1
2 #include "Algorithm.hh"
3 #include "Functional.hh"
4 #include "DisplaySympy.hh"
5 #include "properties/Depends.hh"
6 #include "properties/Accent.hh"
7 #include "properties/Derivative.hh"
8 #include <regex>
9
10 // #define DEBUG 1
11
12 using namespace cadabra;
13
DisplaySympy(const Kernel & kernel,const Ex & e)14 DisplaySympy::DisplaySympy(const Kernel& kernel, const Ex& e)
15 : DisplayBase(kernel, e)
16 {
17 symmap = {
18 {"\\cos", "cos"},
19 {"\\sin", "sin"},
20 {"\\tan", "tan"},
21 {"\\arccos", "acos"},
22 {"\\arcsin", "asin"},
23 {"\\arctan", "atan"},
24 {"\\cosh", "cosh"},
25 {"\\sinh", "sinh"},
26 {"\\tanh", "tanh"},
27 {"\\arccosh", "acosh"},
28 {"\\arcsinh", "asinh"},
29 {"\\arctanh", "atanh"},
30 {"\\log", "log"},
31 {"\\int", "integrate" },
32 {"\\matrix", "Matrix" },
33 {"\\sum", "Sum" },
34 {"\\exp", "exp" },
35 {"\\sqrt", "sqrt" },
36 {"\\equals", "Eq" },
37
38 {"\\infty", "sympy.oo"},
39 {"\\hbar", "hbar"},
40
41 {"\\alpha", "alpha" },
42 {"\\beta", "bbeta" }, // beta seems to be reserved
43 {"\\gamma", "ggamma" }, // gamma seems to be reserved
44 {"\\delta", "delta" },
45 {"\\epsilon", "epsilon" },
46 {"\\zeta", "zeta" },
47 {"\\eta", "eta" },
48 {"\\theta", "theta" },
49 {"\\iota", "iota" },
50 {"\\kappa", "kappa" },
51 {"\\lambda", "lamda" }, // lambda is reserved
52 {"\\mu", "mu" },
53 {"\\nu", "nu" },
54 {"\\xi", "xi" },
55 {"\\omicron", "omicron" },
56 {"\\pi", "pi" },
57 {"\\rho", "rho" },
58 {"\\sigma", "sigma" },
59 {"\\tau", "tau" },
60 {"\\upsilon", "upsilon" },
61 {"\\phi", "phi" },
62 {"\\varphi", "varphi" },
63 {"\\chi", "chi" },
64 {"\\psi", "psi" },
65 {"\\omega", "omega" },
66
67 {"\\Alpha", "Alpha" },
68 {"\\Beta", "Beta" },
69 {"\\Gamma", "Gamma" },
70 {"\\Delta", "Delta" },
71 {"\\Epsilon", "Epsilon" },
72 {"\\Zeta", "Zeta" },
73 {"\\Eta", "Eta" },
74 {"\\Theta", "Theta" },
75 {"\\Iota", "Iota" },
76 {"\\Kappa", "Kappa" },
77 {"\\Lambda", "Lamda" },
78 {"\\Mu", "Mu" },
79 {"\\Nu", "Nu" },
80 {"\\Xi", "Xi" },
81 {"\\Omicron", "Omicron" },
82 {"\\Pi", "Pi" },
83 {"\\Rho", "Rho" },
84 {"\\Sigma", "Sigma" },
85 {"\\Tau", "Tau" },
86 {"\\Upsilon", "Upsilon" },
87 {"\\Phi", "Phi" },
88 {"\\Chi", "Chi" },
89 {"\\Psi", "Psi" },
90 {"\\Omega", "Omega" },
91
92 {"\\partial", "Derivative"},
93 {"\\dot", "dot"},
94 {"\\ddot", "ddot"},
95
96 // A few symbols are reserved by sympy.
97 {"N", "sympyN"},
98 {"O", "sympyO"},
99 {"S", "sympyS"}
100 };
101
102 regex_map = {
103 {"Integral", "\\int" }
104 };
105
106 }
107
108 //TODO: complete this list (take from Sympy)
109
needs_brackets(Ex::iterator it)110 bool DisplaySympy::needs_brackets(Ex::iterator it)
111 {
112 // FIXME: may need looking at properties
113 // FIXME: write as individual parent/current tests
114 if(tree.is_head(it)) return false;
115
116 std::string parent=*tree.parent(it)->name;
117 std::string child =*it->name;
118
119 if(parent=="\\prod" || parent=="\\frac" || parent=="\\pow") {
120 if(parent=="\\pow" && *it->multiplier<0) return true;
121 if(child=="\\sum") return true;
122 if(parent=="\\pow" && ( (tree.index(it)==0 && !it->is_integer()) || child=="\\sum" || child=="\\prod" || child=="\\pow") ) return true;
123 }
124 else if(it->fl.parent_rel==str_node::p_none) {
125 if(*it->name=="\\sum") return false;
126 }
127 else {
128 if(*it->name=="\\sum") return true;
129 if(*it->name=="\\prod") return true;
130 }
131 return false;
132 }
133
134
print_other(std::ostream & str,Ex::iterator it)135 void DisplaySympy::print_other(std::ostream& str, Ex::iterator it)
136 {
137 if(needs_brackets(it))
138 str << "(";
139
140 // print multiplier and object name
141 if(*it->multiplier!=1)
142 print_multiplier(str, it);
143
144 if(*it->name=="1") {
145 if(*it->multiplier==1 || (*it->multiplier==-1)) // this would print nothing altogether.
146 str << "1";
147
148 if(needs_brackets(it))
149 str << ")";
150 return;
151 }
152
153 // const Accent *ac=properties.get<Accent>(it);
154 // if(!ac) { // accents should never get additional curly brackets, {\bar}{g} does not print.
155 // Ex::sibling_iterator sib=tree.begin(it);
156 // while(sib!=tree.end(it)) {
157 // if(sib->is_index())
158 // needs_extra_brackets=true;
159 // ++sib;
160 // }
161 // }
162
163 auto rn = symmap.find(*it->name);
164 if(rn!=symmap.end())
165 str << rn->second;
166 else
167 str << *it->name;
168
169 print_children(str, it);
170
171 if(needs_brackets(it))
172 str << ")";
173 }
174
print_children(std::ostream & str,Ex::iterator it,int)175 void DisplaySympy::print_children(std::ostream& str, Ex::iterator it, int )
176 {
177 // Sympy has no notion of children with different parent relations; it's all
178 // functions of functions kind of stuff. What we will do is print upper and
179 // lower indices as 'UP(..)' and 'DN(..)' type arguments, and then convert
180 // them back later.
181
182 // We need to know if the symbol has implicit dependence on other symbols,
183 // as this needs to be made explicit for sympy. We need to strip this
184 // dependence off later again.
185
186 const Depends *dep=kernel.properties.get<Depends>(it);
187 if(dep) {
188 depsyms[it->name]=dep->dependencies(kernel, it);
189 // std::cerr << *it->name << "depends on " << depsyms[it->name] << std::endl;
190 }
191
192 Ex::sibling_iterator ch=tree.begin(it);
193 if(ch!=tree.end(it) || dep!=0) {
194 str << "(";
195 bool first=true;
196 while(ch!=tree.end(it)) {
197 if(first) first=false;
198 else str << ", ";
199 if(ch->fl.parent_rel==str_node::p_super)
200 str << "UP";
201 if(ch->fl.parent_rel==str_node::p_sub)
202 str << "DN";
203
204 dispatch(str, ch);
205
206 // if(ch->fl.parent_rel==str_node::p_super || ch->fl.parent_rel==str_node::p_sub)
207 // str << ")";
208 ++ch;
209 }
210 if(dep) {
211 if(!first) str << ", ";
212 Ex deplist=dep->dependencies(kernel, it);
213 // deplist is always a \comma node
214 auto sib=tree.begin(deplist.begin());
215 while(sib!=tree.end(deplist.begin())) {
216 const Derivative *dep_is_derivative=kernel.properties.get<Derivative>(sib);
217 if(dep_is_derivative)
218 throw RuntimeException("Dependencies on derivatives are not yet handled in the SymPy bridge");
219 dispatch(str, sib);
220 ++sib;
221 if(sib!=tree.end(deplist.begin()))
222 str << ", ";
223 }
224 //
225 // DisplaySympy ds(kernel, deplist);
226 // ds.output(str);
227 }
228 str << ")";
229 }
230 }
231
print_multiplier(std::ostream & str,Ex::iterator it)232 void DisplaySympy::print_multiplier(std::ostream& str, Ex::iterator it)
233 {
234 bool suppress_star=false;
235 mpz_class denom=it->multiplier->get_den();
236
237 if(denom!=1) {
238 if(false && it->multiplier->get_num()<0)
239 str << "(" << it->multiplier->get_num() << ")";
240 else
241 str << it->multiplier->get_num();
242 str << "/" << it->multiplier->get_den();
243 }
244 else if(*it->multiplier==-1) {
245 str << "-";
246 suppress_star=true;
247 }
248 else {
249 str << *it->multiplier;
250 }
251
252 if(!suppress_star && !(*it->name=="1"))
253 str << "*";
254 }
255
print_opening_bracket(std::ostream & str,str_node::bracket_t br)256 void DisplaySympy::print_opening_bracket(std::ostream& str, str_node::bracket_t br)
257 {
258 switch(br) {
259 case str_node::b_none:
260 str << ")";
261 break;
262 case str_node::b_pointy:
263 str << "\\<";
264 break;
265 case str_node::b_curly:
266 str << "\\{";
267 break;
268 case str_node::b_round:
269 str << "(";
270 break;
271 case str_node::b_square:
272 str << "[";
273 break;
274 default :
275 return;
276 }
277 }
278
print_closing_bracket(std::ostream & str,str_node::bracket_t br)279 void DisplaySympy::print_closing_bracket(std::ostream& str, str_node::bracket_t br)
280 {
281 switch(br) {
282 case str_node::b_none:
283 str << ")";
284 break;
285 case str_node::b_pointy:
286 str << "\\>";
287 break;
288 case str_node::b_curly:
289 str << "\\}";
290 break;
291 case str_node::b_round:
292 str << ")";
293 break;
294 case str_node::b_square:
295 str << "]";
296 break;
297 default :
298 return;
299 }
300 }
301
print_parent_rel(std::ostream & str,str_node::parent_rel_t pr,bool)302 void DisplaySympy::print_parent_rel(std::ostream& str, str_node::parent_rel_t pr, bool )
303 {
304 switch(pr) {
305 case str_node::p_super:
306 str << "^";
307 break;
308 case str_node::p_sub:
309 str << "_";
310 break;
311 case str_node::p_property:
312 str << "$";
313 break;
314 case str_node::p_exponent:
315 str << "**";
316 break;
317 case str_node::p_none:
318 break;
319 case str_node::p_components:
320 break;
321 case str_node::p_invalid:
322 throw std::logic_error("DisplaySympy: p_invalid not handled.");
323 }
324 }
325
dispatch(std::ostream & str,Ex::iterator it)326 void DisplaySympy::dispatch(std::ostream& str, Ex::iterator it)
327 {
328 // The node names below should only be reserved node names; all others
329 // should be looked up using properties. FIXME
330 if(*it->name=="\\prod") print_productlike(str, it, "*");
331 else if(*it->name=="\\sum") print_sumlike(str, it);
332 else if(*it->name=="\\frac") print_fraclike(str, it);
333 else if(*it->name=="\\comma") print_commalike(str, it);
334 else if(*it->name=="\\arrow") print_arrowlike(str, it);
335 else if(*it->name=="\\pow") print_powlike(str, it);
336 else if(*it->name=="\\int") print_intlike(str, it);
337 else if(*it->name=="\\sum") print_intlike(str, it);
338 else if(*it->name=="\\equals") print_equalitylike(str, it);
339 else if(*it->name=="\\components") print_components(str, it);
340 else if(*it->name=="\\partial") print_partial(str, it);
341 else if(*it->name=="\\matrix") print_matrix(str, it);
342 else print_other(str, it);
343 }
344
print_commalike(std::ostream & str,Ex::iterator it)345 void DisplaySympy::print_commalike(std::ostream& str, Ex::iterator it)
346 {
347 Ex::sibling_iterator sib=tree.begin(it);
348 bool first=true;
349 str << "[";
350 while(sib!=tree.end(it)) {
351 if(first)
352 first=false;
353 else
354 str << ", ";
355 dispatch(str, sib);
356 ++sib;
357 }
358 str << "]";
359 //print_closing_bracket(str, (*it).fl.bracket, str_node::p_none);
360 }
361
print_arrowlike(std::ostream & str,Ex::iterator it)362 void DisplaySympy::print_arrowlike(std::ostream& str, Ex::iterator it)
363 {
364 Ex::sibling_iterator sib=tree.begin(it);
365 str << "rule(";
366 dispatch(str, sib);
367 str << ", ";
368 ++sib;
369 dispatch(str, sib);
370 str << ")";
371 }
372
print_fraclike(std::ostream & str,Ex::iterator it)373 void DisplaySympy::print_fraclike(std::ostream& str, Ex::iterator it)
374 {
375 Ex::sibling_iterator num=tree.begin(it), den=num;
376 ++den;
377
378 if(*it->multiplier!=1) {
379 print_multiplier(str, it);
380 }
381 dispatch(str, num);
382
383 str << "/(";
384
385 dispatch(str, den);
386
387 str << ")";
388 }
389
print_productlike(std::ostream & str,Ex::iterator it,const std::string & inbetween)390 void DisplaySympy::print_productlike(std::ostream& str, Ex::iterator it, const std::string& inbetween)
391 {
392 if(needs_brackets(it))
393 str << "(";
394
395 if(*it->multiplier!=1) {
396 print_multiplier(str, it);
397 // Ex::sibling_iterator st=tree.begin(it);
398 }
399
400 // To print \prod{\sum{a}{b}}{\sum{c}{d}} correctly:
401 // If there is any sum as child, and if the sum children do not
402 // all have the same bracket type (different from b_none or b_no),
403 // then print brackets.
404
405 str_node::bracket_t previous_bracket_=str_node::b_invalid;
406 // bool beginning_of_group=true;
407 Ex::sibling_iterator ch=tree.begin(it);
408 while(ch!=tree.end(it)) {
409 str_node::bracket_t current_bracket_=(*ch).fl.bracket;
410 if(previous_bracket_!=current_bracket_) {
411 if(current_bracket_!=str_node::b_none) {
412 print_opening_bracket(str, current_bracket_);
413 // beginning_of_group=true;
414 }
415 }
416 dispatch(str, ch);
417 ++ch;
418 if(ch==tree.end(it)) {
419 if(current_bracket_!=str_node::b_none)
420 print_closing_bracket(str, current_bracket_);
421 }
422
423 if(ch!=tree.end(it)) {
424 str << inbetween;
425 }
426 previous_bracket_=current_bracket_;
427 }
428
429 if(needs_brackets(it))
430 str << ")";
431 // if(close_bracket) str << ")";
432 }
433
print_sumlike(std::ostream & str,Ex::iterator it)434 void DisplaySympy::print_sumlike(std::ostream& str, Ex::iterator it)
435 {
436 assert(*it->multiplier==1);
437
438 if(needs_brackets(it))
439 str << "(";
440
441 unsigned int steps=0;
442
443 Ex::sibling_iterator ch=tree.begin(it);
444 while(ch!=tree.end(it)) {
445 if(++steps==20) {
446 steps=0;
447 }
448 if(*ch->multiplier>=0 && ch!=tree.begin(it))
449 str << "+";
450
451 dispatch(str, ch);
452 ++ch;
453 }
454
455 if(needs_brackets(it))
456 str << ")";
457 str << std::flush;
458 }
459
print_powlike(std::ostream & str,Ex::iterator it)460 void DisplaySympy::print_powlike(std::ostream& str, Ex::iterator it)
461 {
462 if(needs_brackets(it))
463 str << "(";
464
465 Ex::sibling_iterator sib=tree.begin(it);
466 if(*it->multiplier!=1)
467 print_multiplier(str, it);
468 dispatch(str, sib);
469 str << "**(";
470 ++sib;
471 dispatch(str, sib);
472 str << ")";
473
474 if(needs_brackets(it))
475 str << ")";
476 }
477
print_intlike(std::ostream & str,Ex::iterator it)478 void DisplaySympy::print_intlike(std::ostream& str, Ex::iterator it)
479 {
480 if(*it->multiplier!=1)
481 print_multiplier(str, it);
482 str << symmap[*it->name] << "(";
483 Ex::sibling_iterator sib=tree.begin(it);
484 dispatch(str, sib);
485 ++sib;
486 while(tree.is_valid(sib)) {
487 str << ", ";
488 dispatch(str, sib);
489 ++sib;
490 }
491 str << ")";
492 }
493
print_equalitylike(std::ostream & str,Ex::iterator it)494 void DisplaySympy::print_equalitylike(std::ostream& str, Ex::iterator it)
495 {
496 str << "Eq(";
497 Ex::sibling_iterator sib=tree.begin(it);
498 dispatch(str, sib);
499 str << ", ";
500 ++sib;
501 if(sib==tree.end(it))
502 throw ConsistencyException("Found equals node with only one child node.");
503 dispatch(str, sib);
504 str << ")";
505 }
506
print_components(std::ostream & str,Ex::iterator it)507 void DisplaySympy::print_components(std::ostream& str, Ex::iterator it)
508 {
509 str << *it->name;
510 auto sib=tree.begin(it);
511 auto end=tree.end(it);
512 --end;
513 while(sib!=end) {
514 dispatch(str, sib);
515 ++sib;
516 }
517 str << "\n";
518 sib=tree.begin(end);
519 while(sib!=tree.end(end)) {
520 str << " ";
521 dispatch(str, sib);
522 str << "\n";
523 ++sib;
524 }
525 }
526
print_partial(std::ostream & str,Ex::iterator it)527 void DisplaySympy::print_partial(std::ostream& str, Ex::iterator it)
528 {
529 if(*it->multiplier!=1)
530 print_multiplier(str, it);
531
532 str << "diff(";
533 Ex::sibling_iterator sib=tree.begin(it);
534 while(sib!=tree.end(it)) {
535 if(sib->fl.parent_rel==str_node::p_none) {
536 dispatch(str, sib);
537 break;
538 }
539 ++sib;
540 }
541 // write the implicit direction of the derivative, if any.
542 const Derivative *derivative = kernel.properties.get<Derivative>(it);
543 if(derivative) {
544 if(derivative->with_respect_to.size()>0) {
545 str << ", ";
546 dispatch(str, derivative->with_respect_to.begin());
547 }
548 }
549
550 // write the explicit direction(s) of the derivative.
551 sib=tree.begin(it);
552 while(sib!=tree.end(it)) {
553 if(sib->fl.parent_rel!=str_node::p_none) {
554 str << ", ";
555 dispatch(str, sib);
556 }
557 ++sib;
558 }
559 str << ")";
560 }
561
print_matrix(std::ostream & str,Ex::iterator it)562 void DisplaySympy::print_matrix(std::ostream& str, Ex::iterator it)
563 {
564 str << "Matrix([";
565 auto comma=tree.begin(it);
566 Ex::sibling_iterator row_it = tree.begin(comma);
567 while(row_it!=tree.end(comma)) {
568 if(row_it!=tree.begin(comma)) str << ", ";
569 Ex::sibling_iterator col_it = tree.begin(row_it);
570 str << "[";
571 while(col_it!=tree.end(row_it)) {
572 if(col_it!=tree.begin(row_it)) str << ", ";
573 dispatch(str, col_it);
574 ++col_it;
575 }
576 str << "]";
577 ++row_it;
578 }
579 str << "])";
580 }
581
children_have_brackets(Ex::iterator ch) const582 bool DisplaySympy::children_have_brackets(Ex::iterator ch) const
583 {
584 Ex::sibling_iterator chlds=tree.begin(ch);
585 str_node::bracket_t childbr=chlds->fl.bracket;
586 if(childbr==str_node::b_none || childbr==str_node::b_no)
587 return false;
588 else return true;
589 }
590
preparse_import(const std::string & in)591 std::string DisplaySympy::preparse_import(const std::string& in)
592 {
593 #ifdef DEBUG
594 std::cerr << "DisplaySympy::preparse_import" << std::endl;
595 #endif
596 std::string ret = in;
597 for(auto& r: regex_map) {
598 #ifdef DEBUG
599 std::cerr << "Replacing " << r.first << " with " << r.second << std::endl;
600 #endif
601 ret = std::regex_replace(ret, std::regex(r.first), r.second);
602 }
603 return ret;
604 }
605
import(Ex & ex)606 void DisplaySympy::import(Ex& ex)
607 {
608 cadabra::do_subtree(ex, ex.begin(), [&](Ex::iterator it) -> Ex::iterator {
609 for(auto& m: symmap)
610 {
611 // If we have converted the name of this symbol, convert back.
612 if(m.second==*it->name) {
613 it->name=name_set.insert(m.first).first;
614 break;
615 }
616 }
617 // See if we have added dependencies to this symbol (lookup in map).
618 // If yes, strip them off again.
619 auto fnd = depsyms.find(it->name);
620 if(fnd!=depsyms.end())
621 {
622 auto args=ex.begin(it);
623 // Strip out only those symbols which have been added.
624 while(args!=ex.end(it)) {
625 if(args->fl.parent_rel==str_node::p_none) {
626 auto findsib=fnd->second.begin(fnd->second.begin());
627 bool removed=false;
628 while(findsib!=fnd->second.end(fnd->second.begin())) {
629 if(subtree_equal(0, findsib, args)) {
630 args=ex.erase(args);
631 removed=true;
632 break;
633 }
634 ++findsib;
635 }
636 if(!removed)
637 ++args;
638 }
639 else
640 ++args;
641 }
642 // std::cerr << "stripping from " << *it->name << std::endl;
643 //// if(*ex.begin(it)->name=="\\comma")
644 // ex.erase(ex.begin(it));
645 }
646
647 // Move child nodes of partial to the right place.
648 if(*it->name=="\\partial")
649 {
650 auto args = ex.begin(it);
651 ++args;
652 while(args!=ex.end(it)) {
653 auto nxt=args;
654 ++nxt;
655 auto loc = ex.move_before(ex.begin(it), args);
656 loc->fl.parent_rel=str_node::p_sub;
657
658 // If the argument is \comma{x}{n} expand this to 'n' arguments 'x'.
659 // This is to handle Sympy returning 'Derivative(f(x), (x,2))' for the
660 // 2nd order derivative.
661
662 if(*loc->name=="\\comma") {
663 #ifdef DEBUG
664 std::cerr << loc << std::endl;
665 #endif
666 auto x=ex.begin(loc);
667 auto n=x;
668 ++n;
669 if(! n->is_integer())
670 throw RuntimeException("DisplaySympy::import received un-parseable Derivative expression.");
671 int nn=to_long(*n->multiplier);
672 for(int k=0; k<nn; ++k)
673 ex.insert_subtree(loc, x)->fl.parent_rel=str_node::p_sub;
674 ex.erase(loc);
675 #ifdef DEBUG
676 std::cerr << it << std::endl;
677 #endif
678 }
679
680
681 args=nxt;
682 }
683
684 // Strip subscripts which are the same as the 'with_respect_to' member of the
685 // derivative (if any), as these are implicit in Cadabra. This is tricky, because
686 // a multiple derivative with respect to this argument needs to be replaced
687 // with a multiple nesting of the derivative operator itself, e.g.
688 // \partial{\partial{r}} -> diff(diff(r(t),t),t) -> diff(r(t),t,t)
689
690 const Derivative *derivative = kernel.properties.get<Derivative>(it);
691 if(derivative) {
692 #ifdef DEBUG
693 std::cerr << "is proper derivative" << std::endl;
694 #endif
695 if(derivative->with_respect_to.size()>0) {
696 auto it_copy=it;
697 args=ex.begin(it_copy);
698 bool first=true;
699 while(args!=ex.end(it_copy)) {
700 #ifdef DEBUG
701 std::cerr << "Comparing: " << args << std::endl
702 << "and " << derivative->with_respect_to.begin()
703 << std::endl;
704 #endif
705
706 if(subtree_equal(0, args, derivative->with_respect_to.begin(), 0) ) {
707 args=ex.erase(args);
708 if(first) {
709 first=false;
710 }
711 else {
712 it=ex.wrap(it, str_node(it->name));
713 }
714 }
715 else
716 ++args;
717 }
718 }
719 }
720 else {
721 #ifdef DEBUG
722 std::cerr << it << " is not a proper derivative" << std::endl;
723 #endif
724
725 }
726
727 // ex.flatten(comma);
728 // ex.erase(comma);
729 // }
730 }
731
732 return it;
733 });
734 }
735