1"""
2This module implements the functionality to take any Python expression as a
3string and fix all numbers and other things before evaluating it,
4thus
5
61/2
7
8returns
9
10Integer(1)/Integer(2)
11
12We use the ast module for this. It is well documented at docs.python.org.
13
14Some tips to understand how this works: use dump() to get a nice
15representation of any node. Then write a string of what you want to get,
16e.g. "Integer(1)", parse it, dump it and you'll see that you need to do
17"Call(Name('Integer', Load()), [node], [], None, None)". You don't need
18to bother with lineno and col_offset, just call fix_missing_locations()
19before returning the node.
20"""
21
22from sympy.core.basic import Basic
23from sympy.core.sympify import SympifyError
24
25from ast import parse, NodeTransformer, Call, Name, Load, \
26    fix_missing_locations, Str, Tuple
27
28class Transform(NodeTransformer):
29
30    def __init__(self, local_dict, global_dict):
31        NodeTransformer.__init__(self)
32        self.local_dict = local_dict
33        self.global_dict = global_dict
34
35    def visit_Num(self, node):
36        if isinstance(node.n, int):
37            return fix_missing_locations(Call(func=Name('Integer', Load()),
38                    args=[node], keywords=[]))
39        elif isinstance(node.n, float):
40            return fix_missing_locations(Call(func=Name('Float', Load()),
41                    args=[node], keywords=[]))
42        return node
43
44    def visit_Name(self, node):
45        if node.id in self.local_dict:
46            return node
47        elif node.id in self.global_dict:
48            name_obj = self.global_dict[node.id]
49
50            if isinstance(name_obj, (Basic, type)) or callable(name_obj):
51                return node
52        elif node.id in ['True', 'False']:
53            return node
54        return fix_missing_locations(Call(func=Name('Symbol', Load()),
55                args=[Str(node.id)], keywords=[]))
56
57    def visit_Lambda(self, node):
58        args = [self.visit(arg) for arg in node.args.args]
59        body = self.visit(node.body)
60        n = Call(func=Name('Lambda', Load()),
61            args=[Tuple(args, Load()), body], keywords=[])
62        return fix_missing_locations(n)
63
64def parse_expr(s, local_dict):
65    """
66    Converts the string "s" to a SymPy expression, in local_dict.
67
68    It converts all numbers to Integers before feeding it to Python and
69    automatically creates Symbols.
70    """
71    global_dict = {}
72    exec('from sympy import *', global_dict)
73    try:
74        a = parse(s.strip(), mode="eval")
75    except SyntaxError:
76        raise SympifyError("Cannot parse %s." % repr(s))
77    a = Transform(local_dict, global_dict).visit(a)
78    e = compile(a, "<string>", "eval")
79    return eval(e, global_dict, local_dict)
80