1"""A Module to safely parse/evaluate Mathematical Expressions""" 2import ast 3import operator as op 4import math 5 6from numpy import int64 7 8# Sets the limit of how high the number can get to prevent DNS attacks 9max_value = 1e17 10 11 12# Redefine mathematical operations to prevent DNS attacks 13def add(a, b): 14 """Redefine add function to prevent too large numbers""" 15 if any(abs(n) > max_value for n in [a, b]): 16 raise ValueError((a, b)) 17 return op.add(a, b) 18 19 20def sub(a, b): 21 """Redefine sub function to prevent too large numbers""" 22 if any(abs(n) > max_value for n in [a, b]): 23 raise ValueError((a, b)) 24 return op.sub(a, b) 25 26 27def mul(a, b): 28 """Redefine mul function to prevent too large numbers""" 29 if a == 0.0 or b == 0.0: 30 pass 31 elif math.log10(abs(a)) + math.log10(abs(b)) > math.log10(max_value): 32 raise ValueError((a, b)) 33 return op.mul(a, b) 34 35 36def div(a, b): 37 """Redefine div function to prevent too large numbers""" 38 if b == 0.0: 39 raise ValueError((a, b)) 40 elif a == 0.0: 41 pass 42 elif math.log10(abs(a)) - math.log10(abs(b)) > math.log10(max_value): 43 raise ValueError((a, b)) 44 return op.truediv(a, b) 45 46 47def power(a, b): 48 """Redefine pow function to prevent too large numbers""" 49 if a == 0.0: 50 return 0.0 51 elif b / math.log(max_value, abs(a)) >= 1: 52 raise ValueError((a, b)) 53 return op.pow(a, b) 54 55 56def exp(a): 57 """Redefine exp function to prevent too large numbers""" 58 if a > math.log(max_value): 59 raise ValueError(a) 60 return math.exp(a) 61 62 63# The list of allowed operators with defined functions they should operate on 64operators = { 65 ast.Add: add, 66 ast.Sub: sub, 67 ast.Mult: mul, 68 ast.Div: div, 69 ast.Pow: power, 70 ast.USub: op.neg, 71 ast.Mod: op.mod, 72 ast.FloorDiv: op.ifloordiv 73} 74 75# Take all functions from math module as allowed functions 76allowed_math_fxn = { 77 "sin": math.sin, 78 "cos": math.cos, 79 "tan": math.tan, 80 "asin": math.asin, 81 "acos": math.acos, 82 "atan": math.atan, 83 "atan2": math.atan2, 84 "hypot": math.hypot, 85 "sinh": math.sinh, 86 "cosh": math.cosh, 87 "tanh": math.tanh, 88 "asinh": math.asinh, 89 "acosh": math.acosh, 90 "atanh": math.atanh, 91 "radians": math.radians, 92 "degrees": math.degrees, 93 "sqrt": math.sqrt, 94 "log": math.log, 95 "log10": math.log10, 96 "log2": math.log2, 97 "fmod": math.fmod, 98 "abs": math.fabs, 99 "ceil": math.ceil, 100 "floor": math.floor, 101 "round": round, 102 "exp": exp, 103} 104 105 106def get_function(node): 107 """Get the function from an ast.node""" 108 109 # The function call can be to a bare function or a module.function 110 if isinstance(node.func, ast.Name): 111 return node.func.id 112 elif isinstance(node.func, ast.Attribute): 113 return node.func.attr 114 else: 115 raise TypeError("node.func is of the wrong type") 116 117 118def limit(max_=None): 119 """Return decorator that limits allowed returned values.""" 120 import functools 121 122 def decorator(func): 123 @functools.wraps(func) 124 def wrapper(*args, **kwargs): 125 ret = func(*args, **kwargs) 126 try: 127 mag = abs(ret) 128 except TypeError: 129 pass # not applicable 130 else: 131 if mag > max_: 132 raise ValueError(ret) 133 if isinstance(ret, int): 134 ret = int64(ret) 135 return ret 136 137 return wrapper 138 139 return decorator 140 141 142@limit(max_=max_value) 143def _eval(node): 144 """Evaluate a mathematical expression string parsed by ast""" 145 # Allow evaluate certain types of operators 146 if isinstance(node, ast.Num): # <number> 147 return node.n 148 elif isinstance(node, ast.BinOp): # <left> <operator> <right> 149 return operators[type(node.op)](_eval(node.left), _eval(node.right)) 150 elif isinstance(node, ast.UnaryOp): # <operator> <operand> e.g., -1 151 return operators[type(node.op)](_eval(node.operand)) 152 elif isinstance(node, ast.Call): # using math.function 153 func = get_function(node) 154 # Evaluate all arguments 155 evaled_args = [_eval(arg) for arg in node.args] 156 return allowed_math_fxn[func](*evaled_args) 157 elif isinstance(node, ast.Name): 158 if node.id.lower() == "pi": 159 return math.pi 160 elif node.id.lower() == "e": 161 return math.e 162 elif node.id.lower() == "tau": 163 return math.pi * 2.0 164 else: 165 raise TypeError("Found a str in the expression, either param_dct/the expression has a mistake in the parameter names or attempting to parse non-mathematical code") 166 else: 167 raise TypeError(node) 168 169 170def eval_expression(expression, param_dct=dict()): 171 """Parse a mathematical expression, 172 173 Replaces variables with the values in param_dict and solves the expression 174 175 """ 176 if not isinstance(expression, str): 177 raise TypeError("The expression must be a string") 178 if len(expression) > 1e4: 179 raise ValueError("The expression is too long.") 180 181 expression_rep = expression.strip() 182 183 if "()" in expression_rep: 184 raise ValueError("Invalid operation in expression") 185 186 for key, val in param_dct.items(): 187 expression_rep = expression_rep.replace(key, str(val)) 188 189 return _eval(ast.parse(expression_rep, mode="eval").body) 190