1"""A Module to safely parse/evaluate Mathematical Expressions"""
2import ast
3import operator as op
4import math
5
6from numpy import int64
7
8# Sets the limit of how high the number can get to prevent DNS attacks
9max_value = 1e17
10
11
12# Redefine mathematical operations to prevent DNS attacks
13def add(a, b):
14    """Redefine add function to prevent too large numbers"""
15    if any(abs(n) > max_value for n in [a, b]):
16        raise ValueError((a, b))
17    return op.add(a, b)
18
19
20def sub(a, b):
21    """Redefine sub function to prevent too large numbers"""
22    if any(abs(n) > max_value for n in [a, b]):
23        raise ValueError((a, b))
24    return op.sub(a, b)
25
26
27def mul(a, b):
28    """Redefine mul function to prevent too large numbers"""
29    if a == 0.0 or b == 0.0:
30        pass
31    elif math.log10(abs(a)) + math.log10(abs(b)) > math.log10(max_value):
32        raise ValueError((a, b))
33    return op.mul(a, b)
34
35
36def div(a, b):
37    """Redefine div function to prevent too large numbers"""
38    if b == 0.0:
39        raise ValueError((a, b))
40    elif a == 0.0:
41        pass
42    elif math.log10(abs(a)) - math.log10(abs(b)) > math.log10(max_value):
43        raise ValueError((a, b))
44    return op.truediv(a, b)
45
46
47def power(a, b):
48    """Redefine pow function to prevent too large numbers"""
49    if a == 0.0:
50        return 0.0
51    elif b / math.log(max_value, abs(a)) >= 1:
52        raise ValueError((a, b))
53    return op.pow(a, b)
54
55
56def exp(a):
57    """Redefine exp function to prevent too large numbers"""
58    if a > math.log(max_value):
59        raise ValueError(a)
60    return math.exp(a)
61
62
63# The list of allowed operators with defined functions they should operate on
64operators = {
65    ast.Add: add,
66    ast.Sub: sub,
67    ast.Mult: mul,
68    ast.Div: div,
69    ast.Pow: power,
70    ast.USub: op.neg,
71    ast.Mod: op.mod,
72    ast.FloorDiv: op.ifloordiv
73}
74
75# Take all functions from math module as allowed functions
76allowed_math_fxn = {
77    "sin": math.sin,
78    "cos": math.cos,
79    "tan": math.tan,
80    "asin": math.asin,
81    "acos": math.acos,
82    "atan": math.atan,
83    "atan2": math.atan2,
84    "hypot": math.hypot,
85    "sinh": math.sinh,
86    "cosh": math.cosh,
87    "tanh": math.tanh,
88    "asinh": math.asinh,
89    "acosh": math.acosh,
90    "atanh": math.atanh,
91    "radians": math.radians,
92    "degrees": math.degrees,
93    "sqrt": math.sqrt,
94    "log": math.log,
95    "log10": math.log10,
96    "log2": math.log2,
97    "fmod": math.fmod,
98    "abs": math.fabs,
99    "ceil": math.ceil,
100    "floor": math.floor,
101    "round": round,
102    "exp": exp,
103}
104
105
106def get_function(node):
107    """Get the function from an ast.node"""
108
109    # The function call can be to a bare function or a module.function
110    if isinstance(node.func, ast.Name):
111        return node.func.id
112    elif isinstance(node.func, ast.Attribute):
113        return node.func.attr
114    else:
115        raise TypeError("node.func is of the wrong type")
116
117
118def limit(max_=None):
119    """Return decorator that limits allowed returned values."""
120    import functools
121
122    def decorator(func):
123        @functools.wraps(func)
124        def wrapper(*args, **kwargs):
125            ret = func(*args, **kwargs)
126            try:
127                mag = abs(ret)
128            except TypeError:
129                pass  # not applicable
130            else:
131                if mag > max_:
132                    raise ValueError(ret)
133                if isinstance(ret, int):
134                    ret = int64(ret)
135            return ret
136
137        return wrapper
138
139    return decorator
140
141
142@limit(max_=max_value)
143def _eval(node):
144    """Evaluate a mathematical expression string parsed by ast"""
145    # Allow evaluate certain types of operators
146    if isinstance(node, ast.Num):  # <number>
147        return node.n
148    elif isinstance(node, ast.BinOp):  # <left> <operator> <right>
149        return operators[type(node.op)](_eval(node.left), _eval(node.right))
150    elif isinstance(node, ast.UnaryOp):  # <operator> <operand> e.g., -1
151        return operators[type(node.op)](_eval(node.operand))
152    elif isinstance(node, ast.Call):  # using math.function
153        func = get_function(node)
154        # Evaluate all arguments
155        evaled_args = [_eval(arg) for arg in node.args]
156        return allowed_math_fxn[func](*evaled_args)
157    elif isinstance(node, ast.Name):
158        if node.id.lower() == "pi":
159            return math.pi
160        elif node.id.lower() == "e":
161            return math.e
162        elif node.id.lower() == "tau":
163            return math.pi * 2.0
164        else:
165            raise TypeError("Found a str in the expression, either param_dct/the expression has a mistake in the parameter names or attempting to parse non-mathematical code")
166    else:
167        raise TypeError(node)
168
169
170def eval_expression(expression, param_dct=dict()):
171    """Parse a mathematical expression,
172
173    Replaces variables with the values in param_dict and solves the expression
174
175    """
176    if not isinstance(expression, str):
177        raise TypeError("The expression must be a string")
178    if len(expression) > 1e4:
179        raise ValueError("The expression is too long.")
180
181    expression_rep = expression.strip()
182
183    if "()" in expression_rep:
184        raise ValueError("Invalid operation in expression")
185
186    for key, val in param_dct.items():
187        expression_rep = expression_rep.replace(key, str(val))
188
189    return _eval(ast.parse(expression_rep, mode="eval").body)
190