1 #include <symengine/visitor.h>
2 #include <symengine/symengine_exception.h>
3 
4 namespace SymEngine
5 {
6 
7 extern RCP<const Basic> i2;
8 extern RCP<const Basic> i3;
9 extern RCP<const Basic> i5;
10 extern RCP<const Basic> im2;
11 extern RCP<const Basic> im3;
12 extern RCP<const Basic> im5;
13 
sqrt(RCP<const Basic> & arg)14 RCP<const Basic> sqrt(RCP<const Basic> &arg)
15 {
16     return pow(arg, div(one, i2));
17 }
cbrt(RCP<const Basic> & arg)18 RCP<const Basic> cbrt(RCP<const Basic> &arg)
19 {
20     return pow(arg, div(one, i3));
21 }
22 
23 extern RCP<const Basic> sq3;
24 extern RCP<const Basic> sq2;
25 extern RCP<const Basic> sq5;
26 
27 extern RCP<const Basic> C0;
28 extern RCP<const Basic> C1;
29 extern RCP<const Basic> C2;
30 extern RCP<const Basic> C3;
31 extern RCP<const Basic> C4;
32 extern RCP<const Basic> C5;
33 extern RCP<const Basic> C6;
34 
35 extern RCP<const Basic> mC0;
36 extern RCP<const Basic> mC1;
37 extern RCP<const Basic> mC2;
38 extern RCP<const Basic> mC3;
39 extern RCP<const Basic> mC4;
40 extern RCP<const Basic> mC5;
41 extern RCP<const Basic> mC6;
42 
43 extern RCP<const Basic> sin_table[];
44 
45 extern umap_basic_basic inverse_cst;
46 
47 extern umap_basic_basic inverse_tct;
48 
Conjugate(const RCP<const Basic> & arg)49 Conjugate::Conjugate(const RCP<const Basic> &arg) : OneArgFunction(arg)
50 {
51     SYMENGINE_ASSIGN_TYPEID()
52     SYMENGINE_ASSERT(is_canonical(arg))
53 }
54 
is_canonical(const RCP<const Basic> & arg) const55 bool Conjugate::is_canonical(const RCP<const Basic> &arg) const
56 {
57     if (is_a_Number(*arg)) {
58         if (eq(*arg, *ComplexInf)) {
59             return true;
60         }
61         return false;
62     }
63     if (is_a<Constant>(*arg)) {
64         return false;
65     }
66     if (is_a<Mul>(*arg)) {
67         return false;
68     }
69     if (is_a<Pow>(*arg)) {
70         if (is_a<Integer>(*down_cast<const Pow &>(*arg).get_exp())) {
71             return false;
72         }
73     }
74     // OneArgFunction classes
75     if (is_a<Sign>(*arg) or is_a<Conjugate>(*arg) or is_a<Erf>(*arg)
76         or is_a<Erfc>(*arg) or is_a<Gamma>(*arg) or is_a<LogGamma>(*arg)
77         or is_a<Abs>(*arg)) {
78         return false;
79     }
80     if (is_a<Sin>(*arg) or is_a<Cos>(*arg) or is_a<Tan>(*arg) or is_a<Cot>(*arg)
81         or is_a<Sec>(*arg) or is_a<Csc>(*arg)) {
82         return false;
83     }
84     if (is_a<Sinh>(*arg) or is_a<Cosh>(*arg) or is_a<Tanh>(*arg)
85         or is_a<Coth>(*arg) or is_a<Sech>(*arg) or is_a<Csch>(*arg)) {
86         return false;
87     }
88     // TwoArgFunction classes
89     if (is_a<KroneckerDelta>(*arg) or is_a<ATan2>(*arg)
90         or is_a<LowerGamma>(*arg) or is_a<UpperGamma>(*arg)
91         or is_a<Beta>(*arg)) {
92         return false;
93     }
94     // MultiArgFunction class
95     if (is_a<LeviCivita>(*arg)) {
96         return false;
97     }
98     return true;
99 }
100 
create(const RCP<const Basic> & arg) const101 RCP<const Basic> Conjugate::create(const RCP<const Basic> &arg) const
102 {
103     return conjugate(arg);
104 }
105 
conjugate(const RCP<const Basic> & arg)106 RCP<const Basic> conjugate(const RCP<const Basic> &arg)
107 {
108     if (is_a_Number(*arg)) {
109         return down_cast<const Number &>(*arg).conjugate();
110     }
111     if (is_a<Constant>(*arg) or is_a<Abs>(*arg) or is_a<KroneckerDelta>(*arg)
112         or is_a<LeviCivita>(*arg)) {
113         return arg;
114     }
115     if (is_a<Mul>(*arg)) {
116         const map_basic_basic &dict = down_cast<const Mul &>(*arg).get_dict();
117         map_basic_basic new_dict;
118         RCP<const Number> coef = rcp_static_cast<const Number>(
119             conjugate(down_cast<const Mul &>(*arg).get_coef()));
120         for (const auto &p : dict) {
121             if (is_a<Integer>(*p.second)) {
122                 Mul::dict_add_term_new(outArg(coef), new_dict, p.second,
123                                        conjugate(p.first));
124             } else {
125                 Mul::dict_add_term_new(
126                     outArg(coef), new_dict, one,
127                     conjugate(Mul::from_dict(one, {{p.first, p.second}})));
128             }
129         }
130         return Mul::from_dict(coef, std::move(new_dict));
131     }
132     if (is_a<Pow>(*arg)) {
133         RCP<const Basic> base = down_cast<const Pow &>(*arg).get_base();
134         RCP<const Basic> exp = down_cast<const Pow &>(*arg).get_exp();
135         if (is_a<Integer>(*exp)) {
136             return pow(conjugate(base), exp);
137         }
138     }
139     if (is_a<Conjugate>(*arg)) {
140         return down_cast<const Conjugate &>(*arg).get_arg();
141     }
142     if (is_a<Sign>(*arg) or is_a<Erf>(*arg) or is_a<Erfc>(*arg)
143         or is_a<Gamma>(*arg) or is_a<LogGamma>(*arg) or is_a<Sin>(*arg)
144         or is_a<Cos>(*arg) or is_a<Tan>(*arg) or is_a<Cot>(*arg)
145         or is_a<Sec>(*arg) or is_a<Csc>(*arg) or is_a<Sinh>(*arg)
146         or is_a<Cosh>(*arg) or is_a<Tanh>(*arg) or is_a<Coth>(*arg)
147         or is_a<Sech>(*arg) or is_a<Csch>(*arg)) {
148         const OneArgFunction &func = down_cast<const OneArgFunction &>(*arg);
149         return func.create(conjugate(func.get_arg()));
150     }
151     if (is_a<ATan2>(*arg) or is_a<LowerGamma>(*arg) or is_a<UpperGamma>(*arg)
152         or is_a<Beta>(*arg)) {
153         const TwoArgFunction &func = down_cast<const TwoArgFunction &>(*arg);
154         return func.create(conjugate(func.get_arg1()),
155                            conjugate(func.get_arg2()));
156     }
157     return make_rcp<const Conjugate>(arg);
158 }
159 
get_pi_shift(const RCP<const Basic> & arg,const Ptr<RCP<const Number>> & n,const Ptr<RCP<const Basic>> & x)160 bool get_pi_shift(const RCP<const Basic> &arg, const Ptr<RCP<const Number>> &n,
161                   const Ptr<RCP<const Basic>> &x)
162 {
163     if (is_a<Add>(*arg)) {
164         const Add &s = down_cast<const Add &>(*arg);
165         RCP<const Basic> coef = s.get_coef();
166         auto size = s.get_dict().size();
167         if (size > 1) {
168             // arg should be of form `x + n*pi`
169             // `n` is an integer
170             // `x` is an `Expression`
171             bool check_pi = false;
172             RCP<const Basic> temp;
173             *x = coef;
174             for (const auto &p : s.get_dict()) {
175                 if (eq(*p.first, *pi)
176                     and (is_a<Integer>(*p.second)
177                          or is_a<Rational>(*p.second))) {
178                     check_pi = true;
179                     *n = p.second;
180                 } else {
181                     *x = add(mul(p.first, p.second), *x);
182                 }
183             }
184             if (check_pi)
185                 return true;
186             else // No term with `pi` found
187                 return false;
188         } else if (size == 1) {
189             // arg should be of form `a + n*pi`
190             // where `a` is a `Number`.
191             auto p = s.get_dict().begin();
192             if (eq(*p->first, *pi)
193                 and (is_a<Integer>(*p->second) or is_a<Rational>(*p->second))) {
194                 *n = p->second;
195                 *x = coef;
196                 return true;
197             } else {
198                 return false;
199             }
200         } else { // Should never reach here though!
201             // Dict of size < 1
202             return false;
203         }
204     } else if (is_a<Mul>(*arg)) {
205         // `arg` is of the form `k*pi/12`
206         const Mul &s = down_cast<const Mul &>(*arg);
207         auto p = s.get_dict().begin();
208         // dict should contain symbol `pi` only
209         if (s.get_dict().size() == 1 and eq(*p->first, *pi)
210             and eq(*p->second, *one)
211             and (is_a<Integer>(*s.get_coef())
212                  or is_a<Rational>(*s.get_coef()))) {
213             *n = s.get_coef();
214             *x = zero;
215             return true;
216         } else {
217             return false;
218         }
219     } else if (eq(*arg, *pi)) {
220         *n = one;
221         *x = zero;
222         return true;
223     } else if (eq(*arg, *zero)) {
224         *n = zero;
225         *x = zero;
226         return true;
227     } else {
228         return false;
229     }
230 }
231 
232 // Return true if arg is of form a+b*pi, with b integer or rational
233 // with denominator 2. The a may be zero or any expression.
trig_has_basic_shift(const RCP<const Basic> & arg)234 bool trig_has_basic_shift(const RCP<const Basic> &arg)
235 {
236     if (is_a<Add>(*arg)) {
237         const Add &s = down_cast<const Add &>(*arg);
238         for (const auto &p : s.get_dict()) {
239             const auto &temp = mul(p.second, integer(2));
240             if (eq(*p.first, *pi)) {
241                 if (is_a<Integer>(*temp)) {
242                     return true;
243                 }
244                 if (is_a<Rational>(*temp)) {
245                     auto m = down_cast<const Rational &>(*temp)
246                                  .as_rational_class();
247                     return (m < 0) or (m > 1);
248                 }
249                 return false;
250             }
251         }
252         return false;
253     } else if (is_a<Mul>(*arg)) {
254         // is `arg` of the form `k*pi/2`?
255         // dict should contain symbol `pi` only
256         // and `k` should be a rational s.t. 0 < k < 1
257         const Mul &s = down_cast<const Mul &>(*arg);
258         RCP<const Basic> coef = mul(s.get_coef(), integer(2));
259         auto p = s.get_dict().begin();
260         if (s.get_dict().size() == 1 and eq(*p->first, *pi)
261             and eq(*p->second, *one)) {
262             if (is_a<Integer>(*coef)) {
263                 return true;
264             }
265             if (is_a<Rational>(*coef)) {
266                 auto m = down_cast<const Rational &>(*coef).as_rational_class();
267                 return (m < 0) or (m > 1);
268             }
269             return false;
270         } else {
271             return false;
272         }
273     } else if (eq(*arg, *pi)) {
274         return true;
275     } else if (eq(*arg, *zero)) {
276         return true;
277     } else {
278         return false;
279     }
280 }
281 
could_extract_minus(const Basic & arg)282 bool could_extract_minus(const Basic &arg)
283 {
284     if (is_a_Number(arg)) {
285         if (down_cast<const Number &>(arg).is_negative()) {
286             return true;
287         } else if (is_a_Complex(arg)) {
288             const ComplexBase &c = down_cast<const ComplexBase &>(arg);
289             RCP<const Number> real_part = c.real_part();
290             return (real_part->is_negative())
291                    or (eq(*real_part, *zero)
292                        and c.imaginary_part()->is_negative());
293         } else {
294             return false;
295         }
296     } else if (is_a<Mul>(arg)) {
297         const Mul &s = down_cast<const Mul &>(arg);
298         return could_extract_minus(*s.get_coef());
299     } else if (is_a<Add>(arg)) {
300         const Add &s = down_cast<const Add &>(arg);
301         if (s.get_coef()->is_zero()) {
302             map_basic_num d(s.get_dict().begin(), s.get_dict().end());
303             return could_extract_minus(*d.begin()->second);
304         } else {
305             return could_extract_minus(*s.get_coef());
306         }
307     } else {
308         return false;
309     }
310 }
311 
handle_minus(const RCP<const Basic> & arg,const Ptr<RCP<const Basic>> & rarg)312 bool handle_minus(const RCP<const Basic> &arg,
313                   const Ptr<RCP<const Basic>> &rarg)
314 {
315     if (is_a<Mul>(*arg)) {
316         const Mul &s = down_cast<const Mul &>(*arg);
317         // Check for -Add instances to transform -(-x + 2*y) to (x - 2*y)
318         if (s.get_coef()->is_minus_one() && s.get_dict().size() == 1
319             && eq(*s.get_dict().begin()->second, *one)) {
320             return not handle_minus(mul(minus_one, arg), rarg);
321         } else if (could_extract_minus(*s.get_coef())) {
322             *rarg = mul(minus_one, arg);
323             return true;
324         }
325     } else if (is_a<Add>(*arg)) {
326         if (could_extract_minus(*arg)) {
327             const Add &s = down_cast<const Add &>(*arg);
328             umap_basic_num d = s.get_dict();
329             for (auto &p : d) {
330                 p.second = p.second->mul(*minus_one);
331             }
332             *rarg = Add::from_dict(s.get_coef()->mul(*minus_one), std::move(d));
333             return true;
334         }
335     } else if (could_extract_minus(*arg)) {
336         *rarg = mul(minus_one, arg);
337         return true;
338     }
339     *rarg = arg;
340     return false;
341 }
342 
343 // \return true if conjugate has to be returned finally else false
trig_simplify(const RCP<const Basic> & arg,unsigned period,bool odd,bool conj_odd,const Ptr<RCP<const Basic>> & rarg,int & index,int & sign)344 bool trig_simplify(const RCP<const Basic> &arg, unsigned period, bool odd,
345                    bool conj_odd, // input
346                    const Ptr<RCP<const Basic>> &rarg, int &index,
347                    int &sign) // output
348 {
349     bool check;
350     RCP<const Number> n;
351     RCP<const Basic> r;
352     RCP<const Basic> ret_arg;
353     check = get_pi_shift(arg, outArg(n), outArg(r));
354     if (check) {
355         RCP<const Number> t = mulnum(n, integer(12));
356         sign = 1;
357         if (is_a<Integer>(*t)) {
358             int m = numeric_cast<int>(
359                 mod_f(down_cast<const Integer &>(*t), *integer(12 * period))
360                     ->as_int());
361             if (eq(*r, *zero)) {
362                 index = m;
363                 *rarg = zero;
364                 return false;
365             } else if (m == 0) {
366                 index = 0;
367                 bool b = handle_minus(r, outArg(ret_arg));
368                 *rarg = ret_arg;
369                 if (odd and b)
370                     sign = -1;
371                 return false;
372             }
373         }
374 
375         rational_class m;
376         if (is_a<Integer>(*n)) {
377             // 2*pi periodic => f(r + pi * n) = f(r - pi * n)
378             m = mp_abs(down_cast<const Integer &>(*n).as_integer_class());
379             m /= period;
380         } else {
381             SYMENGINE_ASSERT(is_a<Rational>(*n));
382             m = down_cast<const Rational &>(*n).as_rational_class() / period;
383             integer_class t;
384 #if SYMENGINE_INTEGER_CLASS != SYMENGINE_BOOSTMP
385             mp_fdiv_r(t, get_num(m), get_den(m));
386             get_num(m) = t;
387 #else
388             integer_class quo;
389             mp_fdiv_qr(quo, t, get_num(m), get_den(m));
390             m -= rational_class(quo);
391 #endif
392             // m = a / b => m = (a % b / b)
393         }
394         // Now, arg = r + 2 * pi * m  where 0 <= m < 1
395         m *= 2 * period;
396         // Now, arg = r + pi * m / 2  where 0 <= m < 4
397         if (m >= 2 and m < 3) {
398             sign = -1;
399             r = add(r, mul(pi, Rational::from_mpq((m - 2) / 2)));
400             bool b = handle_minus(r, outArg(ret_arg));
401             *rarg = ret_arg;
402             if (odd and b)
403                 sign = -1 * sign;
404             return false;
405         } else if (m >= 1) {
406             if (m < 2) {
407                 // 1 <= m < 2
408                 sign = 1;
409                 r = add(r, mul(pi, Rational::from_mpq((m - 1) / 2)));
410             } else {
411                 // 3 <= m < 4
412                 sign = -1;
413                 r = add(r, mul(pi, Rational::from_mpq((m - 3) / 2)));
414             }
415             bool b = handle_minus(r, outArg(ret_arg));
416             *rarg = ret_arg;
417             if (not b and conj_odd)
418                 sign = -sign;
419             return true;
420         } else {
421             *rarg = add(r, mul(pi, Rational::from_mpq(m / 2)));
422             index = -1;
423             return false;
424         }
425     } else {
426         bool b = handle_minus(arg, outArg(ret_arg));
427         *rarg = ret_arg;
428         index = -1;
429         if (odd and b)
430             sign = -1;
431         else
432             sign = 1;
433         return false;
434     }
435 }
436 
inverse_lookup(umap_basic_basic & d,const RCP<const Basic> & t,const Ptr<RCP<const Basic>> & index)437 bool inverse_lookup(umap_basic_basic &d, const RCP<const Basic> &t,
438                     const Ptr<RCP<const Basic>> &index)
439 {
440     auto it = d.find(t);
441     if (it == d.end()) {
442         // Not found in lookup
443         return false;
444     } else {
445         *index = (it->second);
446         return true;
447     }
448 }
449 
Sign(const RCP<const Basic> & arg)450 Sign::Sign(const RCP<const Basic> &arg) : OneArgFunction(arg)
451 {
452     SYMENGINE_ASSIGN_TYPEID()
453     SYMENGINE_ASSERT(is_canonical(arg))
454 }
455 
is_canonical(const RCP<const Basic> & arg) const456 bool Sign::is_canonical(const RCP<const Basic> &arg) const
457 {
458     if (is_a_Number(*arg)) {
459         if (eq(*arg, *ComplexInf)) {
460             return true;
461         }
462         return false;
463     }
464     if (is_a<Constant>(*arg)) {
465         return false;
466     }
467     if (is_a<Sign>(*arg)) {
468         return false;
469     }
470     if (is_a<Mul>(*arg)) {
471         if (neq(*down_cast<const Mul &>(*arg).get_coef(), *one)
472             and neq(*down_cast<const Mul &>(*arg).get_coef(), *minus_one)) {
473             return false;
474         }
475     }
476     return true;
477 }
478 
create(const RCP<const Basic> & arg) const479 RCP<const Basic> Sign::create(const RCP<const Basic> &arg) const
480 {
481     return sign(arg);
482 }
483 
sign(const RCP<const Basic> & arg)484 RCP<const Basic> sign(const RCP<const Basic> &arg)
485 {
486     if (is_a_Number(*arg)) {
487         if (is_a<NaN>(*arg)) {
488             return Nan;
489         }
490         if (down_cast<const Number &>(*arg).is_zero()) {
491             return zero;
492         }
493         if (down_cast<const Number &>(*arg).is_positive()) {
494             return one;
495         }
496         if (down_cast<const Number &>(*arg).is_negative()) {
497             return minus_one;
498         }
499         if (is_a_Complex(*arg)
500             and down_cast<const ComplexBase &>(*arg).is_re_zero()) {
501             RCP<const Number> r
502                 = down_cast<const ComplexBase &>(*arg).imaginary_part();
503             if (down_cast<const Number &>(*r).is_positive()) {
504                 return I;
505             }
506             if (down_cast<const Number &>(*r).is_negative()) {
507                 return mul(minus_one, I);
508             }
509         }
510     }
511     if (is_a<Constant>(*arg)) {
512         if (eq(*arg, *pi) or eq(*arg, *E) or eq(*arg, *EulerGamma)
513             or eq(*arg, *Catalan) or eq(*arg, *GoldenRatio))
514             return one;
515     }
516     if (is_a<Sign>(*arg)) {
517         return arg;
518     }
519     if (is_a<Mul>(*arg)) {
520         RCP<const Basic> s = sign(down_cast<const Mul &>(*arg).get_coef());
521         map_basic_basic dict = down_cast<const Mul &>(*arg).get_dict();
522         return mul(s,
523                    make_rcp<const Sign>(Mul::from_dict(one, std::move(dict))));
524     }
525     return make_rcp<const Sign>(arg);
526 }
527 
Floor(const RCP<const Basic> & arg)528 Floor::Floor(const RCP<const Basic> &arg) : OneArgFunction(arg)
529 {
530     SYMENGINE_ASSIGN_TYPEID()
531     SYMENGINE_ASSERT(is_canonical(arg))
532 }
533 
is_canonical(const RCP<const Basic> & arg) const534 bool Floor::is_canonical(const RCP<const Basic> &arg) const
535 {
536     if (is_a_Number(*arg)) {
537         return false;
538     }
539     if (is_a<Constant>(*arg)) {
540         return false;
541     }
542     if (is_a<Floor>(*arg)) {
543         return false;
544     }
545     if (is_a<Ceiling>(*arg)) {
546         return false;
547     }
548     if (is_a<Truncate>(*arg)) {
549         return false;
550     }
551     if (is_a<BooleanAtom>(*arg) or is_a_Relational(*arg)) {
552         return false;
553     }
554     if (is_a<Add>(*arg)) {
555         RCP<const Number> s = down_cast<const Add &>(*arg).get_coef();
556         if (neq(*zero, *s) and is_a<Integer>(*s)) {
557             return false;
558         }
559     }
560     return true;
561 }
562 
create(const RCP<const Basic> & arg) const563 RCP<const Basic> Floor::create(const RCP<const Basic> &arg) const
564 {
565     return floor(arg);
566 }
567 
floor(const RCP<const Basic> & arg)568 RCP<const Basic> floor(const RCP<const Basic> &arg)
569 {
570     if (is_a_Number(*arg)) {
571         if (down_cast<const Number &>(*arg).is_exact()) {
572             if (is_a<Rational>(*arg)) {
573                 const Rational &s = down_cast<const Rational &>(*arg);
574                 integer_class quotient;
575                 mp_fdiv_q(quotient, SymEngine::get_num(s.as_rational_class()),
576                           SymEngine::get_den(s.as_rational_class()));
577                 return integer(std::move(quotient));
578             }
579             return arg;
580         }
581         RCP<const Number> _arg = rcp_static_cast<const Number>(arg);
582         return _arg->get_eval().floor(*_arg);
583     }
584     if (is_a<Constant>(*arg)) {
585         if (eq(*arg, *pi)) {
586             return integer(3);
587         }
588         if (eq(*arg, *E)) {
589             return integer(2);
590         }
591         if (eq(*arg, *GoldenRatio)) {
592             return integer(1);
593         }
594         if (eq(*arg, *Catalan) or eq(*arg, *EulerGamma)) {
595             return integer(0);
596         }
597     }
598     if (is_a<Floor>(*arg)) {
599         return arg;
600     }
601     if (is_a<Ceiling>(*arg)) {
602         return arg;
603     }
604     if (is_a<Truncate>(*arg)) {
605         return arg;
606     }
607     if (is_a<BooleanAtom>(*arg) or is_a_Relational(*arg)) {
608         throw SymEngineException(
609             "Boolean objects not allowed in this context.");
610     }
611     if (is_a<Add>(*arg)) {
612         RCP<const Number> s = down_cast<const Add &>(*arg).get_coef();
613         umap_basic_num d = down_cast<const Add &>(*arg).get_dict();
614         if (is_a<Integer>(*s)
615             and not down_cast<const Integer &>(*s).is_zero()) {
616             return add(s, floor(Add::from_dict(zero, std::move(d))));
617         }
618     }
619     return make_rcp<const Floor>(arg);
620 }
621 
Ceiling(const RCP<const Basic> & arg)622 Ceiling::Ceiling(const RCP<const Basic> &arg) : OneArgFunction(arg)
623 {
624     SYMENGINE_ASSIGN_TYPEID()
625     SYMENGINE_ASSERT(is_canonical(arg))
626 }
627 
is_canonical(const RCP<const Basic> & arg) const628 bool Ceiling::is_canonical(const RCP<const Basic> &arg) const
629 {
630     if (is_a_Number(*arg)) {
631         return false;
632     }
633     if (is_a<Constant>(*arg)) {
634         return false;
635     }
636     if (is_a<Floor>(*arg)) {
637         return false;
638     }
639     if (is_a<Ceiling>(*arg)) {
640         return false;
641     }
642     if (is_a<Truncate>(*arg)) {
643         return false;
644     }
645     if (is_a<BooleanAtom>(*arg) or is_a_Relational(*arg)) {
646         return false;
647     }
648     if (is_a<Add>(*arg)) {
649         RCP<const Number> s = down_cast<const Add &>(*arg).get_coef();
650         if (neq(*zero, *s) and is_a<Integer>(*s)) {
651             return false;
652         }
653     }
654     return true;
655 }
656 
create(const RCP<const Basic> & arg) const657 RCP<const Basic> Ceiling::create(const RCP<const Basic> &arg) const
658 {
659     return ceiling(arg);
660 }
661 
ceiling(const RCP<const Basic> & arg)662 RCP<const Basic> ceiling(const RCP<const Basic> &arg)
663 {
664     if (is_a_Number(*arg)) {
665         if (down_cast<const Number &>(*arg).is_exact()) {
666             if (is_a<Rational>(*arg)) {
667                 const Rational &s = down_cast<const Rational &>(*arg);
668                 integer_class quotient;
669                 mp_cdiv_q(quotient, SymEngine::get_num(s.as_rational_class()),
670                           SymEngine::get_den(s.as_rational_class()));
671                 return integer(std::move(quotient));
672             }
673             return arg;
674         }
675         RCP<const Number> _arg = rcp_static_cast<const Number>(arg);
676         return _arg->get_eval().ceiling(*_arg);
677     }
678     if (is_a<Constant>(*arg)) {
679         if (eq(*arg, *pi)) {
680             return integer(4);
681         }
682         if (eq(*arg, *E)) {
683             return integer(3);
684         }
685         if (eq(*arg, *GoldenRatio)) {
686             return integer(2);
687         }
688         if (eq(*arg, *Catalan) or eq(*arg, *EulerGamma)) {
689             return integer(1);
690         }
691     }
692     if (is_a<Floor>(*arg)) {
693         return arg;
694     }
695     if (is_a<Ceiling>(*arg)) {
696         return arg;
697     }
698     if (is_a<Truncate>(*arg)) {
699         return arg;
700     }
701     if (is_a<BooleanAtom>(*arg) or is_a_Relational(*arg)) {
702         throw SymEngineException(
703             "Boolean objects not allowed in this context.");
704     }
705     if (is_a<Add>(*arg)) {
706         RCP<const Number> s = down_cast<const Add &>(*arg).get_coef();
707         umap_basic_num d = down_cast<const Add &>(*arg).get_dict();
708         if (is_a<Integer>(*s)) {
709             return add(
710                 s, make_rcp<const Ceiling>(Add::from_dict(zero, std::move(d))));
711         }
712     }
713     return make_rcp<const Ceiling>(arg);
714 }
715 
Truncate(const RCP<const Basic> & arg)716 Truncate::Truncate(const RCP<const Basic> &arg) : OneArgFunction(arg)
717 {
718     SYMENGINE_ASSIGN_TYPEID()
719     SYMENGINE_ASSERT(is_canonical(arg))
720 }
721 
is_canonical(const RCP<const Basic> & arg) const722 bool Truncate::is_canonical(const RCP<const Basic> &arg) const
723 {
724     if (is_a_Number(*arg)) {
725         return false;
726     }
727     if (is_a<Constant>(*arg)) {
728         return false;
729     }
730     if (is_a<Floor>(*arg)) {
731         return false;
732     }
733     if (is_a<Ceiling>(*arg)) {
734         return false;
735     }
736     if (is_a<Truncate>(*arg)) {
737         return false;
738     }
739     if (is_a<BooleanAtom>(*arg) or is_a_Relational(*arg)) {
740         return false;
741     }
742     if (is_a<Add>(*arg)) {
743         RCP<const Number> s = down_cast<const Add &>(*arg).get_coef();
744         if (neq(*zero, *s) and is_a<Integer>(*s)) {
745             return false;
746         }
747     }
748     return true;
749 }
750 
create(const RCP<const Basic> & arg) const751 RCP<const Basic> Truncate::create(const RCP<const Basic> &arg) const
752 {
753     return truncate(arg);
754 }
755 
truncate(const RCP<const Basic> & arg)756 RCP<const Basic> truncate(const RCP<const Basic> &arg)
757 {
758     if (is_a_Number(*arg)) {
759         if (down_cast<const Number &>(*arg).is_exact()) {
760             if (is_a<Rational>(*arg)) {
761                 const Rational &s = down_cast<const Rational &>(*arg);
762                 integer_class quotient;
763                 mp_tdiv_q(quotient, SymEngine::get_num(s.as_rational_class()),
764                           SymEngine::get_den(s.as_rational_class()));
765                 return integer(std::move(quotient));
766             }
767             return arg;
768         }
769         RCP<const Number> _arg = rcp_static_cast<const Number>(arg);
770         return _arg->get_eval().truncate(*_arg);
771     }
772     if (is_a<Constant>(*arg)) {
773         if (eq(*arg, *pi)) {
774             return integer(3);
775         }
776         if (eq(*arg, *E)) {
777             return integer(2);
778         }
779         if (eq(*arg, *GoldenRatio)) {
780             return integer(1);
781         }
782         if (eq(*arg, *Catalan) or eq(*arg, *EulerGamma)) {
783             return integer(0);
784         }
785     }
786     if (is_a<Floor>(*arg)) {
787         return arg;
788     }
789     if (is_a<Ceiling>(*arg)) {
790         return arg;
791     }
792     if (is_a<Truncate>(*arg)) {
793         return arg;
794     }
795     if (is_a<BooleanAtom>(*arg) or is_a_Relational(*arg)) {
796         throw SymEngineException(
797             "Boolean objects not allowed in this context.");
798     }
799     if (is_a<Add>(*arg)) {
800         RCP<const Number> s = down_cast<const Add &>(*arg).get_coef();
801         umap_basic_num d = down_cast<const Add &>(*arg).get_dict();
802         if (is_a<Integer>(*s)) {
803             return add(s, make_rcp<const Truncate>(
804                               Add::from_dict(zero, std::move(d))));
805         }
806     }
807     return make_rcp<const Truncate>(arg);
808 }
809 
Sin(const RCP<const Basic> & arg)810 Sin::Sin(const RCP<const Basic> &arg) : TrigFunction(arg)
811 {
812     SYMENGINE_ASSIGN_TYPEID()
813     SYMENGINE_ASSERT(is_canonical(arg))
814 }
815 
is_canonical(const RCP<const Basic> & arg) const816 bool Sin::is_canonical(const RCP<const Basic> &arg) const
817 {
818     // e.g. sin(0)
819     if (is_a<Integer>(*arg) and down_cast<const Integer &>(*arg).is_zero())
820         return false;
821     // e.g sin(7*pi/2+y)
822     if (trig_has_basic_shift(arg)) {
823         return false;
824     }
825     if (is_a_Number(*arg) and not down_cast<const Number &>(*arg).is_exact()) {
826         return false;
827     }
828     return true;
829 }
830 
sin(const RCP<const Basic> & arg)831 RCP<const Basic> sin(const RCP<const Basic> &arg)
832 {
833     if (eq(*arg, *zero))
834         return zero;
835     if (is_a_Number(*arg) and not down_cast<const Number &>(*arg).is_exact()) {
836         return down_cast<const Number &>(*arg).get_eval().sin(*arg);
837     }
838 
839     if (is_a<ASin>(*arg)) {
840         return down_cast<const ASin &>(*arg).get_arg();
841     } else if (is_a<ACsc>(*arg)) {
842         return div(one, down_cast<const ACsc &>(*arg).get_arg());
843     }
844 
845     RCP<const Basic> ret_arg;
846     int index, sign;
847     bool conjugate = trig_simplify(arg, 2, true, false,           // input
848                                    outArg(ret_arg), index, sign); // output
849 
850     if (conjugate) {
851         // cos has to be returned
852         if (sign == 1) {
853             return cos(ret_arg);
854         } else {
855             return mul(minus_one, cos(ret_arg));
856         }
857     } else {
858         if (eq(*ret_arg, *zero)) {
859             return mul(integer(sign), sin_table[index]);
860         } else {
861             // If ret_arg is the same as arg, a `Sin` instance is returned
862             // Or else `sin` is called again.
863             if (sign == 1) {
864                 if (neq(*ret_arg, *arg)) {
865                     return sin(ret_arg);
866                 } else {
867                     return make_rcp<const Sin>(arg);
868                 }
869             } else {
870                 return mul(minus_one, sin(ret_arg));
871             }
872         }
873     }
874 }
875 
876 /* ---------------------------- */
877 
Cos(const RCP<const Basic> & arg)878 Cos::Cos(const RCP<const Basic> &arg) : TrigFunction(arg)
879 {
880     SYMENGINE_ASSIGN_TYPEID()
881     SYMENGINE_ASSERT(is_canonical(arg))
882 }
883 
is_canonical(const RCP<const Basic> & arg) const884 bool Cos::is_canonical(const RCP<const Basic> &arg) const
885 {
886     // e.g. cos(0)
887     if (is_a<Integer>(*arg) and down_cast<const Integer &>(*arg).is_zero())
888         return false;
889     // e.g cos(k*pi/2)
890     if (trig_has_basic_shift(arg)) {
891         return false;
892     }
893     if (is_a_Number(*arg) and not down_cast<const Number &>(*arg).is_exact()) {
894         return false;
895     }
896     return true;
897 }
898 
cos(const RCP<const Basic> & arg)899 RCP<const Basic> cos(const RCP<const Basic> &arg)
900 {
901     if (eq(*arg, *zero))
902         return one;
903     if (is_a_Number(*arg) and not down_cast<const Number &>(*arg).is_exact()) {
904         return down_cast<const Number &>(*arg).get_eval().cos(*arg);
905     }
906 
907     if (is_a<ACos>(*arg)) {
908         return down_cast<const ACos &>(*arg).get_arg();
909     } else if (is_a<ASec>(*arg)) {
910         return div(one, down_cast<const ASec &>(*arg).get_arg());
911     }
912 
913     RCP<const Basic> ret_arg;
914     int index, sign;
915     bool conjugate = trig_simplify(arg, 2, false, true,           // input
916                                    outArg(ret_arg), index, sign); // output
917 
918     if (conjugate) {
919         // sin has to be returned
920         if (sign == 1) {
921             return sin(ret_arg);
922         } else {
923             return mul(minus_one, sin(ret_arg));
924         }
925     } else {
926         if (eq(*ret_arg, *zero)) {
927             return mul(integer(sign), sin_table[(index + 6) % 24]);
928         } else {
929             if (sign == 1) {
930                 if (neq(*ret_arg, *arg)) {
931                     return cos(ret_arg);
932                 } else {
933                     return make_rcp<const Cos>(ret_arg);
934                 }
935             } else {
936                 return mul(minus_one, cos(ret_arg));
937             }
938         }
939     }
940 }
941 
942 /* ---------------------------- */
943 
Tan(const RCP<const Basic> & arg)944 Tan::Tan(const RCP<const Basic> &arg) : TrigFunction(arg)
945 {
946     SYMENGINE_ASSIGN_TYPEID()
947     SYMENGINE_ASSERT(is_canonical(arg))
948 }
949 
is_canonical(const RCP<const Basic> & arg) const950 bool Tan::is_canonical(const RCP<const Basic> &arg) const
951 {
952     if (is_a<Integer>(*arg) and down_cast<const Integer &>(*arg).is_zero())
953         return false;
954     // e.g tan(k*pi/2)
955     if (trig_has_basic_shift(arg)) {
956         return false;
957     }
958     if (is_a_Number(*arg) and not down_cast<const Number &>(*arg).is_exact()) {
959         return false;
960     }
961     return true;
962 }
963 
tan(const RCP<const Basic> & arg)964 RCP<const Basic> tan(const RCP<const Basic> &arg)
965 {
966     if (eq(*arg, *zero))
967         return zero;
968     if (is_a_Number(*arg) and not down_cast<const Number &>(*arg).is_exact()) {
969         return down_cast<const Number &>(*arg).get_eval().tan(*arg);
970     }
971 
972     if (is_a<ATan>(*arg)) {
973         return down_cast<const ATan &>(*arg).get_arg();
974     } else if (is_a<ACot>(*arg)) {
975         return div(one, down_cast<const ACot &>(*arg).get_arg());
976     }
977 
978     RCP<const Basic> ret_arg;
979     int index, sign;
980     bool conjugate = trig_simplify(arg, 1, true, true,            // input
981                                    outArg(ret_arg), index, sign); // output
982 
983     if (conjugate) {
984         // cot has to be returned
985         if (sign == 1) {
986             return cot(ret_arg);
987         } else {
988             return mul(minus_one, cot(ret_arg));
989         }
990     } else {
991         if (eq(*ret_arg, *zero)) {
992             return mul(integer(sign),
993                        div(sin_table[index], sin_table[(index + 6) % 24]));
994         } else {
995             if (sign == 1) {
996                 if (neq(*ret_arg, *arg)) {
997                     return tan(ret_arg);
998                 } else {
999                     return make_rcp<const Tan>(ret_arg);
1000                 }
1001             } else {
1002                 return mul(minus_one, tan(ret_arg));
1003             }
1004         }
1005     }
1006 }
1007 
1008 /* ---------------------------- */
1009 
Cot(const RCP<const Basic> & arg)1010 Cot::Cot(const RCP<const Basic> &arg) : TrigFunction(arg)
1011 {
1012     SYMENGINE_ASSIGN_TYPEID()
1013     SYMENGINE_ASSERT(is_canonical(arg))
1014 }
1015 
is_canonical(const RCP<const Basic> & arg) const1016 bool Cot::is_canonical(const RCP<const Basic> &arg) const
1017 {
1018     if (is_a<Integer>(*arg) and down_cast<const Integer &>(*arg).is_zero())
1019         return false;
1020     // e.g cot(k*pi/2)
1021     if (trig_has_basic_shift(arg)) {
1022         return false;
1023     }
1024     if (is_a_Number(*arg) and not down_cast<const Number &>(*arg).is_exact()) {
1025         return false;
1026     }
1027     return true;
1028 }
1029 
cot(const RCP<const Basic> & arg)1030 RCP<const Basic> cot(const RCP<const Basic> &arg)
1031 {
1032     if (is_a_Number(*arg) and not down_cast<const Number &>(*arg).is_exact()) {
1033         return down_cast<const Number &>(*arg).get_eval().cot(*arg);
1034     }
1035 
1036     if (is_a<ACot>(*arg)) {
1037         return down_cast<const ACot &>(*arg).get_arg();
1038     } else if (is_a<ATan>(*arg)) {
1039         return div(one, down_cast<const ATan &>(*arg).get_arg());
1040     }
1041 
1042     RCP<const Basic> ret_arg;
1043     int index, sign;
1044     bool conjugate = trig_simplify(arg, 1, true, true,            // input
1045                                    outArg(ret_arg), index, sign); // output
1046 
1047     if (conjugate) {
1048         // tan has to be returned
1049         if (sign == 1) {
1050             return tan(ret_arg);
1051         } else {
1052             return mul(minus_one, tan(ret_arg));
1053         }
1054     } else {
1055         if (eq(*ret_arg, *zero)) {
1056             return mul(integer(sign),
1057                        div(sin_table[(index + 6) % 24], sin_table[index]));
1058         } else {
1059             if (sign == 1) {
1060                 if (neq(*ret_arg, *arg)) {
1061                     return cot(ret_arg);
1062                 } else {
1063                     return make_rcp<const Cot>(ret_arg);
1064                 }
1065             } else {
1066                 return mul(minus_one, cot(ret_arg));
1067             }
1068         }
1069     }
1070 }
1071 
1072 /* ---------------------------- */
1073 
Csc(const RCP<const Basic> & arg)1074 Csc::Csc(const RCP<const Basic> &arg) : TrigFunction(arg)
1075 {
1076     SYMENGINE_ASSIGN_TYPEID()
1077     SYMENGINE_ASSERT(is_canonical(arg))
1078 }
1079 
is_canonical(const RCP<const Basic> & arg) const1080 bool Csc::is_canonical(const RCP<const Basic> &arg) const
1081 {
1082     // e.g. Csc(0)
1083     if (is_a<Integer>(*arg) and down_cast<const Integer &>(*arg).is_zero())
1084         return false;
1085     // e.g csc(k*pi/2)
1086     if (trig_has_basic_shift(arg)) {
1087         return false;
1088     }
1089     if (is_a_Number(*arg) and not down_cast<const Number &>(*arg).is_exact()) {
1090         return false;
1091     }
1092     return true;
1093 }
1094 
csc(const RCP<const Basic> & arg)1095 RCP<const Basic> csc(const RCP<const Basic> &arg)
1096 {
1097     if (is_a_Number(*arg) and not down_cast<const Number &>(*arg).is_exact()) {
1098         return down_cast<const Number &>(*arg).get_eval().csc(*arg);
1099     }
1100 
1101     if (is_a<ACsc>(*arg)) {
1102         return down_cast<const ACsc &>(*arg).get_arg();
1103     } else if (is_a<ASin>(*arg)) {
1104         return div(one, down_cast<const ASin &>(*arg).get_arg());
1105     }
1106 
1107     RCP<const Basic> ret_arg;
1108     int index, sign;
1109     bool conjugate = trig_simplify(arg, 2, true, false,           // input
1110                                    outArg(ret_arg), index, sign); // output
1111 
1112     if (conjugate) {
1113         // cos has to be returned
1114         if (sign == 1) {
1115             return sec(ret_arg);
1116         } else {
1117             return mul(minus_one, sec(ret_arg));
1118         }
1119     } else {
1120         if (eq(*ret_arg, *zero)) {
1121             return mul(integer(sign), div(one, sin_table[index]));
1122         } else {
1123             if (sign == 1) {
1124                 if (neq(*ret_arg, *arg)) {
1125                     return csc(ret_arg);
1126                 } else {
1127                     return make_rcp<const Csc>(ret_arg);
1128                 }
1129             } else {
1130                 return mul(minus_one, csc(ret_arg));
1131             }
1132         }
1133     }
1134 }
1135 
1136 /* ---------------------------- */
1137 
Sec(const RCP<const Basic> & arg)1138 Sec::Sec(const RCP<const Basic> &arg) : TrigFunction(arg)
1139 {
1140     SYMENGINE_ASSIGN_TYPEID()
1141     SYMENGINE_ASSERT(is_canonical(arg))
1142 }
1143 
is_canonical(const RCP<const Basic> & arg) const1144 bool Sec::is_canonical(const RCP<const Basic> &arg) const
1145 {
1146     // e.g. Sec(0)
1147     if (is_a<Integer>(*arg) and down_cast<const Integer &>(*arg).is_zero())
1148         return false;
1149     // e.g sec(k*pi/2)
1150     if (trig_has_basic_shift(arg)) {
1151         return false;
1152     }
1153     if (is_a_Number(*arg) and not down_cast<const Number &>(*arg).is_exact()) {
1154         return false;
1155     }
1156     return true;
1157 }
1158 
sec(const RCP<const Basic> & arg)1159 RCP<const Basic> sec(const RCP<const Basic> &arg)
1160 {
1161     if (is_a_Number(*arg) and not down_cast<const Number &>(*arg).is_exact()) {
1162         return down_cast<const Number &>(*arg).get_eval().sec(*arg);
1163     }
1164 
1165     if (is_a<ASec>(*arg)) {
1166         return down_cast<const ASec &>(*arg).get_arg();
1167     } else if (is_a<ACos>(*arg)) {
1168         return div(one, down_cast<const ACos &>(*arg).get_arg());
1169     }
1170 
1171     RCP<const Basic> ret_arg;
1172     int index, sign;
1173     bool conjugate = trig_simplify(arg, 2, false, true,           // input
1174                                    outArg(ret_arg), index, sign); // output
1175 
1176     if (conjugate) {
1177         // csc has to be returned
1178         if (sign == 1) {
1179             return csc(ret_arg);
1180         } else {
1181             return mul(minus_one, csc(ret_arg));
1182         }
1183     } else {
1184         if (eq(*ret_arg, *zero)) {
1185             return mul(integer(sign), div(one, sin_table[(index + 6) % 24]));
1186         } else {
1187             if (sign == 1) {
1188                 if (neq(*ret_arg, *arg)) {
1189                     return sec(ret_arg);
1190                 } else {
1191                     return make_rcp<const Sec>(ret_arg);
1192                 }
1193             } else {
1194                 return mul(minus_one, sec(ret_arg));
1195             }
1196         }
1197     }
1198 }
1199 /* ---------------------------- */
1200 
1201 // simplifies trigonometric functions wherever possible
1202 // currently deals with simplifications of type sin(acos())
trig_to_sqrt(const RCP<const Basic> & arg)1203 RCP<const Basic> trig_to_sqrt(const RCP<const Basic> &arg)
1204 {
1205     RCP<const Basic> i_arg;
1206 
1207     if (is_a<Sin>(*arg)) {
1208         if (is_a<ACos>(*arg->get_args()[0])) {
1209             i_arg = down_cast<const ACos &>(*(arg->get_args()[0])).get_arg();
1210             return sqrt(sub(one, pow(i_arg, i2)));
1211         } else if (is_a<ATan>(*arg->get_args()[0])) {
1212             i_arg = down_cast<const ATan &>(*(arg->get_args()[0])).get_arg();
1213             return div(i_arg, sqrt(add(one, pow(i_arg, i2))));
1214         } else if (is_a<ASec>(*arg->get_args()[0])) {
1215             i_arg = down_cast<const ASec &>(*(arg->get_args()[0])).get_arg();
1216             return sqrt(sub(one, pow(i_arg, im2)));
1217         } else if (is_a<ACot>(*arg->get_args()[0])) {
1218             i_arg = down_cast<const ACot &>(*(arg->get_args()[0])).get_arg();
1219             return div(one, mul(i_arg, sqrt(add(one, pow(i_arg, im2)))));
1220         }
1221     } else if (is_a<Cos>(*arg)) {
1222         if (is_a<ASin>(*arg->get_args()[0])) {
1223             i_arg = down_cast<const ASin &>(*(arg->get_args()[0])).get_arg();
1224             return sqrt(sub(one, pow(i_arg, i2)));
1225         } else if (is_a<ATan>(*arg->get_args()[0])) {
1226             i_arg = down_cast<const ATan &>(*(arg->get_args()[0])).get_arg();
1227             return div(one, sqrt(add(one, pow(i_arg, i2))));
1228         } else if (is_a<ACsc>(*arg->get_args()[0])) {
1229             i_arg = down_cast<const ACsc &>(*(arg->get_args()[0])).get_arg();
1230             return sqrt(sub(one, pow(i_arg, im2)));
1231         } else if (is_a<ACot>(*arg->get_args()[0])) {
1232             i_arg = down_cast<const ACot &>(*(arg->get_args()[0])).get_arg();
1233             return div(one, sqrt(add(one, pow(i_arg, im2))));
1234         }
1235     } else if (is_a<Tan>(*arg)) {
1236         if (is_a<ASin>(*arg->get_args()[0])) {
1237             i_arg = down_cast<const ASin &>(*(arg->get_args()[0])).get_arg();
1238             return div(i_arg, sqrt(sub(one, pow(i_arg, i2))));
1239         } else if (is_a<ACos>(*arg->get_args()[0])) {
1240             i_arg = down_cast<const ACos &>(*(arg->get_args()[0])).get_arg();
1241             return div(sqrt(sub(one, pow(i_arg, i2))), i_arg);
1242         } else if (is_a<ACsc>(*arg->get_args()[0])) {
1243             i_arg = down_cast<const ACsc &>(*(arg->get_args()[0])).get_arg();
1244             return div(one, mul(i_arg, sqrt(sub(one, pow(i_arg, im2)))));
1245         } else if (is_a<ASec>(*arg->get_args()[0])) {
1246             i_arg = down_cast<const ASec &>(*(arg->get_args()[0])).get_arg();
1247             return mul(i_arg, sqrt(sub(one, pow(i_arg, im2))));
1248         }
1249     } else if (is_a<Csc>(*arg)) {
1250         if (is_a<ACos>(*arg->get_args()[0])) {
1251             i_arg = down_cast<const ACos &>(*(arg->get_args()[0])).get_arg();
1252             return div(one, sqrt(sub(one, pow(i_arg, i2))));
1253         } else if (is_a<ATan>(*arg->get_args()[0])) {
1254             i_arg = down_cast<const ATan &>(*(arg->get_args()[0])).get_arg();
1255             return div(sqrt(add(one, pow(i_arg, i2))), i_arg);
1256         } else if (is_a<ASec>(*arg->get_args()[0])) {
1257             i_arg = down_cast<const ASec &>(*(arg->get_args()[0])).get_arg();
1258             return div(one, sqrt(sub(one, pow(i_arg, im2))));
1259         } else if (is_a<ACot>(*arg->get_args()[0])) {
1260             i_arg = down_cast<const ACot &>(*(arg->get_args()[0])).get_arg();
1261             return mul(i_arg, sqrt(add(one, pow(i_arg, im2))));
1262         }
1263     } else if (is_a<Sec>(*arg)) {
1264         if (is_a<ASin>(*arg->get_args()[0])) {
1265             i_arg = down_cast<const ASin &>(*(arg->get_args()[0])).get_arg();
1266             return div(one, sqrt(sub(one, pow(i_arg, i2))));
1267         } else if (is_a<ATan>(*arg->get_args()[0])) {
1268             i_arg = down_cast<const ATan &>(*(arg->get_args()[0])).get_arg();
1269             return sqrt(add(one, pow(i_arg, i2)));
1270         } else if (is_a<ACsc>(*arg->get_args()[0])) {
1271             i_arg = down_cast<const ACsc &>(*(arg->get_args()[0])).get_arg();
1272             return div(one, sqrt(sub(one, pow(i_arg, im2))));
1273         } else if (is_a<ACot>(*arg->get_args()[0])) {
1274             i_arg = down_cast<const ACot &>(*(arg->get_args()[0])).get_arg();
1275             return sqrt(add(one, pow(i_arg, im2)));
1276         }
1277     } else if (is_a<Cot>(*arg)) {
1278         if (is_a<ASin>(*arg->get_args()[0])) {
1279             i_arg = down_cast<const ASin &>(*(arg->get_args()[0])).get_arg();
1280             return div(sqrt(sub(one, pow(i_arg, i2))), i_arg);
1281         } else if (is_a<ACos>(*arg->get_args()[0])) {
1282             i_arg = down_cast<const ACos &>(*(arg->get_args()[0])).get_arg();
1283             return div(i_arg, sqrt(sub(one, pow(i_arg, i2))));
1284         } else if (is_a<ACsc>(*arg->get_args()[0])) {
1285             i_arg = down_cast<const ACsc &>(*(arg->get_args()[0])).get_arg();
1286             return mul(i_arg, sqrt(sub(one, pow(i_arg, im2))));
1287         } else if (is_a<ASec>(*arg->get_args()[0])) {
1288             i_arg = down_cast<const ASec &>(*(arg->get_args()[0])).get_arg();
1289             return div(one, mul(i_arg, sqrt(sub(one, pow(i_arg, im2)))));
1290         }
1291     }
1292 
1293     return arg;
1294 }
1295 
1296 /* ---------------------------- */
ASin(const RCP<const Basic> & arg)1297 ASin::ASin(const RCP<const Basic> &arg) : InverseTrigFunction(arg)
1298 {
1299     SYMENGINE_ASSIGN_TYPEID()
1300     SYMENGINE_ASSERT(is_canonical(arg))
1301 }
1302 
is_canonical(const RCP<const Basic> & arg) const1303 bool ASin::is_canonical(const RCP<const Basic> &arg) const
1304 {
1305     if (eq(*arg, *zero) or eq(*arg, *one) or eq(*arg, *minus_one))
1306         return false;
1307     RCP<const Basic> index;
1308     if (inverse_lookup(inverse_cst, get_arg(), outArg(index))) {
1309         return false;
1310     }
1311     if (is_a_Number(*arg) and not down_cast<const Number &>(*arg).is_exact()) {
1312         return false;
1313     }
1314     return true;
1315 }
1316 
asin(const RCP<const Basic> & arg)1317 RCP<const Basic> asin(const RCP<const Basic> &arg)
1318 {
1319     if (eq(*arg, *zero))
1320         return zero;
1321     else if (eq(*arg, *one))
1322         return div(pi, i2);
1323     else if (eq(*arg, *minus_one))
1324         return mul(minus_one, div(pi, i2));
1325     else if (is_a_Number(*arg)
1326              and not down_cast<const Number &>(*arg).is_exact()) {
1327         return down_cast<const Number &>(*arg).get_eval().asin(*arg);
1328     }
1329 
1330     RCP<const Basic> index;
1331     bool b = inverse_lookup(inverse_cst, arg, outArg(index));
1332     if (b) {
1333         return div(pi, index);
1334     } else {
1335         return make_rcp<const ASin>(arg);
1336     }
1337 }
1338 
ACos(const RCP<const Basic> & arg)1339 ACos::ACos(const RCP<const Basic> &arg) : InverseTrigFunction(arg)
1340 {
1341     SYMENGINE_ASSIGN_TYPEID()
1342     SYMENGINE_ASSERT(is_canonical(arg))
1343 }
1344 
is_canonical(const RCP<const Basic> & arg) const1345 bool ACos::is_canonical(const RCP<const Basic> &arg) const
1346 {
1347     if (eq(*arg, *zero) or eq(*arg, *one) or eq(*arg, *minus_one))
1348         return false;
1349     RCP<const Basic> index;
1350     if (inverse_lookup(inverse_cst, get_arg(), outArg(index))) {
1351         return false;
1352     }
1353     if (is_a_Number(*arg) and not down_cast<const Number &>(*arg).is_exact()) {
1354         return false;
1355     }
1356     return true;
1357 }
1358 
acos(const RCP<const Basic> & arg)1359 RCP<const Basic> acos(const RCP<const Basic> &arg)
1360 {
1361     if (eq(*arg, *zero))
1362         return div(pi, i2);
1363     else if (eq(*arg, *one))
1364         return zero;
1365     else if (eq(*arg, *minus_one))
1366         return pi;
1367     else if (is_a_Number(*arg)
1368              and not down_cast<const Number &>(*arg).is_exact()) {
1369         return down_cast<const Number &>(*arg).get_eval().acos(*arg);
1370     }
1371 
1372     RCP<const Basic> index;
1373     bool b = inverse_lookup(inverse_cst, arg, outArg(index));
1374     if (b) {
1375         return sub(div(pi, i2), div(pi, index));
1376     } else {
1377         return make_rcp<const ACos>(arg);
1378     }
1379 }
1380 
ASec(const RCP<const Basic> & arg)1381 ASec::ASec(const RCP<const Basic> &arg) : InverseTrigFunction(arg)
1382 {
1383     SYMENGINE_ASSIGN_TYPEID()
1384     SYMENGINE_ASSERT(is_canonical(arg))
1385 }
1386 
is_canonical(const RCP<const Basic> & arg) const1387 bool ASec::is_canonical(const RCP<const Basic> &arg) const
1388 {
1389     if (eq(*arg, *one) or eq(*arg, *minus_one))
1390         return false;
1391     RCP<const Basic> index;
1392     if (inverse_lookup(inverse_cst, div(one, get_arg()), outArg(index))) {
1393         return false;
1394     } else if (is_a_Number(*arg)
1395                and not down_cast<const Number &>(*arg).is_exact()) {
1396         return false;
1397     }
1398     return true;
1399 }
1400 
asec(const RCP<const Basic> & arg)1401 RCP<const Basic> asec(const RCP<const Basic> &arg)
1402 {
1403     if (eq(*arg, *one))
1404         return zero;
1405     else if (eq(*arg, *minus_one))
1406         return pi;
1407     else if (is_a_Number(*arg)
1408              and not down_cast<const Number &>(*arg).is_exact()) {
1409         return down_cast<const Number &>(*arg).get_eval().asec(*arg);
1410     }
1411 
1412     RCP<const Basic> index;
1413     bool b = inverse_lookup(inverse_cst, div(one, arg), outArg(index));
1414     if (b) {
1415         return sub(div(pi, i2), div(pi, index));
1416     } else {
1417         return make_rcp<const ASec>(arg);
1418     }
1419 }
1420 
ACsc(const RCP<const Basic> & arg)1421 ACsc::ACsc(const RCP<const Basic> &arg) : InverseTrigFunction(arg)
1422 {
1423     SYMENGINE_ASSIGN_TYPEID()
1424     SYMENGINE_ASSERT(is_canonical(arg))
1425 }
1426 
is_canonical(const RCP<const Basic> & arg) const1427 bool ACsc::is_canonical(const RCP<const Basic> &arg) const
1428 {
1429     if (eq(*arg, *one) or eq(*arg, *minus_one))
1430         return false;
1431     RCP<const Basic> index;
1432     if (inverse_lookup(inverse_cst, div(one, arg), outArg(index))) {
1433         return false;
1434     }
1435     if (is_a_Number(*arg) and not down_cast<const Number &>(*arg).is_exact()) {
1436         return false;
1437     }
1438     return true;
1439 }
1440 
acsc(const RCP<const Basic> & arg)1441 RCP<const Basic> acsc(const RCP<const Basic> &arg)
1442 {
1443     if (eq(*arg, *one))
1444         return div(pi, i2);
1445     else if (eq(*arg, *minus_one))
1446         return div(pi, im2);
1447     else if (is_a_Number(*arg)
1448              and not down_cast<const Number &>(*arg).is_exact()) {
1449         return down_cast<const Number &>(*arg).get_eval().acsc(*arg);
1450     }
1451 
1452     RCP<const Basic> index;
1453     bool b = inverse_lookup(inverse_cst, div(one, arg), outArg(index));
1454     if (b) {
1455         return div(pi, index);
1456     } else {
1457         return make_rcp<const ACsc>(arg);
1458     }
1459 }
1460 
ATan(const RCP<const Basic> & arg)1461 ATan::ATan(const RCP<const Basic> &arg) : InverseTrigFunction(arg)
1462 {
1463     SYMENGINE_ASSIGN_TYPEID()
1464     SYMENGINE_ASSERT(is_canonical(arg))
1465 }
1466 
is_canonical(const RCP<const Basic> & arg) const1467 bool ATan::is_canonical(const RCP<const Basic> &arg) const
1468 {
1469     if (eq(*arg, *zero) or eq(*arg, *one) or eq(*arg, *minus_one))
1470         return false;
1471     RCP<const Basic> index;
1472     if (inverse_lookup(inverse_tct, get_arg(), outArg(index))) {
1473         return false;
1474     }
1475     if (is_a_Number(*arg) and not down_cast<const Number &>(*arg).is_exact()) {
1476         return false;
1477     }
1478     return true;
1479 }
1480 
atan(const RCP<const Basic> & arg)1481 RCP<const Basic> atan(const RCP<const Basic> &arg)
1482 {
1483     if (eq(*arg, *zero))
1484         return zero;
1485     else if (eq(*arg, *one))
1486         return div(pi, mul(i2, i2));
1487     else if (eq(*arg, *minus_one))
1488         return mul(minus_one, div(pi, mul(i2, i2)));
1489     else if (is_a_Number(*arg)
1490              and not down_cast<const Number &>(*arg).is_exact()) {
1491         return down_cast<const Number &>(*arg).get_eval().atan(*arg);
1492     }
1493 
1494     RCP<const Basic> index;
1495     bool b = inverse_lookup(inverse_tct, arg, outArg(index));
1496     if (b) {
1497         return div(pi, index);
1498     } else {
1499         return make_rcp<const ATan>(arg);
1500     }
1501 }
1502 
ACot(const RCP<const Basic> & arg)1503 ACot::ACot(const RCP<const Basic> &arg) : InverseTrigFunction(arg)
1504 {
1505     SYMENGINE_ASSIGN_TYPEID()
1506     SYMENGINE_ASSERT(is_canonical(arg))
1507 }
1508 
is_canonical(const RCP<const Basic> & arg) const1509 bool ACot::is_canonical(const RCP<const Basic> &arg) const
1510 {
1511     if (eq(*arg, *zero) or eq(*arg, *one) or eq(*arg, *minus_one))
1512         return false;
1513     RCP<const Basic> index;
1514     if (inverse_lookup(inverse_tct, arg, outArg(index))) {
1515         return false;
1516     }
1517     if (is_a_Number(*arg) and not down_cast<const Number &>(*arg).is_exact()) {
1518         return false;
1519     }
1520     return true;
1521 }
1522 
acot(const RCP<const Basic> & arg)1523 RCP<const Basic> acot(const RCP<const Basic> &arg)
1524 {
1525     if (eq(*arg, *zero))
1526         return div(pi, i2);
1527     else if (eq(*arg, *one))
1528         return div(pi, mul(i2, i2));
1529     else if (eq(*arg, *minus_one))
1530         return mul(i3, div(pi, mul(i2, i2)));
1531     else if (is_a_Number(*arg)
1532              and not down_cast<const Number &>(*arg).is_exact()) {
1533         return down_cast<const Number &>(*arg).get_eval().acot(*arg);
1534     }
1535 
1536     RCP<const Basic> index;
1537     bool b = inverse_lookup(inverse_tct, arg, outArg(index));
1538     if (b) {
1539         return sub(div(pi, i2), div(pi, index));
1540     } else {
1541         return make_rcp<const ACot>(arg);
1542     }
1543 }
1544 
ATan2(const RCP<const Basic> & num,const RCP<const Basic> & den)1545 ATan2::ATan2(const RCP<const Basic> &num, const RCP<const Basic> &den)
1546     : TwoArgFunction(num, den)
1547 {
1548     SYMENGINE_ASSIGN_TYPEID()
1549     SYMENGINE_ASSERT(is_canonical(num, den))
1550 }
1551 
is_canonical(const RCP<const Basic> & num,const RCP<const Basic> & den) const1552 bool ATan2::is_canonical(const RCP<const Basic> &num,
1553                          const RCP<const Basic> &den) const
1554 {
1555     if (eq(*num, *zero) or eq(*num, *den) or eq(*num, *mul(minus_one, den)))
1556         return false;
1557     RCP<const Basic> index;
1558     bool b = inverse_lookup(inverse_tct, div(num, den), outArg(index));
1559     if (b)
1560         return false;
1561     else
1562         return true;
1563 }
1564 
create(const RCP<const Basic> & a,const RCP<const Basic> & b) const1565 RCP<const Basic> ATan2::create(const RCP<const Basic> &a,
1566                                const RCP<const Basic> &b) const
1567 {
1568     return atan2(a, b);
1569 }
1570 
atan2(const RCP<const Basic> & num,const RCP<const Basic> & den)1571 RCP<const Basic> atan2(const RCP<const Basic> &num, const RCP<const Basic> &den)
1572 {
1573     if (eq(*num, *zero)) {
1574         if (is_a_Number(*den)) {
1575             RCP<const Number> den_new = rcp_static_cast<const Number>(den);
1576             if (den_new->is_negative())
1577                 return pi;
1578             else if (den_new->is_positive())
1579                 return zero;
1580             else {
1581                 return Nan;
1582             }
1583         }
1584     } else if (eq(*den, *zero)) {
1585         if (is_a_Number(*num)) {
1586             RCP<const Number> num_new = rcp_static_cast<const Number>(num);
1587             if (num_new->is_negative())
1588                 return div(pi, im2);
1589             else
1590                 return div(pi, i2);
1591         }
1592     }
1593     RCP<const Basic> index;
1594     bool b = inverse_lookup(inverse_tct, div(num, den), outArg(index));
1595     if (b) {
1596         // Ideally the answer should depend on the signs of `num` and `den`
1597         // Currently is_positive() and is_negative() is not implemented for
1598         // types other than `Number`
1599         // Hence this will give exact answers in case when num and den are
1600         // numbers in SymEngine sense and when num and den are positive.
1601         // for the remaining cases in which we just return the value from
1602         // the lookup table.
1603         // TODO: update once is_positive() and is_negative() is implemented
1604         // in `Basic`
1605         if (is_a_Number(*den) and is_a_Number(*num)) {
1606             RCP<const Number> den_new = rcp_static_cast<const Number>(den);
1607             RCP<const Number> num_new = rcp_static_cast<const Number>(num);
1608 
1609             if (den_new->is_positive()) {
1610                 return div(pi, index);
1611             } else if (den_new->is_negative()) {
1612                 if (num_new->is_negative()) {
1613                     return sub(div(pi, index), pi);
1614                 } else {
1615                     return add(div(pi, index), pi);
1616                 }
1617             } else {
1618                 return div(pi, index);
1619             }
1620         } else {
1621             return div(pi, index);
1622         }
1623     } else {
1624         return make_rcp<const ATan2>(num, den);
1625     }
1626 }
1627 
1628 /* ---------------------------- */
1629 
create(const RCP<const Basic> & arg) const1630 RCP<const Basic> Sin::create(const RCP<const Basic> &arg) const
1631 {
1632     return sin(arg);
1633 }
1634 
create(const RCP<const Basic> & arg) const1635 RCP<const Basic> Cos::create(const RCP<const Basic> &arg) const
1636 {
1637     return cos(arg);
1638 }
1639 
create(const RCP<const Basic> & arg) const1640 RCP<const Basic> Tan::create(const RCP<const Basic> &arg) const
1641 {
1642     return tan(arg);
1643 }
1644 
create(const RCP<const Basic> & arg) const1645 RCP<const Basic> Cot::create(const RCP<const Basic> &arg) const
1646 {
1647     return cot(arg);
1648 }
1649 
create(const RCP<const Basic> & arg) const1650 RCP<const Basic> Sec::create(const RCP<const Basic> &arg) const
1651 {
1652     return sec(arg);
1653 }
1654 
create(const RCP<const Basic> & arg) const1655 RCP<const Basic> Csc::create(const RCP<const Basic> &arg) const
1656 {
1657     return csc(arg);
1658 }
1659 
create(const RCP<const Basic> & arg) const1660 RCP<const Basic> ASin::create(const RCP<const Basic> &arg) const
1661 {
1662     return asin(arg);
1663 }
1664 
create(const RCP<const Basic> & arg) const1665 RCP<const Basic> ACos::create(const RCP<const Basic> &arg) const
1666 {
1667     return acos(arg);
1668 }
1669 
create(const RCP<const Basic> & arg) const1670 RCP<const Basic> ATan::create(const RCP<const Basic> &arg) const
1671 {
1672     return atan(arg);
1673 }
1674 
create(const RCP<const Basic> & arg) const1675 RCP<const Basic> ACot::create(const RCP<const Basic> &arg) const
1676 {
1677     return acot(arg);
1678 }
1679 
create(const RCP<const Basic> & arg) const1680 RCP<const Basic> ASec::create(const RCP<const Basic> &arg) const
1681 {
1682     return asec(arg);
1683 }
1684 
create(const RCP<const Basic> & arg) const1685 RCP<const Basic> ACsc::create(const RCP<const Basic> &arg) const
1686 {
1687     return acsc(arg);
1688 }
1689 
1690 /* ---------------------------- */
1691 
Log(const RCP<const Basic> & arg)1692 Log::Log(const RCP<const Basic> &arg) : OneArgFunction(arg)
1693 {
1694     SYMENGINE_ASSIGN_TYPEID()
1695     SYMENGINE_ASSERT(is_canonical(arg))
1696 }
1697 
is_canonical(const RCP<const Basic> & arg) const1698 bool Log::is_canonical(const RCP<const Basic> &arg) const
1699 {
1700     //  log(0)
1701     if (is_a<Integer>(*arg) and down_cast<const Integer &>(*arg).is_zero())
1702         return false;
1703     //  log(1)
1704     if (is_a<Integer>(*arg) and down_cast<const Integer &>(*arg).is_one())
1705         return false;
1706     // log(E)
1707     if (eq(*arg, *E))
1708         return false;
1709 
1710     if (is_a_Number(*arg) and down_cast<const Number &>(*arg).is_negative())
1711         return false;
1712 
1713     // log(Inf) is also handled here.
1714     if (is_a_Number(*arg) and not down_cast<const Number &>(*arg).is_exact())
1715         return false;
1716 
1717     // log(3I) should be expanded to log(3) + I*pi/2
1718     if (is_a<Complex>(*arg) and down_cast<const Complex &>(*arg).is_re_zero())
1719         return false;
1720     // log(num/den) = log(num) - log(den)
1721     if (is_a<Rational>(*arg))
1722         return false;
1723     return true;
1724 }
1725 
create(const RCP<const Basic> & a) const1726 RCP<const Basic> Log::create(const RCP<const Basic> &a) const
1727 {
1728     return log(a);
1729 }
1730 
log(const RCP<const Basic> & arg)1731 RCP<const Basic> log(const RCP<const Basic> &arg)
1732 {
1733     if (eq(*arg, *zero))
1734         return ComplexInf;
1735     if (eq(*arg, *one))
1736         return zero;
1737     if (eq(*arg, *E))
1738         return one;
1739 
1740     if (is_a_Number(*arg)) {
1741         RCP<const Number> _arg = rcp_static_cast<const Number>(arg);
1742         if (not _arg->is_exact()) {
1743             return _arg->get_eval().log(*_arg);
1744         } else if (_arg->is_negative()) {
1745             return add(log(mul(minus_one, _arg)), mul(pi, I));
1746         }
1747     }
1748 
1749     if (is_a<Rational>(*arg)) {
1750         RCP<const Integer> num, den;
1751         get_num_den(down_cast<const Rational &>(*arg), outArg(num),
1752                     outArg(den));
1753         return sub(log(num), log(den));
1754     }
1755 
1756     if (is_a<Complex>(*arg)) {
1757         RCP<const Complex> _arg = rcp_static_cast<const Complex>(arg);
1758         if (_arg->is_re_zero()) {
1759             RCP<const Number> arg_img = _arg->imaginary_part();
1760             if (arg_img->is_negative()) {
1761                 return sub(log(mul(minus_one, arg_img)),
1762                            mul(I, div(pi, integer(2))));
1763             } else if (arg_img->is_zero()) {
1764                 return ComplexInf;
1765             } else if (arg_img->is_positive()) {
1766                 return add(log(arg_img), mul(I, div(pi, integer(2))));
1767             }
1768         }
1769     }
1770 
1771     return make_rcp<const Log>(arg);
1772 }
1773 
log(const RCP<const Basic> & arg,const RCP<const Basic> & base)1774 RCP<const Basic> log(const RCP<const Basic> &arg, const RCP<const Basic> &base)
1775 {
1776     return div(log(arg), log(base));
1777 }
1778 
LambertW(const RCP<const Basic> & arg)1779 LambertW::LambertW(const RCP<const Basic> &arg) : OneArgFunction{arg}
1780 {
1781     SYMENGINE_ASSIGN_TYPEID()
1782     SYMENGINE_ASSERT(is_canonical(arg))
1783 }
1784 
is_canonical(const RCP<const Basic> & arg) const1785 bool LambertW::is_canonical(const RCP<const Basic> &arg) const
1786 {
1787     if (eq(*arg, *zero))
1788         return false;
1789     if (eq(*arg, *E))
1790         return false;
1791     if (eq(*arg, *div(neg(one), E)))
1792         return false;
1793     if (eq(*arg, *div(log(i2), im2)))
1794         return false;
1795     return true;
1796 }
1797 
create(const RCP<const Basic> & arg) const1798 RCP<const Basic> LambertW::create(const RCP<const Basic> &arg) const
1799 {
1800     return lambertw(arg);
1801 }
1802 
lambertw(const RCP<const Basic> & arg)1803 RCP<const Basic> lambertw(const RCP<const Basic> &arg)
1804 {
1805     if (eq(*arg, *zero))
1806         return zero;
1807     if (eq(*arg, *E))
1808         return one;
1809     if (eq(*arg, *div(neg(one), E)))
1810         return minus_one;
1811     if (eq(*arg, *div(log(i2), im2)))
1812         return mul(minus_one, log(i2));
1813     return make_rcp<const LambertW>(arg);
1814 }
1815 
FunctionSymbol(std::string name,const RCP<const Basic> & arg)1816 FunctionSymbol::FunctionSymbol(std::string name, const RCP<const Basic> &arg)
1817     : MultiArgFunction({arg}), name_{name} {SYMENGINE_ASSIGN_TYPEID()
1818                                                 SYMENGINE_ASSERT(
1819                                                     is_canonical(get_vec()))}
1820 
FunctionSymbol(std::string name,const vec_basic & arg)1821       FunctionSymbol::FunctionSymbol(std::string name, const vec_basic &arg)
1822     : MultiArgFunction(arg), name_{name}
1823 {
1824     SYMENGINE_ASSIGN_TYPEID()
1825     SYMENGINE_ASSERT(is_canonical(get_vec()))
1826 }
1827 
is_canonical(const vec_basic & arg) const1828 bool FunctionSymbol::is_canonical(const vec_basic &arg) const
1829 {
1830     return true;
1831 }
1832 
__hash__() const1833 hash_t FunctionSymbol::__hash__() const
1834 {
1835     hash_t seed = SYMENGINE_FUNCTIONSYMBOL;
1836     for (const auto &a : get_vec())
1837         hash_combine<Basic>(seed, *a);
1838     hash_combine<std::string>(seed, name_);
1839     return seed;
1840 }
1841 
__eq__(const Basic & o) const1842 bool FunctionSymbol::__eq__(const Basic &o) const
1843 {
1844     if (is_a<FunctionSymbol>(o)
1845         and name_ == down_cast<const FunctionSymbol &>(o).name_
1846         and unified_eq(get_vec(),
1847                        down_cast<const FunctionSymbol &>(o).get_vec()))
1848         return true;
1849     return false;
1850 }
1851 
compare(const Basic & o) const1852 int FunctionSymbol::compare(const Basic &o) const
1853 {
1854     SYMENGINE_ASSERT(is_a<FunctionSymbol>(o))
1855     const FunctionSymbol &s = down_cast<const FunctionSymbol &>(o);
1856     if (name_ == s.name_)
1857         return unified_compare(get_vec(), s.get_vec());
1858     else
1859         return name_ < s.name_ ? -1 : 1;
1860 }
1861 
create(const vec_basic & x) const1862 RCP<const Basic> FunctionSymbol::create(const vec_basic &x) const
1863 {
1864     return make_rcp<const FunctionSymbol>(name_, x);
1865 }
1866 
function_symbol(std::string name,const vec_basic & arg)1867 RCP<const Basic> function_symbol(std::string name, const vec_basic &arg)
1868 {
1869     return make_rcp<const FunctionSymbol>(name, arg);
1870 }
1871 
function_symbol(std::string name,const RCP<const Basic> & arg)1872 RCP<const Basic> function_symbol(std::string name, const RCP<const Basic> &arg)
1873 {
1874     return make_rcp<const FunctionSymbol>(name, arg);
1875 }
1876 
FunctionWrapper(std::string name,const RCP<const Basic> & arg)1877 FunctionWrapper::FunctionWrapper(std::string name, const RCP<const Basic> &arg)
1878     : FunctionSymbol(name, arg){SYMENGINE_ASSIGN_TYPEID()}
1879 
FunctionWrapper(std::string name,const vec_basic & vec)1880       FunctionWrapper::FunctionWrapper(std::string name, const vec_basic &vec)
1881     : FunctionSymbol(name, vec){SYMENGINE_ASSIGN_TYPEID()}
1882 
1883       /* ---------------------------- */
1884 
Derivative(const RCP<const Basic> & arg,const multiset_basic & x)1885       Derivative::Derivative(const RCP<const Basic> &arg,
1886                              const multiset_basic &x)
1887     : arg_{arg}, x_{x}
1888 {
1889     SYMENGINE_ASSIGN_TYPEID()
1890     SYMENGINE_ASSERT(is_canonical(arg, x))
1891 }
1892 
is_canonical(const RCP<const Basic> & arg,const multiset_basic & x) const1893 bool Derivative::is_canonical(const RCP<const Basic> &arg,
1894                               const multiset_basic &x) const
1895 {
1896     // Check that 'x' are Symbols:
1897     for (const auto &a : x)
1898         if (not is_a<Symbol>(*a))
1899             return false;
1900     if (is_a<FunctionSymbol>(*arg) or is_a<LeviCivita>(*arg)) {
1901         for (auto &p : x) {
1902             RCP<const Symbol> s = rcp_static_cast<const Symbol>(p);
1903             RCP<const MultiArgFunction> f
1904                 = rcp_static_cast<const MultiArgFunction>(arg);
1905             bool found_s = false;
1906             // 's' should be one of the args of the function
1907             // and should not appear anywhere else.
1908             for (const auto &a : f->get_args()) {
1909                 if (eq(*a, *s)) {
1910                     if (found_s) {
1911                         return false;
1912                     } else {
1913                         found_s = true;
1914                     }
1915                 } else if (neq(*a->diff(s), *zero)) {
1916                     return false;
1917                 }
1918             }
1919             if (!found_s) {
1920                 return false;
1921             }
1922         }
1923         return true;
1924     } else if (is_a<Abs>(*arg)) {
1925         return true;
1926     } else if (is_a<FunctionWrapper>(*arg)) {
1927         return true;
1928     } else if (is_a<PolyGamma>(*arg) or is_a<Zeta>(*arg)
1929                or is_a<UpperGamma>(*arg) or is_a<LowerGamma>(*arg)
1930                or is_a<Dirichlet_eta>(*arg)) {
1931         bool found = false;
1932         auto v = arg->get_args();
1933         for (auto &p : x) {
1934             if (has_symbol(*v[0], *rcp_static_cast<const Symbol>(p))) {
1935                 found = true;
1936                 break;
1937             }
1938         }
1939         return found;
1940     } else if (is_a<KroneckerDelta>(*arg)) {
1941         bool found = false;
1942         auto v = arg->get_args();
1943         for (auto &p : x) {
1944             if (has_symbol(*v[0], *rcp_static_cast<const Symbol>(p))
1945                 or has_symbol(*v[1], *rcp_static_cast<const Symbol>(p))) {
1946                 found = true;
1947                 break;
1948             }
1949         }
1950         return found;
1951     }
1952     return false;
1953 }
1954 
__hash__() const1955 hash_t Derivative::__hash__() const
1956 {
1957     hash_t seed = SYMENGINE_DERIVATIVE;
1958     hash_combine<Basic>(seed, *arg_);
1959     for (auto &p : x_) {
1960         hash_combine<Basic>(seed, *p);
1961     }
1962     return seed;
1963 }
1964 
__eq__(const Basic & o) const1965 bool Derivative::__eq__(const Basic &o) const
1966 {
1967     if (is_a<Derivative>(o)
1968         and eq(*arg_, *(down_cast<const Derivative &>(o).arg_))
1969         and unified_eq(x_, down_cast<const Derivative &>(o).x_))
1970         return true;
1971     return false;
1972 }
1973 
compare(const Basic & o) const1974 int Derivative::compare(const Basic &o) const
1975 {
1976     SYMENGINE_ASSERT(is_a<Derivative>(o))
1977     const Derivative &s = down_cast<const Derivative &>(o);
1978     int cmp = arg_->__cmp__(*(s.arg_));
1979     if (cmp != 0)
1980         return cmp;
1981     cmp = unified_compare(x_, s.x_);
1982     return cmp;
1983 }
1984 
1985 // Subs class
Subs(const RCP<const Basic> & arg,const map_basic_basic & dict)1986 Subs::Subs(const RCP<const Basic> &arg, const map_basic_basic &dict)
1987     : arg_{arg}, dict_{dict}
1988 {
1989     SYMENGINE_ASSIGN_TYPEID()
1990     SYMENGINE_ASSERT(is_canonical(arg, dict))
1991 }
1992 
is_canonical(const RCP<const Basic> & arg,const map_basic_basic & dict) const1993 bool Subs::is_canonical(const RCP<const Basic> &arg,
1994                         const map_basic_basic &dict) const
1995 {
1996     if (is_a<Derivative>(*arg)) {
1997         return true;
1998     }
1999     return false;
2000 }
2001 
__hash__() const2002 hash_t Subs::__hash__() const
2003 {
2004     hash_t seed = SYMENGINE_SUBS;
2005     hash_combine<Basic>(seed, *arg_);
2006     for (const auto &p : dict_) {
2007         hash_combine<Basic>(seed, *p.first);
2008         hash_combine<Basic>(seed, *p.second);
2009     }
2010     return seed;
2011 }
2012 
__eq__(const Basic & o) const2013 bool Subs::__eq__(const Basic &o) const
2014 {
2015     if (is_a<Subs>(o) and eq(*arg_, *(down_cast<const Subs &>(o).arg_))
2016         and unified_eq(dict_, down_cast<const Subs &>(o).dict_))
2017         return true;
2018     return false;
2019 }
2020 
compare(const Basic & o) const2021 int Subs::compare(const Basic &o) const
2022 {
2023     SYMENGINE_ASSERT(is_a<Subs>(o))
2024     const Subs &s = down_cast<const Subs &>(o);
2025     int cmp = arg_->__cmp__(*(s.arg_));
2026     if (cmp != 0)
2027         return cmp;
2028     cmp = unified_compare(dict_, s.dict_);
2029     return cmp;
2030 }
2031 
get_variables() const2032 vec_basic Subs::get_variables() const
2033 {
2034     vec_basic v;
2035     for (const auto &p : dict_) {
2036         v.push_back(p.first);
2037     }
2038     return v;
2039 }
2040 
get_point() const2041 vec_basic Subs::get_point() const
2042 {
2043     vec_basic v;
2044     for (const auto &p : dict_) {
2045         v.push_back(p.second);
2046     }
2047     return v;
2048 }
2049 
get_args() const2050 vec_basic Subs::get_args() const
2051 {
2052     vec_basic v = {arg_};
2053     for (const auto &p : dict_) {
2054         v.push_back(p.first);
2055     }
2056     for (const auto &p : dict_) {
2057         v.push_back(p.second);
2058     }
2059     return v;
2060 }
2061 
Sinh(const RCP<const Basic> & arg)2062 Sinh::Sinh(const RCP<const Basic> &arg) : HyperbolicFunction(arg)
2063 {
2064     SYMENGINE_ASSIGN_TYPEID()
2065     SYMENGINE_ASSERT(is_canonical(arg))
2066 }
2067 
is_canonical(const RCP<const Basic> & arg) const2068 bool Sinh::is_canonical(const RCP<const Basic> &arg) const
2069 {
2070     if (eq(*arg, *zero))
2071         return false;
2072     if (is_a_Number(*arg)) {
2073         if (down_cast<const Number &>(*arg).is_negative()) {
2074             return false;
2075         } else if (not down_cast<const Number &>(*arg).is_exact()) {
2076             return false;
2077         }
2078     }
2079     if (could_extract_minus(*arg))
2080         return false;
2081     return true;
2082 }
2083 
sinh(const RCP<const Basic> & arg)2084 RCP<const Basic> sinh(const RCP<const Basic> &arg)
2085 {
2086     if (eq(*arg, *zero))
2087         return zero;
2088     if (is_a_Number(*arg)) {
2089         RCP<const Number> _arg = rcp_static_cast<const Number>(arg);
2090         if (not _arg->is_exact()) {
2091             return _arg->get_eval().sinh(*_arg);
2092         } else if (_arg->is_negative()) {
2093             return neg(sinh(zero->sub(*_arg)));
2094         }
2095     }
2096     RCP<const Basic> d;
2097     bool b = handle_minus(arg, outArg(d));
2098     if (b) {
2099         return neg(sinh(d));
2100     }
2101     return make_rcp<const Sinh>(d);
2102 }
2103 
Csch(const RCP<const Basic> & arg)2104 Csch::Csch(const RCP<const Basic> &arg) : HyperbolicFunction(arg)
2105 {
2106     SYMENGINE_ASSIGN_TYPEID()
2107     SYMENGINE_ASSERT(is_canonical(arg))
2108 }
2109 
is_canonical(const RCP<const Basic> & arg) const2110 bool Csch::is_canonical(const RCP<const Basic> &arg) const
2111 {
2112     if (eq(*arg, *zero))
2113         return false;
2114     if (is_a_Number(*arg)) {
2115         if (down_cast<const Number &>(*arg).is_negative()) {
2116             return false;
2117         } else if (not down_cast<const Number &>(*arg).is_exact()) {
2118             return false;
2119         }
2120     }
2121     if (could_extract_minus(*arg))
2122         return false;
2123     return true;
2124 }
2125 
csch(const RCP<const Basic> & arg)2126 RCP<const Basic> csch(const RCP<const Basic> &arg)
2127 {
2128     if (eq(*arg, *zero)) {
2129         return ComplexInf;
2130     }
2131     if (is_a_Number(*arg)) {
2132         RCP<const Number> _arg = rcp_static_cast<const Number>(arg);
2133         if (not _arg->is_exact()) {
2134             return _arg->get_eval().csch(*_arg);
2135         } else if (_arg->is_negative()) {
2136             return neg(csch(zero->sub(*_arg)));
2137         }
2138     }
2139     RCP<const Basic> d;
2140     bool b = handle_minus(arg, outArg(d));
2141     if (b) {
2142         return neg(csch(d));
2143     }
2144     return make_rcp<const Csch>(d);
2145 }
2146 
Cosh(const RCP<const Basic> & arg)2147 Cosh::Cosh(const RCP<const Basic> &arg) : HyperbolicFunction(arg)
2148 {
2149     SYMENGINE_ASSIGN_TYPEID()
2150     SYMENGINE_ASSERT(is_canonical(arg))
2151 }
2152 
is_canonical(const RCP<const Basic> & arg) const2153 bool Cosh::is_canonical(const RCP<const Basic> &arg) const
2154 {
2155     if (eq(*arg, *zero))
2156         return false;
2157     if (is_a_Number(*arg)) {
2158         if (down_cast<const Number &>(*arg).is_negative()) {
2159             return false;
2160         } else if (not down_cast<const Number &>(*arg).is_exact()) {
2161             return false;
2162         }
2163     }
2164     if (could_extract_minus(*arg))
2165         return false;
2166     return true;
2167 }
2168 
cosh(const RCP<const Basic> & arg)2169 RCP<const Basic> cosh(const RCP<const Basic> &arg)
2170 {
2171     if (eq(*arg, *zero))
2172         return one;
2173     if (is_a_Number(*arg)) {
2174         RCP<const Number> _arg = rcp_static_cast<const Number>(arg);
2175         if (not _arg->is_exact()) {
2176             return _arg->get_eval().cosh(*_arg);
2177         } else if (_arg->is_negative()) {
2178             return cosh(zero->sub(*_arg));
2179         }
2180     }
2181     RCP<const Basic> d;
2182     handle_minus(arg, outArg(d));
2183     return make_rcp<const Cosh>(d);
2184 }
2185 
Sech(const RCP<const Basic> & arg)2186 Sech::Sech(const RCP<const Basic> &arg) : HyperbolicFunction(arg)
2187 {
2188     SYMENGINE_ASSIGN_TYPEID()
2189     SYMENGINE_ASSERT(is_canonical(arg))
2190 }
2191 
is_canonical(const RCP<const Basic> & arg) const2192 bool Sech::is_canonical(const RCP<const Basic> &arg) const
2193 {
2194     if (eq(*arg, *zero))
2195         return false;
2196     if (is_a_Number(*arg)) {
2197         if (down_cast<const Number &>(*arg).is_negative()) {
2198             return false;
2199         } else if (not down_cast<const Number &>(*arg).is_exact()) {
2200             return false;
2201         }
2202     }
2203     if (could_extract_minus(*arg))
2204         return false;
2205     return true;
2206 }
2207 
sech(const RCP<const Basic> & arg)2208 RCP<const Basic> sech(const RCP<const Basic> &arg)
2209 {
2210     if (eq(*arg, *zero))
2211         return one;
2212     if (is_a_Number(*arg)) {
2213         RCP<const Number> _arg = rcp_static_cast<const Number>(arg);
2214         if (not _arg->is_exact()) {
2215             return _arg->get_eval().sech(*_arg);
2216         } else if (_arg->is_negative()) {
2217             return sech(zero->sub(*_arg));
2218         }
2219     }
2220     RCP<const Basic> d;
2221     handle_minus(arg, outArg(d));
2222     return make_rcp<const Sech>(d);
2223 }
2224 
Tanh(const RCP<const Basic> & arg)2225 Tanh::Tanh(const RCP<const Basic> &arg) : HyperbolicFunction(arg)
2226 {
2227     SYMENGINE_ASSIGN_TYPEID()
2228     SYMENGINE_ASSERT(is_canonical(arg))
2229 }
2230 
is_canonical(const RCP<const Basic> & arg) const2231 bool Tanh::is_canonical(const RCP<const Basic> &arg) const
2232 {
2233     if (eq(*arg, *zero))
2234         return false;
2235     if (is_a_Number(*arg)) {
2236         if (down_cast<const Number &>(*arg).is_negative()) {
2237             return false;
2238         } else if (not down_cast<const Number &>(*arg).is_exact()) {
2239             return false;
2240         }
2241     }
2242     if (could_extract_minus(*arg))
2243         return false;
2244     return true;
2245 }
2246 
tanh(const RCP<const Basic> & arg)2247 RCP<const Basic> tanh(const RCP<const Basic> &arg)
2248 {
2249     if (eq(*arg, *zero))
2250         return zero;
2251     if (is_a_Number(*arg)) {
2252         RCP<const Number> _arg = rcp_static_cast<const Number>(arg);
2253         if (not _arg->is_exact()) {
2254             return _arg->get_eval().tanh(*_arg);
2255         } else if (_arg->is_negative()) {
2256             return neg(tanh(zero->sub(*_arg)));
2257         }
2258     }
2259 
2260     RCP<const Basic> d;
2261     bool b = handle_minus(arg, outArg(d));
2262     if (b) {
2263         return neg(tanh(d));
2264     }
2265     return make_rcp<const Tanh>(d);
2266 }
2267 
Coth(const RCP<const Basic> & arg)2268 Coth::Coth(const RCP<const Basic> &arg) : HyperbolicFunction(arg)
2269 {
2270     SYMENGINE_ASSIGN_TYPEID()
2271     SYMENGINE_ASSERT(is_canonical(arg))
2272 }
2273 
is_canonical(const RCP<const Basic> & arg) const2274 bool Coth::is_canonical(const RCP<const Basic> &arg) const
2275 {
2276     if (eq(*arg, *zero))
2277         return false;
2278     if (is_a_Number(*arg)) {
2279         if (down_cast<const Number &>(*arg).is_negative()) {
2280             return false;
2281         } else if (not down_cast<const Number &>(*arg).is_exact()) {
2282             return false;
2283         }
2284     }
2285     if (could_extract_minus(*arg))
2286         return false;
2287     return true;
2288 }
2289 
coth(const RCP<const Basic> & arg)2290 RCP<const Basic> coth(const RCP<const Basic> &arg)
2291 {
2292     if (eq(*arg, *zero)) {
2293         return ComplexInf;
2294     }
2295     if (is_a_Number(*arg)) {
2296         RCP<const Number> _arg = rcp_static_cast<const Number>(arg);
2297         if (not _arg->is_exact()) {
2298             return _arg->get_eval().coth(*_arg);
2299         } else if (_arg->is_negative()) {
2300             return neg(coth(zero->sub(*_arg)));
2301         }
2302     }
2303     RCP<const Basic> d;
2304     bool b = handle_minus(arg, outArg(d));
2305     if (b) {
2306         return neg(coth(d));
2307     }
2308     return make_rcp<const Coth>(d);
2309 }
2310 
ASinh(const RCP<const Basic> & arg)2311 ASinh::ASinh(const RCP<const Basic> &arg) : InverseHyperbolicFunction(arg)
2312 {
2313     SYMENGINE_ASSIGN_TYPEID()
2314     SYMENGINE_ASSERT(is_canonical(arg))
2315 }
2316 
is_canonical(const RCP<const Basic> & arg) const2317 bool ASinh::is_canonical(const RCP<const Basic> &arg) const
2318 {
2319     if (eq(*arg, *zero) or eq(*arg, *one) or eq(*arg, *minus_one))
2320         return false;
2321     if (is_a_Number(*arg)) {
2322         if (down_cast<const Number &>(*arg).is_negative()) {
2323             return false;
2324         } else if (not down_cast<const Number &>(*arg).is_exact()) {
2325             return false;
2326         }
2327     }
2328     if (could_extract_minus(*arg))
2329         return false;
2330     return true;
2331 }
2332 
asinh(const RCP<const Basic> & arg)2333 RCP<const Basic> asinh(const RCP<const Basic> &arg)
2334 {
2335     if (eq(*arg, *zero))
2336         return zero;
2337     if (eq(*arg, *one))
2338         return log(add(one, sq2));
2339     if (eq(*arg, *minus_one))
2340         return log(sub(sq2, one));
2341     if (is_a_Number(*arg)) {
2342         RCP<const Number> _arg = rcp_static_cast<const Number>(arg);
2343         if (not _arg->is_exact()) {
2344             return _arg->get_eval().asinh(*_arg);
2345         } else if (_arg->is_negative()) {
2346             return neg(asinh(zero->sub(*_arg)));
2347         }
2348     }
2349     RCP<const Basic> d;
2350     bool b = handle_minus(arg, outArg(d));
2351     if (b) {
2352         return neg(asinh(d));
2353     }
2354     return make_rcp<const ASinh>(d);
2355 }
2356 
ACsch(const RCP<const Basic> & arg)2357 ACsch::ACsch(const RCP<const Basic> &arg) : InverseHyperbolicFunction(arg)
2358 {
2359     SYMENGINE_ASSIGN_TYPEID()
2360     SYMENGINE_ASSERT(is_canonical(arg))
2361 }
2362 
is_canonical(const RCP<const Basic> & arg) const2363 bool ACsch::is_canonical(const RCP<const Basic> &arg) const
2364 {
2365     if (eq(*arg, *one) or eq(*arg, *minus_one))
2366         return false;
2367     if (is_a_Number(*arg)) {
2368         if (down_cast<const Number &>(*arg).is_negative()) {
2369             return false;
2370         } else if (not down_cast<const Number &>(*arg).is_exact()) {
2371             return false;
2372         }
2373     }
2374     if (could_extract_minus(*arg))
2375         return false;
2376     return true;
2377 }
2378 
acsch(const RCP<const Basic> & arg)2379 RCP<const Basic> acsch(const RCP<const Basic> &arg)
2380 {
2381     if (eq(*arg, *one))
2382         return log(add(one, sq2));
2383     if (eq(*arg, *minus_one))
2384         return log(sub(sq2, one));
2385 
2386     if (is_a_Number(*arg)) {
2387         RCP<const Number> _arg = rcp_static_cast<const Number>(arg);
2388         if (not _arg->is_exact()) {
2389             return _arg->get_eval().acsch(*_arg);
2390         }
2391     }
2392 
2393     RCP<const Basic> d;
2394     bool b = handle_minus(arg, outArg(d));
2395     if (b) {
2396         return neg(acsch(d));
2397     }
2398     return make_rcp<const ACsch>(d);
2399 }
2400 
ACosh(const RCP<const Basic> & arg)2401 ACosh::ACosh(const RCP<const Basic> &arg) : InverseHyperbolicFunction(arg)
2402 {
2403     SYMENGINE_ASSIGN_TYPEID()
2404     SYMENGINE_ASSERT(is_canonical(arg))
2405 }
2406 
is_canonical(const RCP<const Basic> & arg) const2407 bool ACosh::is_canonical(const RCP<const Basic> &arg) const
2408 {
2409     // TODO: Lookup into a cst table once complex is implemented
2410     if (eq(*arg, *one))
2411         return false;
2412     if (is_a_Number(*arg) and not down_cast<const Number &>(*arg).is_exact()) {
2413         return false;
2414     }
2415     return true;
2416 }
2417 
acosh(const RCP<const Basic> & arg)2418 RCP<const Basic> acosh(const RCP<const Basic> &arg)
2419 {
2420     // TODO: Lookup into a cst table once complex is implemented
2421     if (eq(*arg, *one))
2422         return zero;
2423     if (is_a_Number(*arg) and not down_cast<const Number &>(*arg).is_exact()) {
2424         return down_cast<const Number &>(*arg).get_eval().acosh(*arg);
2425     }
2426     return make_rcp<const ACosh>(arg);
2427 }
2428 
ATanh(const RCP<const Basic> & arg)2429 ATanh::ATanh(const RCP<const Basic> &arg) : InverseHyperbolicFunction(arg)
2430 {
2431     SYMENGINE_ASSIGN_TYPEID()
2432     SYMENGINE_ASSERT(is_canonical(arg))
2433 }
2434 
is_canonical(const RCP<const Basic> & arg) const2435 bool ATanh::is_canonical(const RCP<const Basic> &arg) const
2436 {
2437     if (eq(*arg, *zero))
2438         return false;
2439     if (is_a_Number(*arg)) {
2440         if (down_cast<const Number &>(*arg).is_negative()) {
2441             return false;
2442         } else if (not down_cast<const Number &>(*arg).is_exact()) {
2443             return false;
2444         }
2445     }
2446     if (could_extract_minus(*arg))
2447         return false;
2448     return true;
2449 }
2450 
atanh(const RCP<const Basic> & arg)2451 RCP<const Basic> atanh(const RCP<const Basic> &arg)
2452 {
2453     if (eq(*arg, *zero))
2454         return zero;
2455     if (is_a_Number(*arg)) {
2456         RCP<const Number> _arg = rcp_static_cast<const Number>(arg);
2457         if (not _arg->is_exact()) {
2458             return _arg->get_eval().atanh(*_arg);
2459         } else if (_arg->is_negative()) {
2460             return neg(atanh(zero->sub(*_arg)));
2461         }
2462     }
2463     RCP<const Basic> d;
2464     bool b = handle_minus(arg, outArg(d));
2465     if (b) {
2466         return neg(atanh(d));
2467     }
2468     return make_rcp<const ATanh>(d);
2469 }
2470 
ACoth(const RCP<const Basic> & arg)2471 ACoth::ACoth(const RCP<const Basic> &arg) : InverseHyperbolicFunction(arg)
2472 {
2473     SYMENGINE_ASSIGN_TYPEID()
2474     SYMENGINE_ASSERT(is_canonical(arg))
2475 }
2476 
is_canonical(const RCP<const Basic> & arg) const2477 bool ACoth::is_canonical(const RCP<const Basic> &arg) const
2478 {
2479     if (is_a_Number(*arg)) {
2480         if (down_cast<const Number &>(*arg).is_negative()) {
2481             return false;
2482         } else if (not down_cast<const Number &>(*arg).is_exact()) {
2483             return false;
2484         }
2485     }
2486     if (could_extract_minus(*arg))
2487         return false;
2488     return true;
2489 }
2490 
acoth(const RCP<const Basic> & arg)2491 RCP<const Basic> acoth(const RCP<const Basic> &arg)
2492 {
2493     if (is_a_Number(*arg)) {
2494         RCP<const Number> _arg = rcp_static_cast<const Number>(arg);
2495         if (not _arg->is_exact()) {
2496             return _arg->get_eval().acoth(*_arg);
2497         } else if (_arg->is_negative()) {
2498             return neg(acoth(zero->sub(*_arg)));
2499         }
2500     }
2501     RCP<const Basic> d;
2502     bool b = handle_minus(arg, outArg(d));
2503     if (b) {
2504         return neg(acoth(d));
2505     }
2506     return make_rcp<const ACoth>(d);
2507 }
2508 
ASech(const RCP<const Basic> & arg)2509 ASech::ASech(const RCP<const Basic> &arg) : InverseHyperbolicFunction(arg)
2510 {
2511     SYMENGINE_ASSIGN_TYPEID()
2512     SYMENGINE_ASSERT(is_canonical(arg))
2513 }
2514 
is_canonical(const RCP<const Basic> & arg) const2515 bool ASech::is_canonical(const RCP<const Basic> &arg) const
2516 {
2517     // TODO: Lookup into a cst table once complex is implemented
2518     if (eq(*arg, *one))
2519         return false;
2520     if (eq(*arg, *zero))
2521         return false;
2522     if (is_a_Number(*arg) and not down_cast<const Number &>(*arg).is_exact()) {
2523         return false;
2524     }
2525     return true;
2526 }
2527 
asech(const RCP<const Basic> & arg)2528 RCP<const Basic> asech(const RCP<const Basic> &arg)
2529 {
2530     // TODO: Lookup into a cst table once complex is implemented
2531     if (eq(*arg, *one))
2532         return zero;
2533     if (eq(*arg, *zero))
2534         return Inf;
2535     if (is_a_Number(*arg)) {
2536         RCP<const Number> _arg = rcp_static_cast<const Number>(arg);
2537         if (not _arg->is_exact()) {
2538             return _arg->get_eval().asech(*_arg);
2539         }
2540     }
2541     return make_rcp<const ASech>(arg);
2542 }
2543 
create(const RCP<const Basic> & arg) const2544 RCP<const Basic> Sinh::create(const RCP<const Basic> &arg) const
2545 {
2546     return sinh(arg);
2547 }
2548 
create(const RCP<const Basic> & arg) const2549 RCP<const Basic> Csch::create(const RCP<const Basic> &arg) const
2550 {
2551     return csch(arg);
2552 }
2553 
create(const RCP<const Basic> & arg) const2554 RCP<const Basic> Cosh::create(const RCP<const Basic> &arg) const
2555 {
2556     return cosh(arg);
2557 }
2558 
create(const RCP<const Basic> & arg) const2559 RCP<const Basic> Sech::create(const RCP<const Basic> &arg) const
2560 {
2561     return sech(arg);
2562 }
2563 
create(const RCP<const Basic> & arg) const2564 RCP<const Basic> Tanh::create(const RCP<const Basic> &arg) const
2565 {
2566     return tanh(arg);
2567 }
2568 
create(const RCP<const Basic> & arg) const2569 RCP<const Basic> Coth::create(const RCP<const Basic> &arg) const
2570 {
2571     return coth(arg);
2572 }
2573 
create(const RCP<const Basic> & arg) const2574 RCP<const Basic> ASinh::create(const RCP<const Basic> &arg) const
2575 {
2576     return asinh(arg);
2577 }
2578 
create(const RCP<const Basic> & arg) const2579 RCP<const Basic> ACsch::create(const RCP<const Basic> &arg) const
2580 {
2581     return acsch(arg);
2582 }
2583 
create(const RCP<const Basic> & arg) const2584 RCP<const Basic> ACosh::create(const RCP<const Basic> &arg) const
2585 {
2586     return acosh(arg);
2587 }
2588 
create(const RCP<const Basic> & arg) const2589 RCP<const Basic> ATanh::create(const RCP<const Basic> &arg) const
2590 {
2591     return atanh(arg);
2592 }
2593 
create(const RCP<const Basic> & arg) const2594 RCP<const Basic> ACoth::create(const RCP<const Basic> &arg) const
2595 {
2596     return acoth(arg);
2597 }
2598 
create(const RCP<const Basic> & arg) const2599 RCP<const Basic> ASech::create(const RCP<const Basic> &arg) const
2600 {
2601     return asech(arg);
2602 }
2603 
KroneckerDelta(const RCP<const Basic> & i,const RCP<const Basic> & j)2604 KroneckerDelta::KroneckerDelta(const RCP<const Basic> &i,
2605                                const RCP<const Basic> &j)
2606     : TwoArgFunction(i, j)
2607 {
2608     SYMENGINE_ASSIGN_TYPEID()
2609     SYMENGINE_ASSERT(is_canonical(i, j))
2610 }
2611 
is_canonical(const RCP<const Basic> & i,const RCP<const Basic> & j) const2612 bool KroneckerDelta::is_canonical(const RCP<const Basic> &i,
2613                                   const RCP<const Basic> &j) const
2614 {
2615     RCP<const Basic> diff = expand(sub(i, j));
2616     if (eq(*diff, *zero)) {
2617         return false;
2618     } else if (is_a_Number(*diff)) {
2619         return false;
2620     } else {
2621         // TODO: SymPy uses default key sorting to return in order
2622         return true;
2623     }
2624 }
2625 
create(const RCP<const Basic> & a,const RCP<const Basic> & b) const2626 RCP<const Basic> KroneckerDelta::create(const RCP<const Basic> &a,
2627                                         const RCP<const Basic> &b) const
2628 {
2629     return kronecker_delta(a, b);
2630 }
2631 
kronecker_delta(const RCP<const Basic> & i,const RCP<const Basic> & j)2632 RCP<const Basic> kronecker_delta(const RCP<const Basic> &i,
2633                                  const RCP<const Basic> &j)
2634 {
2635     // Expand is needed to simplify things like `i-(i+1)` to `-1`
2636     RCP<const Basic> diff = expand(sub(i, j));
2637     if (eq(*diff, *zero)) {
2638         return one;
2639     } else if (is_a_Number(*diff)) {
2640         return zero;
2641     } else {
2642         // SymPy uses default key sorting to return in order
2643         return make_rcp<const KroneckerDelta>(i, j);
2644     }
2645 }
2646 
has_dup(const vec_basic & arg)2647 bool has_dup(const vec_basic &arg)
2648 {
2649     map_basic_basic d;
2650     auto it = d.end();
2651     for (const auto &p : arg) {
2652         it = d.find(p);
2653         if (it == d.end()) {
2654             insert(d, p, one);
2655         } else {
2656             return true;
2657         }
2658     }
2659     return false;
2660 }
2661 
LeviCivita(const vec_basic && arg)2662 LeviCivita::LeviCivita(const vec_basic &&arg) : MultiArgFunction(std::move(arg))
2663 {
2664     SYMENGINE_ASSIGN_TYPEID()
2665     SYMENGINE_ASSERT(is_canonical(get_vec()))
2666 }
2667 
is_canonical(const vec_basic & arg) const2668 bool LeviCivita::is_canonical(const vec_basic &arg) const
2669 {
2670     bool are_int = true;
2671     for (const auto &p : arg) {
2672         if (not(is_a_Number(*p))) {
2673             are_int = false;
2674             break;
2675         }
2676     }
2677     if (are_int) {
2678         return false;
2679     } else if (has_dup(arg)) {
2680         return false;
2681     } else {
2682         return true;
2683     }
2684 }
2685 
create(const vec_basic & a) const2686 RCP<const Basic> LeviCivita::create(const vec_basic &a) const
2687 {
2688     return levi_civita(a);
2689 }
2690 
eval_levicivita(const vec_basic & arg,int len)2691 RCP<const Basic> eval_levicivita(const vec_basic &arg, int len)
2692 {
2693     int i, j;
2694     RCP<const Basic> res = one;
2695     for (i = 0; i < len; i++) {
2696         for (j = i + 1; j < len; j++) {
2697             res = mul(sub(arg[j], arg[i]), res);
2698         }
2699         res = div(res, factorial(i));
2700     }
2701     return res;
2702 }
2703 
levi_civita(const vec_basic & arg)2704 RCP<const Basic> levi_civita(const vec_basic &arg)
2705 {
2706     bool are_int = true;
2707     int len = 0;
2708     for (const auto &p : arg) {
2709         if (not(is_a_Number(*p))) {
2710             are_int = false;
2711             break;
2712         } else {
2713             len++;
2714         }
2715     }
2716     if (are_int) {
2717         return eval_levicivita(arg, len);
2718     } else if (has_dup(arg)) {
2719         return zero;
2720     } else {
2721         return make_rcp<const LeviCivita>(std::move(arg));
2722     }
2723 }
2724 
Zeta(const RCP<const Basic> & s,const RCP<const Basic> & a)2725 Zeta::Zeta(const RCP<const Basic> &s, const RCP<const Basic> &a)
2726     : TwoArgFunction(s, a){SYMENGINE_ASSIGN_TYPEID()
2727                                SYMENGINE_ASSERT(is_canonical(s, a))}
2728 
Zeta(const RCP<const Basic> & s)2729       Zeta::Zeta(const RCP<const Basic> &s)
2730     : TwoArgFunction(s, one)
2731 {
2732     SYMENGINE_ASSIGN_TYPEID()
2733     SYMENGINE_ASSERT(is_canonical(s, one))
2734 }
2735 
is_canonical(const RCP<const Basic> & s,const RCP<const Basic> & a) const2736 bool Zeta::is_canonical(const RCP<const Basic> &s,
2737                         const RCP<const Basic> &a) const
2738 {
2739     if (eq(*s, *zero))
2740         return false;
2741     if (eq(*s, *one))
2742         return false;
2743     if (is_a<Integer>(*s) and is_a<Integer>(*a)) {
2744         auto s_ = down_cast<const Integer &>(*s).as_int();
2745         if (s_ < 0 || s_ % 2 == 0)
2746             return false;
2747     }
2748     return true;
2749 }
2750 
create(const RCP<const Basic> & a,const RCP<const Basic> & b) const2751 RCP<const Basic> Zeta::create(const RCP<const Basic> &a,
2752                               const RCP<const Basic> &b) const
2753 {
2754     return zeta(a, b);
2755 }
2756 
zeta(const RCP<const Basic> & s,const RCP<const Basic> & a)2757 RCP<const Basic> zeta(const RCP<const Basic> &s, const RCP<const Basic> &a)
2758 {
2759     if (is_a_Number(*s)) {
2760         if (down_cast<const Number &>(*s).is_zero()) {
2761             return sub(div(one, i2), a);
2762         } else if (down_cast<const Number &>(*s).is_one()) {
2763             return infty(0);
2764         } else if (is_a<Integer>(*s) and is_a<Integer>(*a)) {
2765             auto s_ = down_cast<const Integer &>(*s).as_int();
2766             auto a_ = down_cast<const Integer &>(*a).as_int();
2767             RCP<const Basic> zeta;
2768             if (s_ < 0) {
2769                 RCP<const Number> res = (s_ % 2 == 0) ? one : minus_one;
2770                 zeta
2771                     = mulnum(res, divnum(bernoulli(-s_ + 1), integer(-s_ + 1)));
2772             } else if (s_ % 2 == 0) {
2773                 RCP<const Number> b = bernoulli(s_);
2774                 RCP<const Number> f = factorial(s_);
2775                 zeta = divnum(pownum(integer(2), integer(s_ - 1)), f);
2776                 zeta = mul(zeta, mul(pow(pi, s), abs(b)));
2777             } else {
2778                 return make_rcp<const Zeta>(s, a);
2779             }
2780             if (a_ < 0)
2781                 return add(zeta, harmonic(-a_, s_));
2782             return sub(zeta, harmonic(a_ - 1, s_));
2783         }
2784     }
2785     return make_rcp<const Zeta>(s, a);
2786 }
2787 
zeta(const RCP<const Basic> & s)2788 RCP<const Basic> zeta(const RCP<const Basic> &s)
2789 {
2790     return zeta(s, one);
2791 }
2792 
Dirichlet_eta(const RCP<const Basic> & s)2793 Dirichlet_eta::Dirichlet_eta(const RCP<const Basic> &s) : OneArgFunction(s)
2794 {
2795     SYMENGINE_ASSIGN_TYPEID()
2796     SYMENGINE_ASSERT(is_canonical(s))
2797 }
2798 
is_canonical(const RCP<const Basic> & s) const2799 bool Dirichlet_eta::is_canonical(const RCP<const Basic> &s) const
2800 {
2801     if (eq(*s, *one))
2802         return false;
2803     if (not(is_a<Zeta>(*zeta(s))))
2804         return false;
2805     return true;
2806 }
2807 
rewrite_as_zeta() const2808 RCP<const Basic> Dirichlet_eta::rewrite_as_zeta() const
2809 {
2810     return mul(sub(one, pow(i2, sub(one, get_arg()))), zeta(get_arg()));
2811 }
2812 
create(const RCP<const Basic> & arg) const2813 RCP<const Basic> Dirichlet_eta::create(const RCP<const Basic> &arg) const
2814 {
2815     return dirichlet_eta(arg);
2816 }
2817 
dirichlet_eta(const RCP<const Basic> & s)2818 RCP<const Basic> dirichlet_eta(const RCP<const Basic> &s)
2819 {
2820     if (is_a_Number(*s) and down_cast<const Number &>(*s).is_one()) {
2821         return log(i2);
2822     }
2823     RCP<const Basic> z = zeta(s);
2824     if (is_a<Zeta>(*z)) {
2825         return make_rcp<const Dirichlet_eta>(s);
2826     } else {
2827         return mul(sub(one, pow(i2, sub(one, s))), z);
2828     }
2829 }
2830 
is_canonical(const RCP<const Basic> & arg) const2831 bool Erf::is_canonical(const RCP<const Basic> &arg) const
2832 {
2833     if (is_a<Integer>(*arg) and down_cast<const Integer &>(*arg).is_zero())
2834         return false;
2835     if (could_extract_minus(*arg))
2836         return false;
2837     if (is_a_Number(*arg) and not down_cast<const Number &>(*arg).is_exact()) {
2838         return false;
2839     }
2840     return true;
2841 }
2842 
create(const RCP<const Basic> & arg) const2843 RCP<const Basic> Erf::create(const RCP<const Basic> &arg) const
2844 {
2845     return erf(arg);
2846 }
2847 
erf(const RCP<const Basic> & arg)2848 RCP<const Basic> erf(const RCP<const Basic> &arg)
2849 {
2850     if (is_a<Integer>(*arg) and down_cast<const Integer &>(*arg).is_zero()) {
2851         return zero;
2852     }
2853     if (is_a_Number(*arg)) {
2854         RCP<const Number> _arg = rcp_static_cast<const Number>(arg);
2855         if (not _arg->is_exact()) {
2856             return _arg->get_eval().erf(*_arg);
2857         }
2858     }
2859     RCP<const Basic> d;
2860     bool b = handle_minus(arg, outArg(d));
2861     if (b) {
2862         return neg(erf(d));
2863     }
2864     return make_rcp<const Erf>(d);
2865 }
2866 
is_canonical(const RCP<const Basic> & arg) const2867 bool Erfc::is_canonical(const RCP<const Basic> &arg) const
2868 {
2869     if (is_a<Integer>(*arg) and down_cast<const Integer &>(*arg).is_zero())
2870         return false;
2871     if (could_extract_minus(*arg))
2872         return false;
2873     if (is_a_Number(*arg) and not down_cast<const Number &>(*arg).is_exact()) {
2874         return false;
2875     }
2876     return true;
2877 }
2878 
create(const RCP<const Basic> & arg) const2879 RCP<const Basic> Erfc::create(const RCP<const Basic> &arg) const
2880 {
2881     return erfc(arg);
2882 }
2883 
erfc(const RCP<const Basic> & arg)2884 RCP<const Basic> erfc(const RCP<const Basic> &arg)
2885 {
2886     if (is_a<Integer>(*arg) and down_cast<const Integer &>(*arg).is_zero()) {
2887         return one;
2888     }
2889     if (is_a_Number(*arg)) {
2890         RCP<const Number> _arg = rcp_static_cast<const Number>(arg);
2891         if (not _arg->is_exact()) {
2892             return _arg->get_eval().erfc(*_arg);
2893         }
2894     }
2895 
2896     RCP<const Basic> d;
2897     bool b = handle_minus(arg, outArg(d));
2898     if (b) {
2899         return add(integer(2), neg(erfc(d)));
2900     }
2901     return make_rcp<const Erfc>(d);
2902 }
2903 
Gamma(const RCP<const Basic> & arg)2904 Gamma::Gamma(const RCP<const Basic> &arg) : OneArgFunction{arg}
2905 {
2906     SYMENGINE_ASSIGN_TYPEID()
2907     SYMENGINE_ASSERT(is_canonical(arg))
2908 }
2909 
is_canonical(const RCP<const Basic> & arg) const2910 bool Gamma::is_canonical(const RCP<const Basic> &arg) const
2911 {
2912     if (is_a<Integer>(*arg))
2913         return false;
2914     if (is_a<Rational>(*arg)
2915         and (get_den(down_cast<const Rational &>(*arg).as_rational_class()))
2916                 == 2) {
2917         return false;
2918     }
2919     if (is_a_Number(*arg) and not down_cast<const Number &>(*arg).is_exact()) {
2920         return false;
2921     }
2922     return true;
2923 }
2924 
create(const RCP<const Basic> & arg) const2925 RCP<const Basic> Gamma::create(const RCP<const Basic> &arg) const
2926 {
2927     return gamma(arg);
2928 }
2929 
gamma_positive_int(const RCP<const Basic> & arg)2930 RCP<const Basic> gamma_positive_int(const RCP<const Basic> &arg)
2931 {
2932     SYMENGINE_ASSERT(is_a<Integer>(*arg))
2933     RCP<const Integer> arg_ = rcp_static_cast<const Integer>(arg);
2934     SYMENGINE_ASSERT(arg_->is_positive())
2935     return factorial((arg_->subint(*one))->as_int());
2936 }
2937 
gamma_multiple_2(const RCP<const Basic> & arg)2938 RCP<const Basic> gamma_multiple_2(const RCP<const Basic> &arg)
2939 {
2940     SYMENGINE_ASSERT(is_a<Rational>(*arg))
2941     RCP<const Rational> arg_ = rcp_static_cast<const Rational>(arg);
2942     SYMENGINE_ASSERT(get_den(arg_->as_rational_class()) == 2)
2943     RCP<const Integer> n, k;
2944     RCP<const Number> coeff;
2945     n = quotient_f(*(integer(mp_abs(get_num(arg_->as_rational_class())))),
2946                    *(integer(get_den(arg_->as_rational_class()))));
2947     if (arg_->is_positive()) {
2948         k = n;
2949         coeff = one;
2950     } else {
2951         n = n->addint(*one);
2952         k = n;
2953         if ((n->as_int() & 1) == 0) {
2954             coeff = one;
2955         } else {
2956             coeff = minus_one;
2957         }
2958     }
2959     int j = 1;
2960     for (int i = 3; i < 2 * k->as_int(); i = i + 2) {
2961         j = j * i;
2962     }
2963     coeff = mulnum(coeff, integer(j));
2964     if (arg_->is_positive()) {
2965         return div(mul(coeff, sqrt(pi)), pow(i2, n));
2966     } else {
2967         return div(mul(pow(i2, n), sqrt(pi)), coeff);
2968     }
2969 }
2970 
gamma(const RCP<const Basic> & arg)2971 RCP<const Basic> gamma(const RCP<const Basic> &arg)
2972 {
2973     if (is_a<Integer>(*arg)) {
2974         RCP<const Integer> arg_ = rcp_static_cast<const Integer>(arg);
2975         if (arg_->is_positive()) {
2976             return gamma_positive_int(arg);
2977         } else {
2978             return ComplexInf;
2979         }
2980     } else if (is_a<Rational>(*arg)) {
2981         RCP<const Rational> arg_ = rcp_static_cast<const Rational>(arg);
2982         if ((get_den(arg_->as_rational_class())) == 2) {
2983             return gamma_multiple_2(arg);
2984         } else {
2985             return make_rcp<const Gamma>(arg);
2986         }
2987     } else if (is_a_Number(*arg)
2988                and not down_cast<const Number &>(*arg).is_exact()) {
2989         return down_cast<const Number &>(*arg).get_eval().gamma(*arg);
2990     }
2991     return make_rcp<const Gamma>(arg);
2992 }
2993 
LowerGamma(const RCP<const Basic> & s,const RCP<const Basic> & x)2994 LowerGamma::LowerGamma(const RCP<const Basic> &s, const RCP<const Basic> &x)
2995     : TwoArgFunction(s, x)
2996 {
2997     SYMENGINE_ASSIGN_TYPEID()
2998     SYMENGINE_ASSERT(is_canonical(s, x))
2999 }
3000 
is_canonical(const RCP<const Basic> & s,const RCP<const Basic> & x) const3001 bool LowerGamma::is_canonical(const RCP<const Basic> &s,
3002                               const RCP<const Basic> &x) const
3003 {
3004     // Only special values are evaluated
3005     if (eq(*s, *one))
3006         return false;
3007     if (is_a<Integer>(*s)
3008         and down_cast<const Integer &>(*s).as_integer_class() > 1)
3009         return false;
3010     if (is_a<Integer>(*mul(i2, s)))
3011         return false;
3012 #ifdef HAVE_SYMENGINE_MPFR
3013 #if MPFR_VERSION_MAJOR > 3
3014     if (is_a<RealMPFR>(*s) && is_a<RealMPFR>(*x))
3015         return false;
3016 #endif
3017 #endif
3018     return true;
3019 }
3020 
create(const RCP<const Basic> & a,const RCP<const Basic> & b) const3021 RCP<const Basic> LowerGamma::create(const RCP<const Basic> &a,
3022                                     const RCP<const Basic> &b) const
3023 {
3024     return lowergamma(a, b);
3025 }
3026 
lowergamma(const RCP<const Basic> & s,const RCP<const Basic> & x)3027 RCP<const Basic> lowergamma(const RCP<const Basic> &s,
3028                             const RCP<const Basic> &x)
3029 {
3030     // Only special values are being evaluated
3031     if (is_a<Integer>(*s)) {
3032         RCP<const Integer> s_int = rcp_static_cast<const Integer>(s);
3033         if (s_int->is_one()) {
3034             return sub(one, exp(mul(minus_one, x)));
3035         } else if (s_int->as_integer_class() > 1) {
3036             s_int = s_int->subint(*one);
3037             return sub(mul(s_int, lowergamma(s_int, x)),
3038                        mul(pow(x, s_int), exp(mul(minus_one, x))));
3039         } else {
3040             return make_rcp<const LowerGamma>(s, x);
3041         }
3042     } else if (is_a<Integer>(*(mul(i2, s)))) {
3043         RCP<const Number> s_num = rcp_static_cast<const Number>(s);
3044         s_num = subnum(s_num, one);
3045         if (eq(*s, *div(one, integer(2)))) {
3046             return mul(sqrt(pi),
3047                        erf(sqrt(x))); // base case for s of the form n/2
3048         } else if (s_num->is_positive()) {
3049             return sub(mul(s_num, lowergamma(s_num, x)),
3050                        mul(pow(x, s_num), exp(mul(minus_one, x))));
3051         } else {
3052             return div(add(lowergamma(add(s, one), x),
3053                            mul(pow(x, s), exp(mul(minus_one, x)))),
3054                        s);
3055         }
3056 #ifdef HAVE_SYMENGINE_MPFR
3057 #if MPFR_VERSION_MAJOR > 3
3058     } else if (is_a<RealMPFR>(*s) && is_a<RealMPFR>(*x)) {
3059         const auto &s_ = down_cast<const RealMPFR &>(*s).i.get_mpfr_t();
3060         const auto &x_ = down_cast<const RealMPFR &>(*x).i.get_mpfr_t();
3061         if (mpfr_cmp_si(x_, 0) >= 0) {
3062             mpfr_class t(std::max(mpfr_get_prec(s_), mpfr_get_prec(x_)));
3063             mpfr_class u(std::max(mpfr_get_prec(s_), mpfr_get_prec(x_)));
3064             mpfr_gamma_inc(t.get_mpfr_t(), s_, x_, MPFR_RNDN);
3065             mpfr_gamma(u.get_mpfr_t(), s_, MPFR_RNDN);
3066             mpfr_sub(t.get_mpfr_t(), u.get_mpfr_t(), t.get_mpfr_t(), MPFR_RNDN);
3067             return real_mpfr(std::move(t));
3068         } else {
3069             throw NotImplementedError("Not implemented.");
3070         }
3071 #endif
3072 #endif
3073     }
3074     return make_rcp<const LowerGamma>(s, x);
3075 }
3076 
UpperGamma(const RCP<const Basic> & s,const RCP<const Basic> & x)3077 UpperGamma::UpperGamma(const RCP<const Basic> &s, const RCP<const Basic> &x)
3078     : TwoArgFunction(s, x)
3079 {
3080     SYMENGINE_ASSIGN_TYPEID()
3081     SYMENGINE_ASSERT(is_canonical(s, x))
3082 }
3083 
is_canonical(const RCP<const Basic> & s,const RCP<const Basic> & x) const3084 bool UpperGamma::is_canonical(const RCP<const Basic> &s,
3085                               const RCP<const Basic> &x) const
3086 {
3087     // Only special values are evaluated
3088     if (eq(*s, *one))
3089         return false;
3090     if (is_a<Integer>(*s)
3091         and down_cast<const Integer &>(*s).as_integer_class() > 1)
3092         return false;
3093     if (is_a<Integer>(*mul(i2, s)))
3094         return false;
3095 #ifdef HAVE_SYMENGINE_MPFR
3096 #if MPFR_VERSION_MAJOR > 3
3097     if (is_a<RealMPFR>(*s) && is_a<RealMPFR>(*x))
3098         return false;
3099 #endif
3100 #endif
3101     return true;
3102 }
3103 
create(const RCP<const Basic> & a,const RCP<const Basic> & b) const3104 RCP<const Basic> UpperGamma::create(const RCP<const Basic> &a,
3105                                     const RCP<const Basic> &b) const
3106 {
3107     return uppergamma(a, b);
3108 }
3109 
uppergamma(const RCP<const Basic> & s,const RCP<const Basic> & x)3110 RCP<const Basic> uppergamma(const RCP<const Basic> &s,
3111                             const RCP<const Basic> &x)
3112 {
3113     // Only special values are being evaluated
3114     if (is_a<Integer>(*s)) {
3115         RCP<const Integer> s_int = rcp_static_cast<const Integer>(s);
3116         if (s_int->is_one()) {
3117             return exp(mul(minus_one, x));
3118         } else if (s_int->as_integer_class() > 1) {
3119             s_int = s_int->subint(*one);
3120             return add(mul(s_int, uppergamma(s_int, x)),
3121                        mul(pow(x, s_int), exp(mul(minus_one, x))));
3122         } else {
3123             // TODO: implement unpolarfy to handle this case
3124             return make_rcp<const LowerGamma>(s, x);
3125         }
3126     } else if (is_a<Integer>(*(mul(i2, s)))) {
3127         RCP<const Number> s_num = rcp_static_cast<const Number>(s);
3128         s_num = subnum(s_num, one);
3129         if (eq(*s, *div(one, integer(2)))) {
3130             return mul(sqrt(pi),
3131                        erfc(sqrt(x))); // base case for s of the form n/2
3132         } else if (s_num->is_positive()) {
3133             return add(mul(s_num, uppergamma(s_num, x)),
3134                        mul(pow(x, s_num), exp(mul(minus_one, x))));
3135         } else {
3136             return div(sub(uppergamma(add(s, one), x),
3137                            mul(pow(x, s), exp(mul(minus_one, x)))),
3138                        s);
3139         }
3140 #ifdef HAVE_SYMENGINE_MPFR
3141 #if MPFR_VERSION_MAJOR > 3
3142     } else if (is_a<RealMPFR>(*s) && is_a<RealMPFR>(*x)) {
3143         const auto &s_ = down_cast<const RealMPFR &>(*s).i.get_mpfr_t();
3144         const auto &x_ = down_cast<const RealMPFR &>(*x).i.get_mpfr_t();
3145         if (mpfr_cmp_si(x_, 0) >= 0) {
3146             mpfr_class t(std::max(mpfr_get_prec(s_), mpfr_get_prec(x_)));
3147             mpfr_gamma_inc(t.get_mpfr_t(), s_, x_, MPFR_RNDN);
3148             return real_mpfr(std::move(t));
3149         } else {
3150             throw NotImplementedError("Not implemented.");
3151         }
3152 #endif
3153 #endif
3154     }
3155     return make_rcp<const UpperGamma>(s, x);
3156 }
3157 
is_canonical(const RCP<const Basic> & arg) const3158 bool LogGamma::is_canonical(const RCP<const Basic> &arg) const
3159 {
3160     if (is_a<Integer>(*arg)) {
3161         RCP<const Integer> arg_int = rcp_static_cast<const Integer>(arg);
3162         if (not arg_int->is_positive()) {
3163             return false;
3164         }
3165         if (eq(*integer(1), *arg_int) or eq(*integer(2), *arg_int)
3166             or eq(*integer(3), *arg_int)) {
3167             return false;
3168         }
3169     }
3170     return true;
3171 }
3172 
rewrite_as_gamma() const3173 RCP<const Basic> LogGamma::rewrite_as_gamma() const
3174 {
3175     return log(gamma(get_arg()));
3176 }
3177 
create(const RCP<const Basic> & arg) const3178 RCP<const Basic> LogGamma::create(const RCP<const Basic> &arg) const
3179 {
3180     return loggamma(arg);
3181 }
3182 
loggamma(const RCP<const Basic> & arg)3183 RCP<const Basic> loggamma(const RCP<const Basic> &arg)
3184 {
3185     if (is_a<Integer>(*arg)) {
3186         RCP<const Integer> arg_int = rcp_static_cast<const Integer>(arg);
3187         if (not arg_int->is_positive()) {
3188             return Inf;
3189         }
3190         if (eq(*integer(1), *arg_int) or eq(*integer(2), *arg_int)) {
3191             return zero;
3192         } else if (eq(*integer(3), *arg_int)) {
3193             return log(integer(2));
3194         }
3195     }
3196     return make_rcp<const LogGamma>(arg);
3197 }
3198 
from_two_basic(const RCP<const Basic> & x,const RCP<const Basic> & y)3199 RCP<const Beta> Beta::from_two_basic(const RCP<const Basic> &x,
3200                                      const RCP<const Basic> &y)
3201 {
3202     if (x->__cmp__(*y) == -1) {
3203         return make_rcp<const Beta>(y, x);
3204     }
3205     return make_rcp<const Beta>(x, y);
3206 }
3207 
is_canonical(const RCP<const Basic> & x,const RCP<const Basic> & y)3208 bool Beta::is_canonical(const RCP<const Basic> &x, const RCP<const Basic> &y)
3209 {
3210     if (x->__cmp__(*y) == -1) {
3211         return false;
3212     }
3213     if (is_a<Integer>(*x)
3214         or (is_a<Rational>(*x)
3215             and (get_den(down_cast<const Rational &>(*x).as_rational_class()))
3216                     == 2)) {
3217         if (is_a<Integer>(*y)
3218             or (is_a<Rational>(*y)
3219                 and (get_den(
3220                         down_cast<const Rational &>(*y).as_rational_class()))
3221                         == 2)) {
3222             return false;
3223         }
3224     }
3225     return true;
3226 }
3227 
rewrite_as_gamma() const3228 RCP<const Basic> Beta::rewrite_as_gamma() const
3229 {
3230     return div(mul(gamma(get_arg1()), gamma(get_arg2())),
3231                gamma(add(get_arg1(), get_arg2())));
3232 }
3233 
create(const RCP<const Basic> & a,const RCP<const Basic> & b) const3234 RCP<const Basic> Beta::create(const RCP<const Basic> &a,
3235                               const RCP<const Basic> &b) const
3236 {
3237     return beta(a, b);
3238 }
3239 
beta(const RCP<const Basic> & x,const RCP<const Basic> & y)3240 RCP<const Basic> beta(const RCP<const Basic> &x, const RCP<const Basic> &y)
3241 {
3242     // Only special values are being evaluated
3243     if (eq(*add(x, y), *one)) {
3244         return ComplexInf;
3245     }
3246 
3247     if (is_a<Integer>(*x)) {
3248         RCP<const Integer> x_int = rcp_static_cast<const Integer>(x);
3249         if (x_int->is_positive()) {
3250             if (is_a<Integer>(*y)) {
3251                 RCP<const Integer> y_int = rcp_static_cast<const Integer>(y);
3252                 if (y_int->is_positive()) {
3253                     return div(
3254                         mul(gamma_positive_int(x), gamma_positive_int(y)),
3255                         gamma_positive_int(add(x, y)));
3256                 } else {
3257                     return ComplexInf;
3258                 }
3259             } else if (is_a<Rational>(*y)) {
3260                 RCP<const Rational> y_ = rcp_static_cast<const Rational>(y);
3261                 if (get_den(y_->as_rational_class()) == 2) {
3262                     return div(mul(gamma_positive_int(x), gamma_multiple_2(y)),
3263                                gamma_multiple_2(add(x, y)));
3264                 } else {
3265                     return Beta::from_two_basic(x, y);
3266                 }
3267             }
3268         } else {
3269             return ComplexInf;
3270         }
3271     }
3272 
3273     if (is_a<Integer>(*y)) {
3274         RCP<const Integer> y_int = rcp_static_cast<const Integer>(y);
3275         if (y_int->is_positive()) {
3276             if (is_a<Rational>(*x)) {
3277                 RCP<const Rational> x_ = rcp_static_cast<const Rational>(x);
3278                 if (get_den(x_->as_rational_class()) == 2) {
3279                     return div(mul(gamma_positive_int(y), gamma_multiple_2(x)),
3280                                gamma_multiple_2(add(x, y)));
3281                 } else {
3282                     return Beta::from_two_basic(x, y);
3283                 }
3284             }
3285         } else {
3286             return ComplexInf;
3287         }
3288     }
3289 
3290     if (is_a<const Rational>(*x)
3291         and get_den(down_cast<const Rational &>(*x).as_rational_class()) == 2) {
3292         if (is_a<Integer>(*y)) {
3293             RCP<const Integer> y_int = rcp_static_cast<const Integer>(y);
3294             if (y_int->is_positive()) {
3295                 return div(mul(gamma_multiple_2(x), gamma_positive_int(y)),
3296                            gamma_multiple_2(add(x, y)));
3297             } else {
3298                 return ComplexInf;
3299             }
3300         }
3301         if (is_a<const Rational>(*y)
3302             and get_den((down_cast<const Rational &>(*y)).as_rational_class())
3303                     == 2) {
3304             return div(mul(gamma_multiple_2(x), gamma_multiple_2(y)),
3305                        gamma_positive_int(add(x, y)));
3306         }
3307     }
3308     return Beta::from_two_basic(x, y);
3309 }
3310 
is_canonical(const RCP<const Basic> & n,const RCP<const Basic> & x)3311 bool PolyGamma::is_canonical(const RCP<const Basic> &n,
3312                              const RCP<const Basic> &x)
3313 {
3314     if (is_a_Number(*x) and not(down_cast<const Number &>(*x)).is_positive()) {
3315         return false;
3316     }
3317     if (eq(*n, *zero)) {
3318         if (eq(*x, *one)) {
3319             return false;
3320         }
3321         if (is_a<Rational>(*x)) {
3322             auto x_ = rcp_static_cast<const Rational>(x);
3323             auto den = get_den(x_->as_rational_class());
3324             if (den == 2 or den == 3 or den == 4) {
3325                 return false;
3326             }
3327         }
3328     }
3329     return true;
3330 }
3331 
rewrite_as_zeta() const3332 RCP<const Basic> PolyGamma::rewrite_as_zeta() const
3333 {
3334     if (not is_a<Integer>(*get_arg1())) {
3335         return rcp_from_this();
3336     }
3337     RCP<const Integer> n = rcp_static_cast<const Integer>(get_arg1());
3338     if (not(n->is_positive())) {
3339         return rcp_from_this();
3340     }
3341     if ((n->as_int() & 1) == 0) {
3342         return neg(mul(factorial(n->as_int()), zeta(add(n, one), get_arg2())));
3343     } else {
3344         return mul(factorial(n->as_int()), zeta(add(n, one), get_arg2()));
3345     }
3346 }
3347 
create(const RCP<const Basic> & a,const RCP<const Basic> & b) const3348 RCP<const Basic> PolyGamma::create(const RCP<const Basic> &a,
3349                                    const RCP<const Basic> &b) const
3350 {
3351     return polygamma(a, b);
3352 }
3353 
polygamma(const RCP<const Basic> & n_,const RCP<const Basic> & x_)3354 RCP<const Basic> polygamma(const RCP<const Basic> &n_,
3355                            const RCP<const Basic> &x_)
3356 {
3357     // Only special values are being evaluated
3358     if (is_a_Number(*x_)
3359         and not(down_cast<const Number &>(*x_)).is_positive()) {
3360         return ComplexInf;
3361     }
3362     if (is_a<Integer>(*n_) and is_a<Integer>(*x_)) {
3363         auto n = down_cast<const Integer &>(*n_).as_int();
3364         auto x = down_cast<const Integer &>(*x_).as_int();
3365         if (n == 0) {
3366             return sub(harmonic(x - 1, 1), EulerGamma);
3367         } else if (n % 2 == 1) {
3368             return mul(factorial(n), zeta(add(n_, one), x_));
3369         }
3370     }
3371     if (eq(*n_, *zero)) {
3372         if (eq(*x_, *one)) {
3373             return neg(EulerGamma);
3374         }
3375         if (is_a<Rational>(*x_)) {
3376             RCP<const Rational> x = rcp_static_cast<const Rational>(x_);
3377             const auto den = get_den(x->as_rational_class());
3378             const auto num = get_num(x->as_rational_class());
3379             const integer_class r = num % den;
3380             RCP<const Basic> res;
3381             if (den == 2) {
3382                 res = sub(mul(im2, log(i2)), EulerGamma);
3383             } else if (den == 3) {
3384                 if (num == 1) {
3385                     res = add(neg(div(div(pi, i2), sqrt(i3))),
3386                               sub(div(mul(im3, log(i3)), i2), EulerGamma));
3387                 } else {
3388                     res = add(div(div(pi, i2), sqrt(i3)),
3389                               sub(div(mul(im3, log(i3)), i2), EulerGamma));
3390                 }
3391             } else if (den == 4) {
3392                 if (num == 1) {
3393                     res = add(div(pi, im2), sub(mul(im3, log(i2)), EulerGamma));
3394                 } else {
3395                     res = add(div(pi, i2), sub(mul(im3, log(i2)), EulerGamma));
3396                 }
3397             } else {
3398                 return make_rcp<const PolyGamma>(n_, x_);
3399             }
3400             rational_class a(0), f(r, den);
3401             for (unsigned long i = 0; i < (num - r) / den; ++i) {
3402                 a += 1 / (f + i);
3403             }
3404             return add(Rational::from_mpq(a), res);
3405         }
3406     }
3407     return make_rcp<const PolyGamma>(n_, x_);
3408 }
3409 
digamma(const RCP<const Basic> & x)3410 RCP<const Basic> digamma(const RCP<const Basic> &x)
3411 {
3412     return polygamma(zero, x);
3413 }
3414 
trigamma(const RCP<const Basic> & x)3415 RCP<const Basic> trigamma(const RCP<const Basic> &x)
3416 {
3417     return polygamma(one, x);
3418 }
3419 
Abs(const RCP<const Basic> & arg)3420 Abs::Abs(const RCP<const Basic> &arg) : OneArgFunction(arg)
3421 {
3422     SYMENGINE_ASSIGN_TYPEID()
3423     SYMENGINE_ASSERT(is_canonical(arg))
3424 }
3425 
is_canonical(const RCP<const Basic> & arg) const3426 bool Abs::is_canonical(const RCP<const Basic> &arg) const
3427 {
3428     if (is_a<Integer>(*arg) or is_a<Rational>(*arg) or is_a<Complex>(*arg))
3429         return false;
3430     if (is_a_Number(*arg) and not down_cast<const Number &>(*arg).is_exact()) {
3431         return false;
3432     }
3433     if (is_a<Abs>(*arg)) {
3434         return false;
3435     }
3436 
3437     if (could_extract_minus(*arg)) {
3438         return false;
3439     }
3440 
3441     return true;
3442 }
3443 
create(const RCP<const Basic> & arg) const3444 RCP<const Basic> Abs::create(const RCP<const Basic> &arg) const
3445 {
3446     return abs(arg);
3447 }
3448 
abs(const RCP<const Basic> & arg)3449 RCP<const Basic> abs(const RCP<const Basic> &arg)
3450 {
3451     if (is_a<Integer>(*arg)) {
3452         RCP<const Integer> arg_ = rcp_static_cast<const Integer>(arg);
3453         if (arg_->is_negative()) {
3454             return arg_->neg();
3455         } else {
3456             return arg_;
3457         }
3458     } else if (is_a<Rational>(*arg)) {
3459         RCP<const Rational> arg_ = rcp_static_cast<const Rational>(arg);
3460         if (arg_->is_negative()) {
3461             return arg_->neg();
3462         } else {
3463             return arg_;
3464         }
3465     } else if (is_a<Complex>(*arg)) {
3466         RCP<const Complex> arg_ = rcp_static_cast<const Complex>(arg);
3467         return sqrt(Rational::from_mpq(arg_->real_ * arg_->real_
3468                                        + arg_->imaginary_ * arg_->imaginary_));
3469     } else if (is_a_Number(*arg)
3470                and not down_cast<const Number &>(*arg).is_exact()) {
3471         return down_cast<const Number &>(*arg).get_eval().abs(*arg);
3472     }
3473     if (is_a<Abs>(*arg)) {
3474         return arg;
3475     }
3476 
3477     RCP<const Basic> d;
3478     handle_minus(arg, outArg(d));
3479     return make_rcp<const Abs>(d);
3480 }
3481 
Max(const vec_basic && arg)3482 Max::Max(const vec_basic &&arg) : MultiArgFunction(std::move(arg))
3483 {
3484     SYMENGINE_ASSIGN_TYPEID()
3485     SYMENGINE_ASSERT(is_canonical(get_vec()))
3486 }
3487 
is_canonical(const vec_basic & arg) const3488 bool Max::is_canonical(const vec_basic &arg) const
3489 {
3490     if (arg.size() < 2)
3491         return false;
3492 
3493     bool non_number_exists = false;
3494 
3495     for (const auto &p : arg) {
3496         if (is_a<Complex>(*p) or is_a<Max>(*p))
3497             return false;
3498         if (not is_a_Number(*p))
3499             non_number_exists = true;
3500     }
3501     if (not std::is_sorted(arg.begin(), arg.end(), RCPBasicKeyLess()))
3502         return false;
3503 
3504     return non_number_exists; // all arguments cant be numbers
3505 }
3506 
create(const vec_basic & a) const3507 RCP<const Basic> Max::create(const vec_basic &a) const
3508 {
3509     return max(a);
3510 }
3511 
max(const vec_basic & arg)3512 RCP<const Basic> max(const vec_basic &arg)
3513 {
3514     bool number_set = false;
3515     RCP<const Number> max_number, difference;
3516     set_basic new_args;
3517 
3518     for (const auto &p : arg) {
3519         if (is_a<Complex>(*p))
3520             throw SymEngineException("Complex can't be passed to max!");
3521 
3522         if (is_a_Number(*p)) {
3523             if (not number_set) {
3524                 max_number = rcp_static_cast<const Number>(p);
3525 
3526             } else {
3527                 if (eq(*p, *Inf)) {
3528                     return Inf;
3529                 } else if (eq(*p, *NegInf)) {
3530                     continue;
3531                 }
3532                 difference = down_cast<const Number &>(*p).sub(*max_number);
3533 
3534                 if (difference->is_zero() and not difference->is_exact()) {
3535                     if (max_number->is_exact())
3536                         max_number = rcp_static_cast<const Number>(p);
3537                 } else if (difference->is_positive()) {
3538                     max_number = rcp_static_cast<const Number>(p);
3539                 }
3540             }
3541             number_set = true;
3542 
3543         } else if (is_a<Max>(*p)) {
3544             for (const auto &l : down_cast<const Max &>(*p).get_args()) {
3545                 if (is_a_Number(*l)) {
3546                     if (not number_set) {
3547                         max_number = rcp_static_cast<const Number>(l);
3548 
3549                     } else {
3550                         difference = rcp_static_cast<const Number>(l)->sub(
3551                             *max_number);
3552 
3553                         if (difference->is_zero()
3554                             and not difference->is_exact()) {
3555                             if (max_number->is_exact())
3556                                 max_number = rcp_static_cast<const Number>(l);
3557                         } else if (difference->is_positive()) {
3558                             max_number = rcp_static_cast<const Number>(l);
3559                         }
3560                     }
3561                     number_set = true;
3562                 } else {
3563                     new_args.insert(l);
3564                 }
3565             }
3566         } else {
3567             new_args.insert(p);
3568         }
3569     }
3570 
3571     if (number_set)
3572         new_args.insert(max_number);
3573 
3574     vec_basic final_args(new_args.size());
3575     std::copy(new_args.begin(), new_args.end(), final_args.begin());
3576 
3577     if (final_args.size() > 1) {
3578         return make_rcp<const Max>(std::move(final_args));
3579     } else if (final_args.size() == 1) {
3580         return final_args[0];
3581     } else {
3582         throw SymEngineException("Empty vec_basic passed to max!");
3583     }
3584 }
3585 
Min(const vec_basic && arg)3586 Min::Min(const vec_basic &&arg) : MultiArgFunction(std::move(arg))
3587 {
3588     SYMENGINE_ASSIGN_TYPEID()
3589     SYMENGINE_ASSERT(is_canonical(get_vec()))
3590 }
3591 
is_canonical(const vec_basic & arg) const3592 bool Min::is_canonical(const vec_basic &arg) const
3593 {
3594     if (arg.size() < 2)
3595         return false;
3596 
3597     bool non_number_exists = false;
3598 
3599     for (const auto &p : arg) {
3600         if (is_a<Complex>(*p) or is_a<Min>(*p))
3601             return false;
3602         if (not is_a_Number(*p))
3603             non_number_exists = true;
3604     }
3605     if (not std::is_sorted(arg.begin(), arg.end(), RCPBasicKeyLess()))
3606         return false;
3607 
3608     return non_number_exists; // all arguments cant be numbers
3609 }
3610 
create(const vec_basic & a) const3611 RCP<const Basic> Min::create(const vec_basic &a) const
3612 {
3613     return min(a);
3614 }
3615 
min(const vec_basic & arg)3616 RCP<const Basic> min(const vec_basic &arg)
3617 {
3618     bool number_set = false;
3619     RCP<const Number> min_number, difference;
3620     set_basic new_args;
3621 
3622     for (const auto &p : arg) {
3623         if (is_a<Complex>(*p))
3624             throw SymEngineException("Complex can't be passed to min!");
3625 
3626         if (is_a_Number(*p)) {
3627             if (not number_set) {
3628                 min_number = rcp_static_cast<const Number>(p);
3629 
3630             } else {
3631                 if (eq(*p, *Inf)) {
3632                     continue;
3633                 } else if (eq(*p, *NegInf)) {
3634                     return NegInf;
3635                 }
3636                 difference = min_number->sub(*rcp_static_cast<const Number>(p));
3637 
3638                 if (difference->is_zero() and not difference->is_exact()) {
3639                     if (min_number->is_exact())
3640                         min_number = rcp_static_cast<const Number>(p);
3641                 } else if (difference->is_positive()) {
3642                     min_number = rcp_static_cast<const Number>(p);
3643                 }
3644             }
3645             number_set = true;
3646 
3647         } else if (is_a<Min>(*p)) {
3648             for (const auto &l : down_cast<const Min &>(*p).get_args()) {
3649                 if (is_a_Number(*l)) {
3650                     if (not number_set) {
3651                         min_number = rcp_static_cast<const Number>(l);
3652 
3653                     } else {
3654                         difference = min_number->sub(
3655                             *rcp_static_cast<const Number>(l));
3656 
3657                         if (difference->is_zero()
3658                             and not difference->is_exact()) {
3659                             if (min_number->is_exact())
3660                                 min_number = rcp_static_cast<const Number>(l);
3661                         } else if (difference->is_positive()) {
3662                             min_number = rcp_static_cast<const Number>(l);
3663                         }
3664                     }
3665                     number_set = true;
3666                 } else {
3667                     new_args.insert(l);
3668                 }
3669             }
3670         } else {
3671             new_args.insert(p);
3672         }
3673     }
3674 
3675     if (number_set)
3676         new_args.insert(min_number);
3677 
3678     vec_basic final_args(new_args.size());
3679     std::copy(new_args.begin(), new_args.end(), final_args.begin());
3680 
3681     if (final_args.size() > 1) {
3682         return make_rcp<const Min>(std::move(final_args));
3683     } else if (final_args.size() == 1) {
3684         return final_args[0];
3685     } else {
3686         throw SymEngineException("Empty vec_basic passed to min!");
3687     }
3688 }
3689 
UnevaluatedExpr(const RCP<const Basic> & arg)3690 UnevaluatedExpr::UnevaluatedExpr(const RCP<const Basic> &arg)
3691     : OneArgFunction(arg)
3692 {
3693     SYMENGINE_ASSIGN_TYPEID()
3694     SYMENGINE_ASSERT(is_canonical(arg))
3695 }
3696 
is_canonical(const RCP<const Basic> & arg) const3697 bool UnevaluatedExpr::is_canonical(const RCP<const Basic> &arg) const
3698 {
3699     return true;
3700 }
3701 
create(const RCP<const Basic> & arg) const3702 RCP<const Basic> UnevaluatedExpr::create(const RCP<const Basic> &arg) const
3703 {
3704     return make_rcp<const UnevaluatedExpr>(arg);
3705 }
3706 
unevaluated_expr(const RCP<const Basic> & arg)3707 RCP<const Basic> unevaluated_expr(const RCP<const Basic> &arg)
3708 {
3709     return make_rcp<const UnevaluatedExpr>(arg);
3710 }
3711 
3712 } // namespace SymEngine
3713