1# Ported from latex2sympy by @augustt198
2# https://github.com/augustt198/latex2sympy
3# See license in LICENSE.txt
4
5import sympy
6from sympy.external import import_module
7from sympy.printing.str import StrPrinter
8from sympy.physics.quantum.state import Bra, Ket
9
10from .errors import LaTeXParsingError
11
12
13LaTeXParser = LaTeXLexer = MathErrorListener = None
14
15try:
16    LaTeXParser = import_module('sympy.parsing.latex._antlr.latexparser',
17                                import_kwargs={'fromlist': ['LaTeXParser']}).LaTeXParser
18    LaTeXLexer = import_module('sympy.parsing.latex._antlr.latexlexer',
19                               import_kwargs={'fromlist': ['LaTeXLexer']}).LaTeXLexer
20except Exception:
21    pass
22
23ErrorListener = import_module('antlr4.error.ErrorListener',
24                              warn_not_installed=True,
25                              import_kwargs={'fromlist': ['ErrorListener']}
26                              )
27
28
29
30if ErrorListener:
31    class MathErrorListener(ErrorListener.ErrorListener):  # type: ignore
32        def __init__(self, src):
33            super(ErrorListener.ErrorListener, self).__init__()
34            self.src = src
35
36        def syntaxError(self, recog, symbol, line, col, msg, e):
37            fmt = "%s\n%s\n%s"
38            marker = "~" * col + "^"
39
40            if msg.startswith("missing"):
41                err = fmt % (msg, self.src, marker)
42            elif msg.startswith("no viable"):
43                err = fmt % ("I expected something else here", self.src, marker)
44            elif msg.startswith("mismatched"):
45                names = LaTeXParser.literalNames
46                expected = [
47                    names[i] for i in e.getExpectedTokens() if i < len(names)
48                ]
49                if len(expected) < 10:
50                    expected = " ".join(expected)
51                    err = (fmt % ("I expected one of these: " + expected, self.src,
52                                  marker))
53                else:
54                    err = (fmt % ("I expected something else here", self.src,
55                                  marker))
56            else:
57                err = fmt % ("I don't understand this", self.src, marker)
58            raise LaTeXParsingError(err)
59
60
61def parse_latex(sympy):
62    antlr4 = import_module('antlr4', warn_not_installed=True)
63
64    if None in [antlr4, MathErrorListener]:
65        raise ImportError("LaTeX parsing requires the antlr4 python package,"
66                          " provided by pip (antlr4-python2-runtime or"
67                          " antlr4-python3-runtime) or"
68                          " conda (antlr-python-runtime)")
69
70    matherror = MathErrorListener(sympy)
71
72    stream = antlr4.InputStream(sympy)
73    lex = LaTeXLexer(stream)
74    lex.removeErrorListeners()
75    lex.addErrorListener(matherror)
76
77    tokens = antlr4.CommonTokenStream(lex)
78    parser = LaTeXParser(tokens)
79
80    # remove default console error listener
81    parser.removeErrorListeners()
82    parser.addErrorListener(matherror)
83
84    relation = parser.math().relation()
85    expr = convert_relation(relation)
86
87    return expr
88
89
90def convert_relation(rel):
91    if rel.expr():
92        return convert_expr(rel.expr())
93
94    lh = convert_relation(rel.relation(0))
95    rh = convert_relation(rel.relation(1))
96    if rel.LT():
97        return sympy.StrictLessThan(lh, rh)
98    elif rel.LTE():
99        return sympy.LessThan(lh, rh)
100    elif rel.GT():
101        return sympy.StrictGreaterThan(lh, rh)
102    elif rel.GTE():
103        return sympy.GreaterThan(lh, rh)
104    elif rel.EQUAL():
105        return sympy.Eq(lh, rh)
106    elif rel.NEQ():
107        return sympy.Ne(lh, rh)
108
109
110def convert_expr(expr):
111    return convert_add(expr.additive())
112
113
114def convert_add(add):
115    if add.ADD():
116        lh = convert_add(add.additive(0))
117        rh = convert_add(add.additive(1))
118        return sympy.Add(lh, rh, evaluate=False)
119    elif add.SUB():
120        lh = convert_add(add.additive(0))
121        rh = convert_add(add.additive(1))
122        return sympy.Add(lh, sympy.Mul(-1, rh, evaluate=False),
123                         evaluate=False)
124    else:
125        return convert_mp(add.mp())
126
127
128def convert_mp(mp):
129    if hasattr(mp, 'mp'):
130        mp_left = mp.mp(0)
131        mp_right = mp.mp(1)
132    else:
133        mp_left = mp.mp_nofunc(0)
134        mp_right = mp.mp_nofunc(1)
135
136    if mp.MUL() or mp.CMD_TIMES() or mp.CMD_CDOT():
137        lh = convert_mp(mp_left)
138        rh = convert_mp(mp_right)
139        return sympy.Mul(lh, rh, evaluate=False)
140    elif mp.DIV() or mp.CMD_DIV() or mp.COLON():
141        lh = convert_mp(mp_left)
142        rh = convert_mp(mp_right)
143        return sympy.Mul(lh, sympy.Pow(rh, -1, evaluate=False), evaluate=False)
144    else:
145        if hasattr(mp, 'unary'):
146            return convert_unary(mp.unary())
147        else:
148            return convert_unary(mp.unary_nofunc())
149
150
151def convert_unary(unary):
152    if hasattr(unary, 'unary'):
153        nested_unary = unary.unary()
154    else:
155        nested_unary = unary.unary_nofunc()
156    if hasattr(unary, 'postfix_nofunc'):
157        first = unary.postfix()
158        tail = unary.postfix_nofunc()
159        postfix = [first] + tail
160    else:
161        postfix = unary.postfix()
162
163    if unary.ADD():
164        return convert_unary(nested_unary)
165    elif unary.SUB():
166        numabs = convert_unary(nested_unary)
167        # Use Integer(-n) instead of Mul(-1, n)
168        return -numabs
169    elif postfix:
170        return convert_postfix_list(postfix)
171
172
173def convert_postfix_list(arr, i=0):
174    if i >= len(arr):
175        raise LaTeXParsingError("Index out of bounds")
176
177    res = convert_postfix(arr[i])
178    if isinstance(res, sympy.Expr):
179        if i == len(arr) - 1:
180            return res  # nothing to multiply by
181        else:
182            if i > 0:
183                left = convert_postfix(arr[i - 1])
184                right = convert_postfix(arr[i + 1])
185                if isinstance(left, sympy.Expr) and isinstance(
186                        right, sympy.Expr):
187                    left_syms = convert_postfix(arr[i - 1]).atoms(sympy.Symbol)
188                    right_syms = convert_postfix(arr[i + 1]).atoms(
189                        sympy.Symbol)
190                    # if the left and right sides contain no variables and the
191                    # symbol in between is 'x', treat as multiplication.
192                    if len(left_syms) == 0 and len(right_syms) == 0 and str(
193                            res) == "x":
194                        return convert_postfix_list(arr, i + 1)
195            # multiply by next
196            return sympy.Mul(
197                res, convert_postfix_list(arr, i + 1), evaluate=False)
198    else:  # must be derivative
199        wrt = res[0]
200        if i == len(arr) - 1:
201            raise LaTeXParsingError("Expected expression for derivative")
202        else:
203            expr = convert_postfix_list(arr, i + 1)
204            return sympy.Derivative(expr, wrt)
205
206
207def do_subs(expr, at):
208    if at.expr():
209        at_expr = convert_expr(at.expr())
210        syms = at_expr.atoms(sympy.Symbol)
211        if len(syms) == 0:
212            return expr
213        elif len(syms) > 0:
214            sym = next(iter(syms))
215            return expr.subs(sym, at_expr)
216    elif at.equality():
217        lh = convert_expr(at.equality().expr(0))
218        rh = convert_expr(at.equality().expr(1))
219        return expr.subs(lh, rh)
220
221
222def convert_postfix(postfix):
223    if hasattr(postfix, 'exp'):
224        exp_nested = postfix.exp()
225    else:
226        exp_nested = postfix.exp_nofunc()
227
228    exp = convert_exp(exp_nested)
229    for op in postfix.postfix_op():
230        if op.BANG():
231            if isinstance(exp, list):
232                raise LaTeXParsingError("Cannot apply postfix to derivative")
233            exp = sympy.factorial(exp, evaluate=False)
234        elif op.eval_at():
235            ev = op.eval_at()
236            at_b = None
237            at_a = None
238            if ev.eval_at_sup():
239                at_b = do_subs(exp, ev.eval_at_sup())
240            if ev.eval_at_sub():
241                at_a = do_subs(exp, ev.eval_at_sub())
242            if at_b is not None and at_a is not None:
243                exp = sympy.Add(at_b, -1 * at_a, evaluate=False)
244            elif at_b is not None:
245                exp = at_b
246            elif at_a is not None:
247                exp = at_a
248
249    return exp
250
251
252def convert_exp(exp):
253    if hasattr(exp, 'exp'):
254        exp_nested = exp.exp()
255    else:
256        exp_nested = exp.exp_nofunc()
257
258    if exp_nested:
259        base = convert_exp(exp_nested)
260        if isinstance(base, list):
261            raise LaTeXParsingError("Cannot raise derivative to power")
262        if exp.atom():
263            exponent = convert_atom(exp.atom())
264        elif exp.expr():
265            exponent = convert_expr(exp.expr())
266        return sympy.Pow(base, exponent, evaluate=False)
267    else:
268        if hasattr(exp, 'comp'):
269            return convert_comp(exp.comp())
270        else:
271            return convert_comp(exp.comp_nofunc())
272
273
274def convert_comp(comp):
275    if comp.group():
276        return convert_expr(comp.group().expr())
277    elif comp.abs_group():
278        return sympy.Abs(convert_expr(comp.abs_group().expr()), evaluate=False)
279    elif comp.atom():
280        return convert_atom(comp.atom())
281    elif comp.frac():
282        return convert_frac(comp.frac())
283    elif comp.binom():
284        return convert_binom(comp.binom())
285    elif comp.floor():
286        return convert_floor(comp.floor())
287    elif comp.ceil():
288        return convert_ceil(comp.ceil())
289    elif comp.func():
290        return convert_func(comp.func())
291
292
293def convert_atom(atom):
294    if atom.LETTER():
295        subscriptName = ''
296        if atom.subexpr():
297            subscript = None
298            if atom.subexpr().expr():  # subscript is expr
299                subscript = convert_expr(atom.subexpr().expr())
300            else:  # subscript is atom
301                subscript = convert_atom(atom.subexpr().atom())
302            subscriptName = '_{' + StrPrinter().doprint(subscript) + '}'
303        return sympy.Symbol(atom.LETTER().getText() + subscriptName)
304    elif atom.SYMBOL():
305        s = atom.SYMBOL().getText()[1:]
306        if s == "infty":
307            return sympy.oo
308        else:
309            if atom.subexpr():
310                subscript = None
311                if atom.subexpr().expr():  # subscript is expr
312                    subscript = convert_expr(atom.subexpr().expr())
313                else:  # subscript is atom
314                    subscript = convert_atom(atom.subexpr().atom())
315                subscriptName = StrPrinter().doprint(subscript)
316                s += '_{' + subscriptName + '}'
317            return sympy.Symbol(s)
318    elif atom.NUMBER():
319        s = atom.NUMBER().getText().replace(",", "")
320        return sympy.Number(s)
321    elif atom.DIFFERENTIAL():
322        var = get_differential_var(atom.DIFFERENTIAL())
323        return sympy.Symbol('d' + var.name)
324    elif atom.mathit():
325        text = rule2text(atom.mathit().mathit_text())
326        return sympy.Symbol(text)
327    elif atom.bra():
328        val = convert_expr(atom.bra().expr())
329        return Bra(val)
330    elif atom.ket():
331        val = convert_expr(atom.ket().expr())
332        return Ket(val)
333
334
335def rule2text(ctx):
336    stream = ctx.start.getInputStream()
337    # starting index of starting token
338    startIdx = ctx.start.start
339    # stopping index of stopping token
340    stopIdx = ctx.stop.stop
341
342    return stream.getText(startIdx, stopIdx)
343
344
345def convert_frac(frac):
346    diff_op = False
347    partial_op = False
348    lower_itv = frac.lower.getSourceInterval()
349    lower_itv_len = lower_itv[1] - lower_itv[0] + 1
350    if (frac.lower.start == frac.lower.stop
351            and frac.lower.start.type == LaTeXLexer.DIFFERENTIAL):
352        wrt = get_differential_var_str(frac.lower.start.text)
353        diff_op = True
354    elif (lower_itv_len == 2 and frac.lower.start.type == LaTeXLexer.SYMBOL
355          and frac.lower.start.text == '\\partial'
356          and (frac.lower.stop.type == LaTeXLexer.LETTER
357               or frac.lower.stop.type == LaTeXLexer.SYMBOL)):
358        partial_op = True
359        wrt = frac.lower.stop.text
360        if frac.lower.stop.type == LaTeXLexer.SYMBOL:
361            wrt = wrt[1:]
362
363    if diff_op or partial_op:
364        wrt = sympy.Symbol(wrt)
365        if (diff_op and frac.upper.start == frac.upper.stop
366                and frac.upper.start.type == LaTeXLexer.LETTER
367                and frac.upper.start.text == 'd'):
368            return [wrt]
369        elif (partial_op and frac.upper.start == frac.upper.stop
370              and frac.upper.start.type == LaTeXLexer.SYMBOL
371              and frac.upper.start.text == '\\partial'):
372            return [wrt]
373        upper_text = rule2text(frac.upper)
374
375        expr_top = None
376        if diff_op and upper_text.startswith('d'):
377            expr_top = parse_latex(upper_text[1:])
378        elif partial_op and frac.upper.start.text == '\\partial':
379            expr_top = parse_latex(upper_text[len('\\partial'):])
380        if expr_top:
381            return sympy.Derivative(expr_top, wrt)
382
383    expr_top = convert_expr(frac.upper)
384    expr_bot = convert_expr(frac.lower)
385    inverse_denom = sympy.Pow(expr_bot, -1, evaluate=False)
386    if expr_top == 1:
387        return inverse_denom
388    else:
389        return sympy.Mul(expr_top, inverse_denom, evaluate=False)
390
391def convert_binom(binom):
392    expr_n = convert_expr(binom.n)
393    expr_k = convert_expr(binom.k)
394    return sympy.binomial(expr_n, expr_k, evaluate=False)
395
396def convert_floor(floor):
397    val = convert_expr(floor.val)
398    return sympy.floor(val, evaluate=False)
399
400def convert_ceil(ceil):
401    val = convert_expr(ceil.val)
402    return sympy.ceiling(val, evaluate=False)
403
404def convert_func(func):
405    if func.func_normal():
406        if func.L_PAREN():  # function called with parenthesis
407            arg = convert_func_arg(func.func_arg())
408        else:
409            arg = convert_func_arg(func.func_arg_noparens())
410
411        name = func.func_normal().start.text[1:]
412
413        # change arc<trig> -> a<trig>
414        if name in [
415                "arcsin", "arccos", "arctan", "arccsc", "arcsec", "arccot"
416        ]:
417            name = "a" + name[3:]
418            expr = getattr(sympy.functions, name)(arg, evaluate=False)
419        if name in ["arsinh", "arcosh", "artanh"]:
420            name = "a" + name[2:]
421            expr = getattr(sympy.functions, name)(arg, evaluate=False)
422
423        if name == "exp":
424            expr = sympy.exp(arg, evaluate=False)
425
426        if (name == "log" or name == "ln"):
427            if func.subexpr():
428                if func.subexpr().expr():
429                    base = convert_expr(func.subexpr().expr())
430                else:
431                    base = convert_atom(func.subexpr().atom())
432            elif name == "log":
433                base = 10
434            elif name == "ln":
435                base = sympy.E
436            expr = sympy.log(arg, base, evaluate=False)
437
438        func_pow = None
439        should_pow = True
440        if func.supexpr():
441            if func.supexpr().expr():
442                func_pow = convert_expr(func.supexpr().expr())
443            else:
444                func_pow = convert_atom(func.supexpr().atom())
445
446        if name in [
447                "sin", "cos", "tan", "csc", "sec", "cot", "sinh", "cosh",
448                "tanh"
449        ]:
450            if func_pow == -1:
451                name = "a" + name
452                should_pow = False
453            expr = getattr(sympy.functions, name)(arg, evaluate=False)
454
455        if func_pow and should_pow:
456            expr = sympy.Pow(expr, func_pow, evaluate=False)
457
458        return expr
459    elif func.LETTER() or func.SYMBOL():
460        if func.LETTER():
461            fname = func.LETTER().getText()
462        elif func.SYMBOL():
463            fname = func.SYMBOL().getText()[1:]
464        fname = str(fname)  # can't be unicode
465        if func.subexpr():
466            subscript = None
467            if func.subexpr().expr():  # subscript is expr
468                subscript = convert_expr(func.subexpr().expr())
469            else:  # subscript is atom
470                subscript = convert_atom(func.subexpr().atom())
471            subscriptName = StrPrinter().doprint(subscript)
472            fname += '_{' + subscriptName + '}'
473        input_args = func.args()
474        output_args = []
475        while input_args.args():  # handle multiple arguments to function
476            output_args.append(convert_expr(input_args.expr()))
477            input_args = input_args.args()
478        output_args.append(convert_expr(input_args.expr()))
479        return sympy.Function(fname)(*output_args)
480    elif func.FUNC_INT():
481        return handle_integral(func)
482    elif func.FUNC_SQRT():
483        expr = convert_expr(func.base)
484        if func.root:
485            r = convert_expr(func.root)
486            return sympy.root(expr, r, evaluate=False)
487        else:
488            return sympy.sqrt(expr, evaluate=False)
489    elif func.FUNC_OVERLINE():
490        expr = convert_expr(func.base)
491        return sympy.conjugate(expr, evaluate=False)
492    elif func.FUNC_SUM():
493        return handle_sum_or_prod(func, "summation")
494    elif func.FUNC_PROD():
495        return handle_sum_or_prod(func, "product")
496    elif func.FUNC_LIM():
497        return handle_limit(func)
498
499
500def convert_func_arg(arg):
501    if hasattr(arg, 'expr'):
502        return convert_expr(arg.expr())
503    else:
504        return convert_mp(arg.mp_nofunc())
505
506
507def handle_integral(func):
508    if func.additive():
509        integrand = convert_add(func.additive())
510    elif func.frac():
511        integrand = convert_frac(func.frac())
512    else:
513        integrand = 1
514
515    int_var = None
516    if func.DIFFERENTIAL():
517        int_var = get_differential_var(func.DIFFERENTIAL())
518    else:
519        for sym in integrand.atoms(sympy.Symbol):
520            s = str(sym)
521            if len(s) > 1 and s[0] == 'd':
522                if s[1] == '\\':
523                    int_var = sympy.Symbol(s[2:])
524                else:
525                    int_var = sympy.Symbol(s[1:])
526                int_sym = sym
527        if int_var:
528            integrand = integrand.subs(int_sym, 1)
529        else:
530            # Assume dx by default
531            int_var = sympy.Symbol('x')
532
533    if func.subexpr():
534        if func.subexpr().atom():
535            lower = convert_atom(func.subexpr().atom())
536        else:
537            lower = convert_expr(func.subexpr().expr())
538        if func.supexpr().atom():
539            upper = convert_atom(func.supexpr().atom())
540        else:
541            upper = convert_expr(func.supexpr().expr())
542        return sympy.Integral(integrand, (int_var, lower, upper))
543    else:
544        return sympy.Integral(integrand, int_var)
545
546
547def handle_sum_or_prod(func, name):
548    val = convert_mp(func.mp())
549    iter_var = convert_expr(func.subeq().equality().expr(0))
550    start = convert_expr(func.subeq().equality().expr(1))
551    if func.supexpr().expr():  # ^{expr}
552        end = convert_expr(func.supexpr().expr())
553    else:  # ^atom
554        end = convert_atom(func.supexpr().atom())
555
556    if name == "summation":
557        return sympy.Sum(val, (iter_var, start, end))
558    elif name == "product":
559        return sympy.Product(val, (iter_var, start, end))
560
561
562def handle_limit(func):
563    sub = func.limit_sub()
564    if sub.LETTER():
565        var = sympy.Symbol(sub.LETTER().getText())
566    elif sub.SYMBOL():
567        var = sympy.Symbol(sub.SYMBOL().getText()[1:])
568    else:
569        var = sympy.Symbol('x')
570    if sub.SUB():
571        direction = "-"
572    else:
573        direction = "+"
574    approaching = convert_expr(sub.expr())
575    content = convert_mp(func.mp())
576
577    return sympy.Limit(content, var, approaching, direction)
578
579
580def get_differential_var(d):
581    text = get_differential_var_str(d.getText())
582    return sympy.Symbol(text)
583
584
585def get_differential_var_str(text):
586    for i in range(1, len(text)):
587        c = text[i]
588        if not (c == " " or c == "\r" or c == "\n" or c == "\t"):
589            idx = i
590            break
591    text = text[idx:]
592    if text[0] == "\\":
593        text = text[1:]
594    return text
595