1# Ported from latex2sympy by @augustt198 2# https://github.com/augustt198/latex2sympy 3# See license in LICENSE.txt 4 5import sympy 6from sympy.external import import_module 7from sympy.printing.str import StrPrinter 8from sympy.physics.quantum.state import Bra, Ket 9 10from .errors import LaTeXParsingError 11 12 13LaTeXParser = LaTeXLexer = MathErrorListener = None 14 15try: 16 LaTeXParser = import_module('sympy.parsing.latex._antlr.latexparser', 17 import_kwargs={'fromlist': ['LaTeXParser']}).LaTeXParser 18 LaTeXLexer = import_module('sympy.parsing.latex._antlr.latexlexer', 19 import_kwargs={'fromlist': ['LaTeXLexer']}).LaTeXLexer 20except Exception: 21 pass 22 23ErrorListener = import_module('antlr4.error.ErrorListener', 24 warn_not_installed=True, 25 import_kwargs={'fromlist': ['ErrorListener']} 26 ) 27 28 29 30if ErrorListener: 31 class MathErrorListener(ErrorListener.ErrorListener): # type: ignore 32 def __init__(self, src): 33 super(ErrorListener.ErrorListener, self).__init__() 34 self.src = src 35 36 def syntaxError(self, recog, symbol, line, col, msg, e): 37 fmt = "%s\n%s\n%s" 38 marker = "~" * col + "^" 39 40 if msg.startswith("missing"): 41 err = fmt % (msg, self.src, marker) 42 elif msg.startswith("no viable"): 43 err = fmt % ("I expected something else here", self.src, marker) 44 elif msg.startswith("mismatched"): 45 names = LaTeXParser.literalNames 46 expected = [ 47 names[i] for i in e.getExpectedTokens() if i < len(names) 48 ] 49 if len(expected) < 10: 50 expected = " ".join(expected) 51 err = (fmt % ("I expected one of these: " + expected, self.src, 52 marker)) 53 else: 54 err = (fmt % ("I expected something else here", self.src, 55 marker)) 56 else: 57 err = fmt % ("I don't understand this", self.src, marker) 58 raise LaTeXParsingError(err) 59 60 61def parse_latex(sympy): 62 antlr4 = import_module('antlr4', warn_not_installed=True) 63 64 if None in [antlr4, MathErrorListener]: 65 raise ImportError("LaTeX parsing requires the antlr4 python package," 66 " provided by pip (antlr4-python2-runtime or" 67 " antlr4-python3-runtime) or" 68 " conda (antlr-python-runtime)") 69 70 matherror = MathErrorListener(sympy) 71 72 stream = antlr4.InputStream(sympy) 73 lex = LaTeXLexer(stream) 74 lex.removeErrorListeners() 75 lex.addErrorListener(matherror) 76 77 tokens = antlr4.CommonTokenStream(lex) 78 parser = LaTeXParser(tokens) 79 80 # remove default console error listener 81 parser.removeErrorListeners() 82 parser.addErrorListener(matherror) 83 84 relation = parser.math().relation() 85 expr = convert_relation(relation) 86 87 return expr 88 89 90def convert_relation(rel): 91 if rel.expr(): 92 return convert_expr(rel.expr()) 93 94 lh = convert_relation(rel.relation(0)) 95 rh = convert_relation(rel.relation(1)) 96 if rel.LT(): 97 return sympy.StrictLessThan(lh, rh) 98 elif rel.LTE(): 99 return sympy.LessThan(lh, rh) 100 elif rel.GT(): 101 return sympy.StrictGreaterThan(lh, rh) 102 elif rel.GTE(): 103 return sympy.GreaterThan(lh, rh) 104 elif rel.EQUAL(): 105 return sympy.Eq(lh, rh) 106 elif rel.NEQ(): 107 return sympy.Ne(lh, rh) 108 109 110def convert_expr(expr): 111 return convert_add(expr.additive()) 112 113 114def convert_add(add): 115 if add.ADD(): 116 lh = convert_add(add.additive(0)) 117 rh = convert_add(add.additive(1)) 118 return sympy.Add(lh, rh, evaluate=False) 119 elif add.SUB(): 120 lh = convert_add(add.additive(0)) 121 rh = convert_add(add.additive(1)) 122 return sympy.Add(lh, sympy.Mul(-1, rh, evaluate=False), 123 evaluate=False) 124 else: 125 return convert_mp(add.mp()) 126 127 128def convert_mp(mp): 129 if hasattr(mp, 'mp'): 130 mp_left = mp.mp(0) 131 mp_right = mp.mp(1) 132 else: 133 mp_left = mp.mp_nofunc(0) 134 mp_right = mp.mp_nofunc(1) 135 136 if mp.MUL() or mp.CMD_TIMES() or mp.CMD_CDOT(): 137 lh = convert_mp(mp_left) 138 rh = convert_mp(mp_right) 139 return sympy.Mul(lh, rh, evaluate=False) 140 elif mp.DIV() or mp.CMD_DIV() or mp.COLON(): 141 lh = convert_mp(mp_left) 142 rh = convert_mp(mp_right) 143 return sympy.Mul(lh, sympy.Pow(rh, -1, evaluate=False), evaluate=False) 144 else: 145 if hasattr(mp, 'unary'): 146 return convert_unary(mp.unary()) 147 else: 148 return convert_unary(mp.unary_nofunc()) 149 150 151def convert_unary(unary): 152 if hasattr(unary, 'unary'): 153 nested_unary = unary.unary() 154 else: 155 nested_unary = unary.unary_nofunc() 156 if hasattr(unary, 'postfix_nofunc'): 157 first = unary.postfix() 158 tail = unary.postfix_nofunc() 159 postfix = [first] + tail 160 else: 161 postfix = unary.postfix() 162 163 if unary.ADD(): 164 return convert_unary(nested_unary) 165 elif unary.SUB(): 166 numabs = convert_unary(nested_unary) 167 # Use Integer(-n) instead of Mul(-1, n) 168 return -numabs 169 elif postfix: 170 return convert_postfix_list(postfix) 171 172 173def convert_postfix_list(arr, i=0): 174 if i >= len(arr): 175 raise LaTeXParsingError("Index out of bounds") 176 177 res = convert_postfix(arr[i]) 178 if isinstance(res, sympy.Expr): 179 if i == len(arr) - 1: 180 return res # nothing to multiply by 181 else: 182 if i > 0: 183 left = convert_postfix(arr[i - 1]) 184 right = convert_postfix(arr[i + 1]) 185 if isinstance(left, sympy.Expr) and isinstance( 186 right, sympy.Expr): 187 left_syms = convert_postfix(arr[i - 1]).atoms(sympy.Symbol) 188 right_syms = convert_postfix(arr[i + 1]).atoms( 189 sympy.Symbol) 190 # if the left and right sides contain no variables and the 191 # symbol in between is 'x', treat as multiplication. 192 if len(left_syms) == 0 and len(right_syms) == 0 and str( 193 res) == "x": 194 return convert_postfix_list(arr, i + 1) 195 # multiply by next 196 return sympy.Mul( 197 res, convert_postfix_list(arr, i + 1), evaluate=False) 198 else: # must be derivative 199 wrt = res[0] 200 if i == len(arr) - 1: 201 raise LaTeXParsingError("Expected expression for derivative") 202 else: 203 expr = convert_postfix_list(arr, i + 1) 204 return sympy.Derivative(expr, wrt) 205 206 207def do_subs(expr, at): 208 if at.expr(): 209 at_expr = convert_expr(at.expr()) 210 syms = at_expr.atoms(sympy.Symbol) 211 if len(syms) == 0: 212 return expr 213 elif len(syms) > 0: 214 sym = next(iter(syms)) 215 return expr.subs(sym, at_expr) 216 elif at.equality(): 217 lh = convert_expr(at.equality().expr(0)) 218 rh = convert_expr(at.equality().expr(1)) 219 return expr.subs(lh, rh) 220 221 222def convert_postfix(postfix): 223 if hasattr(postfix, 'exp'): 224 exp_nested = postfix.exp() 225 else: 226 exp_nested = postfix.exp_nofunc() 227 228 exp = convert_exp(exp_nested) 229 for op in postfix.postfix_op(): 230 if op.BANG(): 231 if isinstance(exp, list): 232 raise LaTeXParsingError("Cannot apply postfix to derivative") 233 exp = sympy.factorial(exp, evaluate=False) 234 elif op.eval_at(): 235 ev = op.eval_at() 236 at_b = None 237 at_a = None 238 if ev.eval_at_sup(): 239 at_b = do_subs(exp, ev.eval_at_sup()) 240 if ev.eval_at_sub(): 241 at_a = do_subs(exp, ev.eval_at_sub()) 242 if at_b is not None and at_a is not None: 243 exp = sympy.Add(at_b, -1 * at_a, evaluate=False) 244 elif at_b is not None: 245 exp = at_b 246 elif at_a is not None: 247 exp = at_a 248 249 return exp 250 251 252def convert_exp(exp): 253 if hasattr(exp, 'exp'): 254 exp_nested = exp.exp() 255 else: 256 exp_nested = exp.exp_nofunc() 257 258 if exp_nested: 259 base = convert_exp(exp_nested) 260 if isinstance(base, list): 261 raise LaTeXParsingError("Cannot raise derivative to power") 262 if exp.atom(): 263 exponent = convert_atom(exp.atom()) 264 elif exp.expr(): 265 exponent = convert_expr(exp.expr()) 266 return sympy.Pow(base, exponent, evaluate=False) 267 else: 268 if hasattr(exp, 'comp'): 269 return convert_comp(exp.comp()) 270 else: 271 return convert_comp(exp.comp_nofunc()) 272 273 274def convert_comp(comp): 275 if comp.group(): 276 return convert_expr(comp.group().expr()) 277 elif comp.abs_group(): 278 return sympy.Abs(convert_expr(comp.abs_group().expr()), evaluate=False) 279 elif comp.atom(): 280 return convert_atom(comp.atom()) 281 elif comp.frac(): 282 return convert_frac(comp.frac()) 283 elif comp.binom(): 284 return convert_binom(comp.binom()) 285 elif comp.floor(): 286 return convert_floor(comp.floor()) 287 elif comp.ceil(): 288 return convert_ceil(comp.ceil()) 289 elif comp.func(): 290 return convert_func(comp.func()) 291 292 293def convert_atom(atom): 294 if atom.LETTER(): 295 subscriptName = '' 296 if atom.subexpr(): 297 subscript = None 298 if atom.subexpr().expr(): # subscript is expr 299 subscript = convert_expr(atom.subexpr().expr()) 300 else: # subscript is atom 301 subscript = convert_atom(atom.subexpr().atom()) 302 subscriptName = '_{' + StrPrinter().doprint(subscript) + '}' 303 return sympy.Symbol(atom.LETTER().getText() + subscriptName) 304 elif atom.SYMBOL(): 305 s = atom.SYMBOL().getText()[1:] 306 if s == "infty": 307 return sympy.oo 308 else: 309 if atom.subexpr(): 310 subscript = None 311 if atom.subexpr().expr(): # subscript is expr 312 subscript = convert_expr(atom.subexpr().expr()) 313 else: # subscript is atom 314 subscript = convert_atom(atom.subexpr().atom()) 315 subscriptName = StrPrinter().doprint(subscript) 316 s += '_{' + subscriptName + '}' 317 return sympy.Symbol(s) 318 elif atom.NUMBER(): 319 s = atom.NUMBER().getText().replace(",", "") 320 return sympy.Number(s) 321 elif atom.DIFFERENTIAL(): 322 var = get_differential_var(atom.DIFFERENTIAL()) 323 return sympy.Symbol('d' + var.name) 324 elif atom.mathit(): 325 text = rule2text(atom.mathit().mathit_text()) 326 return sympy.Symbol(text) 327 elif atom.bra(): 328 val = convert_expr(atom.bra().expr()) 329 return Bra(val) 330 elif atom.ket(): 331 val = convert_expr(atom.ket().expr()) 332 return Ket(val) 333 334 335def rule2text(ctx): 336 stream = ctx.start.getInputStream() 337 # starting index of starting token 338 startIdx = ctx.start.start 339 # stopping index of stopping token 340 stopIdx = ctx.stop.stop 341 342 return stream.getText(startIdx, stopIdx) 343 344 345def convert_frac(frac): 346 diff_op = False 347 partial_op = False 348 lower_itv = frac.lower.getSourceInterval() 349 lower_itv_len = lower_itv[1] - lower_itv[0] + 1 350 if (frac.lower.start == frac.lower.stop 351 and frac.lower.start.type == LaTeXLexer.DIFFERENTIAL): 352 wrt = get_differential_var_str(frac.lower.start.text) 353 diff_op = True 354 elif (lower_itv_len == 2 and frac.lower.start.type == LaTeXLexer.SYMBOL 355 and frac.lower.start.text == '\\partial' 356 and (frac.lower.stop.type == LaTeXLexer.LETTER 357 or frac.lower.stop.type == LaTeXLexer.SYMBOL)): 358 partial_op = True 359 wrt = frac.lower.stop.text 360 if frac.lower.stop.type == LaTeXLexer.SYMBOL: 361 wrt = wrt[1:] 362 363 if diff_op or partial_op: 364 wrt = sympy.Symbol(wrt) 365 if (diff_op and frac.upper.start == frac.upper.stop 366 and frac.upper.start.type == LaTeXLexer.LETTER 367 and frac.upper.start.text == 'd'): 368 return [wrt] 369 elif (partial_op and frac.upper.start == frac.upper.stop 370 and frac.upper.start.type == LaTeXLexer.SYMBOL 371 and frac.upper.start.text == '\\partial'): 372 return [wrt] 373 upper_text = rule2text(frac.upper) 374 375 expr_top = None 376 if diff_op and upper_text.startswith('d'): 377 expr_top = parse_latex(upper_text[1:]) 378 elif partial_op and frac.upper.start.text == '\\partial': 379 expr_top = parse_latex(upper_text[len('\\partial'):]) 380 if expr_top: 381 return sympy.Derivative(expr_top, wrt) 382 383 expr_top = convert_expr(frac.upper) 384 expr_bot = convert_expr(frac.lower) 385 inverse_denom = sympy.Pow(expr_bot, -1, evaluate=False) 386 if expr_top == 1: 387 return inverse_denom 388 else: 389 return sympy.Mul(expr_top, inverse_denom, evaluate=False) 390 391def convert_binom(binom): 392 expr_n = convert_expr(binom.n) 393 expr_k = convert_expr(binom.k) 394 return sympy.binomial(expr_n, expr_k, evaluate=False) 395 396def convert_floor(floor): 397 val = convert_expr(floor.val) 398 return sympy.floor(val, evaluate=False) 399 400def convert_ceil(ceil): 401 val = convert_expr(ceil.val) 402 return sympy.ceiling(val, evaluate=False) 403 404def convert_func(func): 405 if func.func_normal(): 406 if func.L_PAREN(): # function called with parenthesis 407 arg = convert_func_arg(func.func_arg()) 408 else: 409 arg = convert_func_arg(func.func_arg_noparens()) 410 411 name = func.func_normal().start.text[1:] 412 413 # change arc<trig> -> a<trig> 414 if name in [ 415 "arcsin", "arccos", "arctan", "arccsc", "arcsec", "arccot" 416 ]: 417 name = "a" + name[3:] 418 expr = getattr(sympy.functions, name)(arg, evaluate=False) 419 if name in ["arsinh", "arcosh", "artanh"]: 420 name = "a" + name[2:] 421 expr = getattr(sympy.functions, name)(arg, evaluate=False) 422 423 if name == "exp": 424 expr = sympy.exp(arg, evaluate=False) 425 426 if (name == "log" or name == "ln"): 427 if func.subexpr(): 428 if func.subexpr().expr(): 429 base = convert_expr(func.subexpr().expr()) 430 else: 431 base = convert_atom(func.subexpr().atom()) 432 elif name == "log": 433 base = 10 434 elif name == "ln": 435 base = sympy.E 436 expr = sympy.log(arg, base, evaluate=False) 437 438 func_pow = None 439 should_pow = True 440 if func.supexpr(): 441 if func.supexpr().expr(): 442 func_pow = convert_expr(func.supexpr().expr()) 443 else: 444 func_pow = convert_atom(func.supexpr().atom()) 445 446 if name in [ 447 "sin", "cos", "tan", "csc", "sec", "cot", "sinh", "cosh", 448 "tanh" 449 ]: 450 if func_pow == -1: 451 name = "a" + name 452 should_pow = False 453 expr = getattr(sympy.functions, name)(arg, evaluate=False) 454 455 if func_pow and should_pow: 456 expr = sympy.Pow(expr, func_pow, evaluate=False) 457 458 return expr 459 elif func.LETTER() or func.SYMBOL(): 460 if func.LETTER(): 461 fname = func.LETTER().getText() 462 elif func.SYMBOL(): 463 fname = func.SYMBOL().getText()[1:] 464 fname = str(fname) # can't be unicode 465 if func.subexpr(): 466 subscript = None 467 if func.subexpr().expr(): # subscript is expr 468 subscript = convert_expr(func.subexpr().expr()) 469 else: # subscript is atom 470 subscript = convert_atom(func.subexpr().atom()) 471 subscriptName = StrPrinter().doprint(subscript) 472 fname += '_{' + subscriptName + '}' 473 input_args = func.args() 474 output_args = [] 475 while input_args.args(): # handle multiple arguments to function 476 output_args.append(convert_expr(input_args.expr())) 477 input_args = input_args.args() 478 output_args.append(convert_expr(input_args.expr())) 479 return sympy.Function(fname)(*output_args) 480 elif func.FUNC_INT(): 481 return handle_integral(func) 482 elif func.FUNC_SQRT(): 483 expr = convert_expr(func.base) 484 if func.root: 485 r = convert_expr(func.root) 486 return sympy.root(expr, r, evaluate=False) 487 else: 488 return sympy.sqrt(expr, evaluate=False) 489 elif func.FUNC_OVERLINE(): 490 expr = convert_expr(func.base) 491 return sympy.conjugate(expr, evaluate=False) 492 elif func.FUNC_SUM(): 493 return handle_sum_or_prod(func, "summation") 494 elif func.FUNC_PROD(): 495 return handle_sum_or_prod(func, "product") 496 elif func.FUNC_LIM(): 497 return handle_limit(func) 498 499 500def convert_func_arg(arg): 501 if hasattr(arg, 'expr'): 502 return convert_expr(arg.expr()) 503 else: 504 return convert_mp(arg.mp_nofunc()) 505 506 507def handle_integral(func): 508 if func.additive(): 509 integrand = convert_add(func.additive()) 510 elif func.frac(): 511 integrand = convert_frac(func.frac()) 512 else: 513 integrand = 1 514 515 int_var = None 516 if func.DIFFERENTIAL(): 517 int_var = get_differential_var(func.DIFFERENTIAL()) 518 else: 519 for sym in integrand.atoms(sympy.Symbol): 520 s = str(sym) 521 if len(s) > 1 and s[0] == 'd': 522 if s[1] == '\\': 523 int_var = sympy.Symbol(s[2:]) 524 else: 525 int_var = sympy.Symbol(s[1:]) 526 int_sym = sym 527 if int_var: 528 integrand = integrand.subs(int_sym, 1) 529 else: 530 # Assume dx by default 531 int_var = sympy.Symbol('x') 532 533 if func.subexpr(): 534 if func.subexpr().atom(): 535 lower = convert_atom(func.subexpr().atom()) 536 else: 537 lower = convert_expr(func.subexpr().expr()) 538 if func.supexpr().atom(): 539 upper = convert_atom(func.supexpr().atom()) 540 else: 541 upper = convert_expr(func.supexpr().expr()) 542 return sympy.Integral(integrand, (int_var, lower, upper)) 543 else: 544 return sympy.Integral(integrand, int_var) 545 546 547def handle_sum_or_prod(func, name): 548 val = convert_mp(func.mp()) 549 iter_var = convert_expr(func.subeq().equality().expr(0)) 550 start = convert_expr(func.subeq().equality().expr(1)) 551 if func.supexpr().expr(): # ^{expr} 552 end = convert_expr(func.supexpr().expr()) 553 else: # ^atom 554 end = convert_atom(func.supexpr().atom()) 555 556 if name == "summation": 557 return sympy.Sum(val, (iter_var, start, end)) 558 elif name == "product": 559 return sympy.Product(val, (iter_var, start, end)) 560 561 562def handle_limit(func): 563 sub = func.limit_sub() 564 if sub.LETTER(): 565 var = sympy.Symbol(sub.LETTER().getText()) 566 elif sub.SYMBOL(): 567 var = sympy.Symbol(sub.SYMBOL().getText()[1:]) 568 else: 569 var = sympy.Symbol('x') 570 if sub.SUB(): 571 direction = "-" 572 else: 573 direction = "+" 574 approaching = convert_expr(sub.expr()) 575 content = convert_mp(func.mp()) 576 577 return sympy.Limit(content, var, approaching, direction) 578 579 580def get_differential_var(d): 581 text = get_differential_var_str(d.getText()) 582 return sympy.Symbol(text) 583 584 585def get_differential_var_str(text): 586 for i in range(1, len(text)): 587 c = text[i] 588 if not (c == " " or c == "\r" or c == "\n" or c == "\t"): 589 idx = i 590 break 591 text = text[idx:] 592 if text[0] == "\\": 593 text = text[1:] 594 return text 595