1from sympy.parsing.sympy_parser import parse_expr, auto_number, rationalize
2from sympy.parsing.sympy_tokenize import NUMBER, STRING, NAME, OP
3from sympy import Basic, Symbol, Expr
4import sympy as sym
5import re
6
7global_dict = {}
8exec('from sympy import *', global_dict)
9global_ignore = ('C', 'O', 'S', 'N', 'E', 'E1', 'Q')
10for symbol in global_ignore:
11    global_dict.pop(symbol)
12# delta gets printed as DiracDelta; could override
13global_dict['delta'] = global_dict['DiracDelta']
14global_dict['step'] = global_dict['Heaviside']
15global_dict['u'] = global_dict['Heaviside']
16
17cpt_names = ('C', 'E', 'F', 'G', 'H', 'I', 'L', 'R', 'V', 'Y', 'Z')
18cpt_name_pattern = re.compile(r"(%s)([\w']*)" % '|'.join(cpt_names))
19
20sub_super_pattern = re.compile(r"([_\^]){([\w]+)}")
21
22def canonical_name(name):
23
24    def foo(match):
25        return match.group(1) + match.group(2)
26
27    if not isinstance(name, str):
28        return name
29
30    # Convert R_{out} to R_out for sympy to recognise.
31    name = sub_super_pattern.sub(foo, name)
32
33    if name.find('_') != -1:
34        return name
35
36    # Rewrite R1 as R_1, etc.
37    match = cpt_name_pattern.match(name)
38    if match:
39        if match.groups()[1] == '':
40            return name
41        name = match.groups()[0] + '_' + match.groups()[1]
42        return name
43
44    return name
45
46def symbols_find(arg):
47    """Return list of symbols in arg.  No symbols are cached."""
48
49    symbols = []
50
51    def find_symbol(tokens, local_dict, global_dict):
52
53        for tok in tokens:
54            tokNum, tokVal = tok
55            if tokNum == NAME:
56                name = tokVal
57                if name == 'j':
58                    name = 'I'
59                if name not in local_dict and name not in global_dict:
60                    symbols.append(name)
61        return ([(NUMBER, '0')])
62
63    if isinstance(arg, str):
64        parse_expr(arg, transformations=(find_symbol, ),
65                   global_dict=global_dict, local_dict={}, evaluate=False)
66
67        return symbols
68
69    # Hack
70    if hasattr(arg, 'expr'):
71        arg = arg.expr
72
73    if not isinstance(arg, (Symbol, Expr)):
74        return []
75    return [symbol.name for symbol in arg.atoms(Symbol)]
76
77def parse(string, symbols={}, evaluate=True, local_dict={}, **assumptions):
78
79    cache = assumptions.pop('cache', True)
80
81    def auto_symbol(tokens, local_dict, global_dict):
82        """Inserts calls to ``Symbol`` for undefined variables."""
83        result = []
84        prevTok = (None, None)
85
86        tokens.append((None, None))  # so zip traverses all tokens
87        for tok, nextTok in zip(tokens, tokens[1:]):
88            tokNum, tokVal = tok
89            nextTokNum, nextTokVal = nextTok
90            if tokNum == NAME:
91                name = tokVal
92                if name == 'j':
93                    name = 'I'
94
95                if name in global_dict:
96
97                    obj = global_dict[name]
98                    if isinstance(obj, (Basic, type)):
99                        result.append((NAME, name))
100                        continue
101
102                    if callable(obj):
103                        result.append((NAME, name))
104                        continue
105
106                name = canonical_name(str(name))
107
108                if name in local_dict:
109                    # print('Found %s' % name)
110                    # Could check assumptions.
111                    result.append((NAME, name))
112                    continue
113
114                # Automatically add Symbol
115                result.extend([(NAME, 'Symbol'),
116                               (OP, '('), (NAME, repr(name))])
117                for assumption, val in assumptions.items():
118                    result.extend([(OP, ','),
119                                   (NAME, '%s=%s' % (assumption, val))])
120                result.extend([(OP, ')')])
121
122            else:
123                result.append((tokNum, tokVal))
124
125            prevTok = (tokNum, tokVal)
126
127        return result
128
129
130    s = parse_expr(string, transformations=(auto_symbol, auto_number,
131                                            rationalize),
132                   global_dict=global_dict, local_dict=local_dict,
133                   evaluate=evaluate)
134    if not cache:
135        return s
136
137    # Look for newly defined symbols.
138    for symbol in s.atoms(Symbol):
139        if (False and symbol.name in symbols
140            and symbols[symbol.name] != symbol):
141            # The symbol may have different assumptions, real,
142            # positive, etc.
143            print('Different assumptions for symbol %s when parsing %s' %
144                  (symbol.name, string))
145
146        if symbol.name not in symbols:
147            if False:
148                print('Added symbol %s: real=%s, positive=%s' %
149                      (symbol.name, symbol.is_real, symbol.is_positive))
150            symbols[symbol.name] = symbol
151
152    return s
153
154
155def sympify1(arg, symbols={}, evaluate=True, **assumptions):
156    """Create a sympy expression."""
157
158    if hasattr(arg, 'expr'):
159        return arg.expr
160
161    if isinstance(arg, (Symbol, Expr)):
162        return arg
163
164    # Why doesn't sympy do this?
165    if isinstance(arg, complex):
166        re = sym.sympify(str(arg.real), rational=True, evaluate=evaluate)
167        im = sym.sympify(str(arg.imag), rational=True, evaluate=evaluate)
168        if im == 1.0:
169            arg = re + sym.I
170        else:
171            arg = re + sym.I * im
172        return arg
173
174    if isinstance(arg, float):
175        # Note, need to convert to string to achieve a rational
176        # representation.
177        return sym.sympify(str(arg), rational=True, evaluate=evaluate)
178
179    if isinstance(arg, str):
180        return parse(arg, symbols, evaluate=evaluate,
181                     local_dict=symbols, **assumptions)
182
183    return sym.sympify(arg, rational=True, locals=symbols,
184                       evaluate=evaluate)
185
186def test():
187    symbols = {}
188    s1 = sympify1('5 * E1 + a', symbols)
189    s2 = sympify1('5 * R1 + a', symbols, real=True)
190    print(symbols['R_1'].assumptions0)
191    s3 = sympify1('5 * R1 + a', symbols, positive=True)
192    print(symbols1['R_1'].assumptions0)
193
194