Skip to content
Snippets Groups Projects
Commit de2c38ea authored by Tulir Asokan's avatar Tulir Asokan :cat2:
Browse files

Add support for more dice, larger dice and function calls

parent 0e30d1b0
No related branches found
No related tags found
No related merge requests found
......@@ -13,20 +13,17 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Match
from typing import Match, Union, Any
import operator
import random
import math
import ast
import re
from maubot import Plugin, MessageEvent
from maubot.handlers import command
ARG_PATTERN = "$pattern"
COMMAND_ROLL = f"roll {ARG_PATTERN}"
COMMAND_ROLL_DEFAULT = "roll"
pattern_regex = re.compile("([0-9]{0,2})d([0-9]{1,2})")
pattern_regex = re.compile("([0-9]{0,9})d([0-9]{1,9})")
_OP_MAP = {
ast.Add: operator.add,
......@@ -46,6 +43,9 @@ _OP_MAP = {
ast.LShift: operator.lshift,
}
_NUM_MAX = 1_000_000_000_000_000
_NUM_MIN = -_NUM_MAX
_OP_LIMITS = {
ast.Pow: (1000, 1000),
ast.LShift: (1000, 1000),
......@@ -55,10 +55,27 @@ _OP_LIMITS = {
ast.Mod: (1_000_000_000_000_000, 1_000_000_000_000_000),
}
_ALLOWED_FUNCS = ["ceil", "copysign", "fabs", "factorial", "gcd", "remainder", "trunc",
"exp", "log", "log1p", "log2", "log10", "sqrt",
"acos", "asin", "atan", "atan2", "cos", "hypot", "sin", "tan",
"degrees", "radians",
"acosh", "asinh", "atanh", "cosh", "sinh", "tanh",
"erf", "erfc", "gamma", "lgamma"]
_FUNC_MAP = {func: getattr(math, func) for func in _ALLOWED_FUNCS if hasattr(math, func)}
_FUNC_LIMITS = {
"factorial": 1000,
"exp": 709,
"sqrt": 1_000_000_000_000_000,
}
_ARG_COUNT_LIMIT = 5
# AST-based calculator from https://stackoverflow.com/a/33030616/2120293
class Calc(ast.NodeVisitor):
def visit_BinOp(self, node):
def visit_BinOp(self, node: ast.BinOp) -> Any:
left = self.visit(node.left)
right = self.visit(node.right)
op_type = type(node.op)
......@@ -74,7 +91,7 @@ class Calc(ast.NodeVisitor):
raise SyntaxError(f"Operator {op_type.__name__} not allowed")
return op(left, right)
def visit_UnaryOp(self, node):
def visit_UnaryOp(self, node: ast.UnaryOp) -> Any:
operand = self.visit(node.operand)
try:
op = _OP_MAP[type(node.op)]
......@@ -82,14 +99,47 @@ class Calc(ast.NodeVisitor):
raise SyntaxError(f"Operator {type(node.op).__name__} not allowed")
return op(operand)
def visit_Num(self, node):
def visit_Num(self, node: ast.Num) -> Any:
if node.n > _NUM_MAX or node.n < _NUM_MIN:
raise ValueError(f"Number out of bounds")
return node.n
def visit_Expr(self, node):
def visit_Name(self, node: ast.Name) -> Any:
if node.id == "pi":
return math.pi
elif node.id == "tau":
return math.tau
elif node.id == "e":
return math.e
def visit_Call(self, node: ast.Call) -> Any:
if isinstance(node.func, ast.Name):
try:
func = _FUNC_MAP[node.func.id]
except KeyError:
raise NameError(f"Function {node.func.id} is not defined")
args = [self.visit(arg) for arg in node.args]
kwargs = {kwarg.arg: self.visit(kwarg.value) for kwarg in node.keywords}
if len(args) + len(kwargs) > _ARG_COUNT_LIMIT:
raise ValueError("Too many arguments")
try:
limit = _FUNC_LIMITS[node.func.id]
for value in args:
if value > limit:
raise ValueError(f"Value over bounds for function {node.func.id}")
for value in kwargs.values():
if value > limit:
raise ValueError(f"Value over bounds for function {node.func.id}")
except KeyError:
pass
return func(*args, **kwargs)
raise SyntaxError("Indirect call")
def visit_Expr(self, node: ast.Expr) -> Any:
return self.visit(node.value)
@classmethod
def evaluate(cls, expression):
def evaluate(cls, expression: str) -> Union[int, float]:
tree = ast.parse(expression)
return cls().visit(tree.body[0])
......@@ -104,8 +154,14 @@ class DiceBot(Plugin):
elif size == 1:
return number
result = 0
for i in range(number):
result += random.randint(1, size)
if number < 100:
for i in range(number):
result += random.randint(1, size)
else:
mean = number * (size + 1) / 2
variance = number * (size ** 2 - 1) / 12
while result < number or result > number * size:
result = int(random.gauss(mean, math.sqrt(variance)))
return result
@classmethod
......
maubot: 0.1.0
id: xyz.maubot.dice
version: 1.0.0
license: AGPL-3.0-or-later
modules:
- dice
main_class: DiceBot
database: true
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment