1 /** @file normal.cpp
2  *
3  *  This file implements several functions that work on univariate and
4  *  multivariate polynomials and rational functions.
5  *  These functions include polynomial quotient and remainder, GCD and LCM
6  *  computation, square-free factorization and rational function normalization. */
7 
8 /*
9  *  GiNaC Copyright (C) 1999-2008 Johannes Gutenberg University Mainz, Germany
10  *
11  *  This program is free software; you can redistribute it and/or modify
12  *  it under the terms of the GNU General Public License as published by
13  *  the Free Software Foundation; either version 2 of the License, or
14  *  (at your option) any later version.
15  *
16  *  This program is distributed in the hope that it will be useful,
17  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
18  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
19  *  GNU General Public License for more details.
20  *
21  *  You should have received a copy of the GNU General Public License
22  *  along with this program; if not, write to the Free Software
23  *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
24  */
25 
26 #ifdef HAVE_CONFIG_H
27 #include "pynac-config.h"
28 #endif
29 
30 #include "normal.h"
31 #include "basic.h"
32 #include "ex.h"
33 #include "ex_utils.h"
34 #include "add.h"
35 #include "constant.h"
36 #include "expairseq.h"
37 #include "inifcns.h"
38 #include "lst.h"
39 #include "mul.h"
40 #include "numeric.h"
41 #include "power.h"
42 #include "relational.h"
43 #include "operators.h"
44 #include "matrix.h"
45 #include "pseries.h"
46 #include "symbol.h"
47 #include "utils.h"
48 #include "upoly.h"
49 #include "mpoly.h"
50 
51 #include <algorithm>
52 #include <map>
53 
54 namespace GiNaC {
55 
56 // If comparing expressions (ex::compare()) is fast, you can set this to 1.
57 // Some routines like quo(), rem() and gcd() will then return a quick answer
58 // when they are called with two identical arguments.
59 #define FAST_COMPARE 1
60 
61 // Set this if you want divide_in_z() to use remembering
62 #define USE_REMEMBER 0
63 
64 // Set this if you want divide_in_z() to use trial division followed by
65 // polynomial interpolation (always slower except for completely dense
66 // polynomials)
67 #define USE_TRIAL_DIVISION 0
68 
69 // Set this to enable some statistical output for the GCD routines
70 #define STATISTICS 0
71 
72 static symbol symbol_E;
73 
74 /** Compute the integer content (= GCD of all numeric coefficients) of an
75  *  expanded polynomial. For a polynomial with rational coefficients, this
76  *  returns g/l where g is the GCD of the coefficients' numerators and l
77  *  is the LCM of the coefficients' denominators.
78  *
79  *  @return integer content */
integer_content() const80 numeric ex::integer_content() const
81 {
82 	return bp->integer_content();
83 }
84 
integer_content() const85 numeric basic::integer_content() const
86 {
87 	return *_num1_p;
88 }
89 
integer_content() const90 numeric numeric::integer_content() const
91 {
92 	return abs();
93 }
94 
integer_content() const95 numeric add::integer_content() const
96 {
97 	auto it = seq.begin();
98 	auto itend = seq.end();
99 	numeric c = *_num0_p, l = *_num1_p;
100 	while (it != itend) {
101 		GINAC_ASSERT(!is_exactly_a<numeric>(it->rest));
102 		GINAC_ASSERT(is_exactly_a<numeric>(it->coeff));
103 		c = gcd(ex_to<numeric>(it->coeff).numer(), c);
104 		l = lcm(ex_to<numeric>(it->coeff).denom(), l);
105 		it++;
106 	}
107 	c = gcd(overall_coeff.numer(), c);
108 	l = lcm(overall_coeff.denom(), l);
109 	return (c/l).abs();
110 }
111 
integer_content() const112 numeric mul::integer_content() const
113 {
114 #ifdef DO_GINAC_ASSERT
115 	epvector::const_iterator it = seq.begin();
116 	epvector::const_iterator itend = seq.end();
117 	while (it != itend) {
118 		GINAC_ASSERT(!is_exactly_a<numeric>(recombine_pair_to_ex(*it)));
119 		++it;
120 	}
121 #endif // def DO_GINAC_ASSERT
122 	return overall_coeff.abs();
123 }
124 
125 
126 #if USE_REMEMBER
127 /*
128  *  Remembering
129  */
130 
131 typedef std::pair<ex, ex> ex2;
132 typedef std::pair<ex, bool> exbool;
133 
134 struct ex2_less {
operator ()GiNaC::ex2_less135 	bool operator() (const ex2 &p, const ex2 &q) const
136 	{
137 		int cmp = p.first.compare(q.first);
138 		return ((cmp<0) || (!(cmp>0) && p.second.compare(q.second)<0));
139 	}
140 };
141 
142 typedef std::map<ex2, exbool, ex2_less> ex2_exbool_remember;
143 #endif
144 
145 
146 /** Return maximum (absolute value) coefficient of a polynomial.
147  *  This function was used internally by heur_gcd().
148  *
149  *  @return maximum coefficient
150  */
max_coefficient() const151 numeric ex::max_coefficient() const
152 {
153 	return bp->max_coefficient();
154 }
155 
156 /** Implementation ex::max_coefficient().
157  */
max_coefficient() const158 numeric basic::max_coefficient() const
159 {
160 	return *_num1_p;
161 }
162 
max_coefficient() const163 numeric numeric::max_coefficient() const
164 {
165 	return abs();
166 }
167 
max_coefficient() const168 numeric add::max_coefficient() const
169 {
170 	auto it = seq.begin();
171 	auto itend = seq.end();
172 	numeric cur_max = abs(overall_coeff);
173 	while (it != itend) {
174 		numeric a;
175 		GINAC_ASSERT(!is_exactly_a<numeric>(it->rest));
176 		a = abs(ex_to<numeric>(it->coeff));
177 		if (a > cur_max)
178 			cur_max = a;
179 		it++;
180 	}
181 	return cur_max;
182 }
183 
max_coefficient() const184 numeric mul::max_coefficient() const
185 {
186 #ifdef DO_GINAC_ASSERT
187 	epvector::const_iterator it = seq.begin();
188 	epvector::const_iterator itend = seq.end();
189 	while (it != itend) {
190 		GINAC_ASSERT(!is_exactly_a<numeric>(recombine_pair_to_ex(*it)));
191 		it++;
192 	}
193 #endif // def DO_GINAC_ASSERT
194 	return abs(overall_coeff);
195 }
196 
is_linear(const symbol & x,ex & a,ex & b) const197 bool ex::is_linear(const symbol& x, ex& a, ex& b) const
198 {
199         expand();
200         if (not has_symbol(*this, x)) {
201                 a = *this;
202                 b = _ex0;
203                 return true;
204         }
205         if (this->is_equal(x)) {
206                 a = _ex0;
207                 b = _ex1;
208                 return true;
209         }
210         if (is_exactly_a<mul>(*this)) {
211                 if (has_symbol(*this/x, x))
212                         return false;
213                 a = _ex0;
214                 b = *this/x;
215                 return true;
216         }
217         if (not is_exactly_a<add>(*this))
218                 return false;
219         const add& A = ex_to<add>(*this);
220         exvector cterms, xterms;
221         for (unsigned i=0; i<A.nops(); ++i)
222                 if (has_symbol(A.op(i), x))
223                         xterms.push_back(A.op(i));
224                 else
225                         cterms.push_back(A.op(i));
226         ex xt = (add(xterms) / x).normal();
227         if (has_symbol(xt, x))
228                 return false;
229         a = add(cterms);
230         b = xt;
231         return true;
232 }
233 
is_quadratic(const symbol & x,ex & a,ex & b,ex & c) const234 bool ex::is_quadratic(const symbol& x, ex& a, ex& b, ex& c) const
235 {
236         expand();
237         expairvec coeffs;
238         coefficients(x, coeffs);
239         b = c = _ex0;
240         for (const auto& p : coeffs) {
241                 const ex& d = p.second;
242                 if (d.is_equal(_ex2)) {
243                         c = p.first;
244                         if (has_symbol(c,x))
245                                 return false;
246                 }
247                 else if (d.is_equal(_ex1)) {
248                         b = p.first;
249                         if (has_symbol(b,x))
250                                 return false;
251                 }
252                 else if (not d.is_equal(_ex0))
253                         return false;
254         }
255         a = ((*this) - c*power(x,2) - b*x).expand();
256         if (has_symbol(a,x))
257                 return false;
258         return true;
259 }
260 
is_binomial(const symbol & x,ex & a,ex & j,ex & b,ex & n) const261 bool ex::is_binomial(const symbol& x, ex& a, ex& j, ex& b, ex& n) const
262 {
263         expand();
264         if (is_linear(x, a, b)) {
265                 j = _ex0;
266                 if (b.is_zero())
267                         n = _ex0;
268                 else
269                         n = _ex1;
270                 return true;
271         }
272         if (is_exactly_a<power>(*this)) {
273                 const power& p = ex_to<power>(*this);
274                 if (has_symbol(p.op(1), x)
275                     or not p.op(0).is_equal(x))
276                         return false;
277                 a = _ex1;
278                 j = p.op(1);
279                 b = _ex0;
280                 n = _ex0;
281                 return true;
282         }
283         if (is_exactly_a<mul>(*this)) {
284                 const mul& m = ex_to<mul>(*this);
285                 ex cprod = _ex1;
286                 j = _ex0;
287                 b = _ex0;
288                 n = _ex0;
289                 for (unsigned i=0; i<m.nops(); ++i) {
290                         const ex& factor = m.op(i);
291                         if (not has_symbol(factor, x))
292                                 cprod *= factor;
293                         else if (is_exactly_a<power>(factor)) {
294                                 const power& pow = ex_to<power>(factor);
295                                 if (has_symbol(pow.op(1), x)
296                                     or not pow.op(0).is_equal(x))
297                                         return false;
298                                 j = pow.op(1);
299                         }
300                         else if (factor.is_equal(x))
301                                 j = _ex1;
302                         else
303                                 return false;
304                 }
305                 a = cprod;
306                 return true;
307         }
308         if (not is_exactly_a<add>(*this))
309                 return false;
310         const add& A = ex_to<add>(*this);
311         exvector cterms, xterms;
312         for (unsigned i=0; i<A.nops(); ++i)
313                 if (has_symbol(A.op(i), x))
314                         xterms.push_back(A.op(i));
315                 else
316                         cterms.push_back(A.op(i));
317         if (xterms.size() > 2
318             or (xterms.size() == 2
319                 and cterms.size() > 0))
320                 return false;
321 
322         ex ta, tj, tb, tn;
323         bool r = xterms[0].is_binomial(x, ta, tj, tb, tn);
324         if (not r)
325                 return false;
326         a = ta;
327         j = tj;
328         if (xterms.size() < 2) {
329                 b = add(cterms);
330                 n = _ex0;
331                 return true;
332         }
333         r = xterms[1].is_binomial(x, ta, tj, tb, tn);
334         if (not r)
335                 return false;
336         b = ta;
337         n = tj;
338         return true;
339 }
340 
341 /** Apply symmetric modular homomorphism to an expanded multivariate
342  *  polynomial.  This function was usually used internally by heur_gcd().
343  *
344  *  @param xi  modulus
345  *  @return mapped polynomial
346  */
smod(const numeric & xi) const347 ex basic::smod(const numeric &xi) const
348 {
349 	return *this;
350 }
351 
smod(const numeric & xi) const352 ex numeric::smod(const numeric &xi) const
353 {
354 	return GiNaC::smod(*this, xi);
355 }
356 
smod(const numeric & xi) const357 ex add::smod(const numeric &xi) const
358 {
359 	epvector newseq;
360 	newseq.reserve(seq.size()+1);
361 	auto it = seq.begin();
362 	auto itend = seq.end();
363 	while (it != itend) {
364 		GINAC_ASSERT(!is_exactly_a<numeric>(it->rest));
365 		numeric num_coeff = GiNaC::smod(ex_to<numeric>(it->coeff), xi);
366 		if (!num_coeff.is_zero())
367 			newseq.emplace_back(it->rest, num_coeff);
368 		it++;
369 	}
370 	numeric num_coeff = GiNaC::smod(overall_coeff, xi);
371 	return (new add(newseq, num_coeff))->setflag(status_flags::dynallocated);
372 }
373 
smod(const numeric & xi) const374 ex mul::smod(const numeric &xi) const
375 {
376 #ifdef DO_GINAC_ASSERT
377 	epvector::const_iterator it = seq.begin();
378 	epvector::const_iterator itend = seq.end();
379 	while (it != itend) {
380 		GINAC_ASSERT(!is_exactly_a<numeric>(recombine_pair_to_ex(*it)));
381 		it++;
382 	}
383 #endif // def DO_GINAC_ASSERT
384 	auto  mulcopyp = new mul(*this);
385 	mulcopyp->overall_coeff = GiNaC::smod(overall_coeff,xi);
386 	mulcopyp->clearflag(status_flags::evaluated);
387 	mulcopyp->clearflag(status_flags::hash_calculated);
388 	return mulcopyp->setflag(status_flags::dynallocated);
389 }
390 
391 
392 /*
393  *  Normal form of rational functions
394  */
395 
396 /*
397  *  Note: The internal normal() functions (= basic::normal() and overloaded
398  *  functions) all return lists of the form {numerator, denominator}. This
399  *  is to get around mul::eval()'s automatic expansion of numeric coefficients.
400  *  E.g. (a+b)/3 is automatically converted to a/3+b/3 but we want to keep
401  *  the information that (a+b) is the numerator and 3 is the denominator.
402  */
403 
404 
405 /** Create a symbol for replacing the expression "e" (or return a previously
406  *  assigned symbol). The symbol and expression are appended to repl, for
407  *  a later application of subs().
408  *  @see ex::normal */
replace_with_symbol(const ex & e,exmap & repl,exmap & rev_lookup)409 static ex replace_with_symbol(const ex & e, exmap & repl, exmap & rev_lookup)
410 {
411 	// Since the repl contains replaced expressions we should search for them
412 	ex e_replaced = e.subs(repl, subs_options::no_pattern);
413 
414 	// Expression already replaced? Then return the assigned symbol
415 	auto it = rev_lookup.find(e_replaced);
416 	if (it != rev_lookup.end())
417 		return it->second;
418 
419 	// Otherwise create new symbol and add to list, taking care that the
420 	// replacement expression doesn't itself contain symbols from repl,
421 	// because subs() is not recursive
422 	symbol* sp = new symbol;
423         sp->set_domain_from_ex(e_replaced);
424         ex es = sp->setflag(status_flags::dynallocated);
425 	repl.insert(std::make_pair(es, e_replaced));
426 	rev_lookup.insert(std::make_pair(e_replaced, es));
427 	return es;
428 }
429 
430 /** Create a symbol for replacing the expression "e" (or return a previously
431  *  assigned symbol). The symbol and expression are appended to repl, and the
432  *  symbol is returned.
433  *  @see basic::to_rational
434  *  @see basic::to_polynomial */
replace_with_symbol(const ex & e,exmap & repl)435 static ex replace_with_symbol(const ex & e, exmap & repl)
436 {
437 	// Since the repl contains replaced expressions we should search for them
438 	ex e_replaced = e.subs(repl, subs_options::no_pattern);
439 
440 	// Expression already replaced? Then return the assigned symbol
441 	for (const auto& elem : repl)
442 		if (elem.second.is_equal(e_replaced))
443 			return elem.first;
444 
445 	// Otherwise create new symbol and add to list, taking care that the
446 	// replacement expression doesn't itself contain symbols from repl,
447 	// because subs() is not recursive
448 	symbol* sp = new symbol;
449         sp->set_domain_from_ex(e_replaced);
450         ex es = sp->setflag(status_flags::dynallocated);
451 	repl.insert(std::make_pair(es, e_replaced));
452 	return es;
453 }
454 
455 
456 /** Function object to be applied by basic::normal(). */
457 struct normal_map_function : public map_function {
458 	int level;
normal_map_functionGiNaC::normal_map_function459 	normal_map_function(int l) : level(l) {}
operator ()GiNaC::normal_map_function460 	ex operator()(const ex & e) override { return normal(e, level); }
461 };
462 
463 /** Default implementation of ex::normal(). It normalizes the children and
464  *  replaces the object with a temporary symbol.
465  *  @see ex::normal */
normal(exmap & repl,exmap & rev_lookup,int level,unsigned options) const466 ex basic::normal(exmap & repl, exmap & rev_lookup, int level, unsigned options) const
467 {
468 	if (nops() == 0)
469 		return (new lst(replace_with_symbol(*this, repl, rev_lookup), _ex1))->setflag(status_flags::dynallocated);
470         if (level == 1)
471                 return (new lst(replace_with_symbol(*this, repl, rev_lookup), _ex1))->setflag(status_flags::dynallocated);
472         if (level == -max_recursion_level)
473                 throw(std::runtime_error("max recursion level reached"));
474         else {
475                 normal_map_function map_normal(level - 1);
476                 return (new lst(replace_with_symbol(map(map_normal), repl, rev_lookup), _ex1))->setflag(status_flags::dynallocated);
477         }
478 }
479 
480 
481 /** Implementation of ex::normal() for symbols. This returns the unmodified symbol.
482  *  @see ex::normal */
normal(exmap & repl,exmap & rev_lookup,int level,unsigned options) const483 ex symbol::normal(exmap & repl, exmap & rev_lookup, int level, unsigned options) const
484 {
485 	return (new lst(*this, _ex1))->setflag(status_flags::dynallocated);
486 }
487 
488 
489 /** Implementation of ex::normal() for a numeric. It splits complex numbers
490  *  into re+I*im and replaces I and non-rational real numbers with a temporary
491  *  symbol.
492  *  @see ex::normal */
normal(exmap & repl,exmap & rev_lookup,int level,unsigned options) const493 ex numeric::normal(exmap & repl, exmap & rev_lookup, int level, unsigned options) const
494 {
495 	numeric num = numer();
496 	ex numex = num;
497 
498 	if (num.is_real()) {
499 		if (!num.is_integer())
500 			numex = replace_with_symbol(numex, repl, rev_lookup);
501 	} else { // complex
502 		numeric re = num.real(), im = num.imag();
503 		ex re_ex = re.is_rational() ? re : replace_with_symbol(re, repl, rev_lookup);
504 		ex im_ex = im.is_rational() ? im : replace_with_symbol(im, repl, rev_lookup);
505 		numex = re_ex + im_ex * replace_with_symbol(I, repl, rev_lookup);
506 	}
507 
508 	// Denominator is always a real integer (see numeric::denom())
509 	return (new lst(numex, denom()))->setflag(status_flags::dynallocated);
510 }
511 
512 
513 /** Fraction cancellation.
514  *  @param n  numerator
515  *  @param d  denominator
516  *  @return cancelled fraction {n, d} as a list */
frac_cancel(const ex & n,const ex & d)517 static ex frac_cancel(const ex &n, const ex &d)
518 {
519 	ex num = n;
520 	ex den = d;
521 	numeric pre_factor = *_num1_p;
522 
523 //std::clog << "frac_cancel num = " << num << ", den = " << den << std::endl;
524 
525 	// Handle trivial case where denominator is 1
526 	if (den.is_one())
527 		return (new lst(num, den))->setflag(status_flags::dynallocated);
528 
529 	// Handle special cases where numerator or denominator is 0
530 	if (num.is_zero())
531 		return (new lst(num, _ex1))->setflag(status_flags::dynallocated);
532 	if (den.is_zero())
533 		throw(std::overflow_error("frac_cancel: division by zero in frac_cancel"));
534 
535 	// Bring numerator and denominator to Z[X] by multiplying with
536 	// LCM of all coefficients' denominators
537 	numeric num_lcm = lcm_of_coefficients_denominators(num);
538 	numeric den_lcm = lcm_of_coefficients_denominators(den);
539 	num = multiply_lcm(num, num_lcm);
540 	den = multiply_lcm(den, den_lcm);
541 	pre_factor = den_lcm / num_lcm;
542 
543 	// Cancel GCD from numerator and denominator
544 	ex cnum, cden;
545 	if (not gcdpoly(num, den, &cnum, &cden, false).is_one()) {
546 		num = cnum;
547 		den = cden;
548 	}
549 
550 	// Make denominator unit normal (i.e. coefficient of first symbol
551 	// as defined by get_first_symbol() is made positive)
552 	if (is_exactly_a<numeric>(den)) {
553 		if (ex_to<numeric>(den).is_negative()) {
554 			num *= _ex_1;
555 			den *= _ex_1;
556 		}
557 	} else {
558 		ex x;
559 		if (den.get_first_symbol(x)) {
560 			GINAC_ASSERT(is_exactly_a<numeric>(den.unit(x)));
561 			if (ex_to<numeric>(den.unit(x)).is_negative()) {
562 				num *= _ex_1;
563 				den *= _ex_1;
564 			}
565 		}
566 	}
567 
568 	// Return result as list
569 //std::clog << " returns num = " << num << ", den = " << den << ", pre_factor = " << pre_factor << std::endl;
570 	return (new lst(num * pre_factor.numer(), den * pre_factor.denom()))->setflag(status_flags::dynallocated);
571 }
572 
normal(exmap & repl,exmap & rev_lookup,int level,unsigned options) const573 ex function::normal(exmap & repl, exmap & rev_lookup, int level, unsigned options) const
574 {
575         if (get_serial() == exp_SERIAL::serial) {
576                 GiNaC::power p(symbol_E, op(0));
577                 return p.normal(repl, rev_lookup, level, options);
578         }
579         if (level == 1)
580                 return (new lst(replace_with_symbol(*this, repl, rev_lookup), _ex1))->setflag(status_flags::dynallocated);
581         if (level == -max_recursion_level)
582                 throw(std::runtime_error("max recursion level reached"));
583         else {
584                 normal_map_function map_normal(level - 1);
585                 return (new lst(replace_with_symbol(map(map_normal), repl, rev_lookup), _ex1))->setflag(status_flags::dynallocated);
586         }
587 }
588 
589 /** Implementation of ex::normal() for a sum. It expands terms and performs
590  *  fractional addition.
591  *  @see ex::normal */
normal(exmap & repl,exmap & rev_lookup,int level,unsigned options) const592 ex add::normal(exmap & repl, exmap & rev_lookup, int level, unsigned options) const
593 {
594 	if (level == 1)
595 		return (new lst(replace_with_symbol(*this, repl, rev_lookup), _ex1))->setflag(status_flags::dynallocated);
596 	if (level == -max_recursion_level)
597 		throw(std::runtime_error("max recursion level reached"));
598 
599 	// Normalize children and split each one into numerator and denominator
600 	exvector nums, dens;
601 	nums.reserve(seq.size()+1);
602 	dens.reserve(seq.size()+1);
603         for (const auto& pair : seq) {
604 		const ex& term = recombine_pair_to_ex(pair);
605 		ex n = ex_to<basic>(term).normal(repl, rev_lookup, level-1);
606 		nums.push_back(n.op(0));
607 		dens.push_back(n.op(1));
608 	}
609 	ex n = overall_coeff.normal(repl, rev_lookup, level-1);
610 	nums.push_back(n.op(0));
611 	dens.push_back(n.op(1));
612 	GINAC_ASSERT(nums.size() == dens.size());
613 
614 	// Now, nums is a vector of all numerators and dens is a vector of
615 	// all denominators
616 //std::clog << "add::normal uses " << nums.size() << " summands:\n";
617 
618 	// Add fractions sequentially
619 	auto num_it = nums.begin(), num_itend = nums.end();
620 	auto den_it = dens.begin(), den_itend = dens.end();
621 //std::clog << " num = " << *num_it << ", den = " << *den_it << std::endl;
622 	ex num = *num_it++, den = *den_it++;
623 	while (num_it != num_itend) {
624 //std::clog << " num = " << *num_it << ", den = " << *den_it << std::endl;
625 		ex next_num = *num_it++, next_den = *den_it++;
626 
627 		// Trivially add sequences of fractions with identical denominators
628 		while ((den_it != den_itend) && next_den.is_equal(*den_it)) {
629 			next_num += *num_it;
630 			num_it++; den_it++;
631 		}
632 
633 		// Additiion of two fractions, taking advantage of the fact that
634 		// the heuristic GCD algorithm computes the cofactors at no extra cost
635 		ex co_den1, co_den2;
636 		ex g = gcdpoly(den, next_den, &co_den1, &co_den2, false);
637 		num = ((num * co_den2) + (next_num * co_den1));
638                 if ((options & normal_options::no_expand_combined_numer) == 0u)
639                         num = num.expand();
640 		den *= co_den2;		// this is the lcm(den, next_den)
641 	}
642 //std::clog << " common denominator = " << den << std::endl;
643 
644 	// Cancel common factors from num/den
645 	return frac_cancel(num, den);
646 }
647 
648 
649 /** Implementation of ex::normal() for a product. It cancels common factors
650  *  from fractions.
651  *  @see ex::normal() */
normal(exmap & repl,exmap & rev_lookup,int level,unsigned options) const652 ex mul::normal(exmap & repl, exmap & rev_lookup, int level, unsigned options) const
653 {
654 	if (level == 1)
655 		return (new lst(replace_with_symbol(*this, repl, rev_lookup), _ex1))->setflag(status_flags::dynallocated);
656 	if (level == -max_recursion_level)
657 		throw(std::runtime_error("max recursion level reached"));
658 
659 	// Normalize children, separate into numerator and denominator
660 	exvector num; num.reserve(seq.size());
661 	exvector den; den.reserve(seq.size());
662 	ex n;
663         for (const auto& pair : seq) {
664 		const ex& term = recombine_pair_to_ex(pair);
665 		n = ex_to<basic>(term).normal(repl, rev_lookup, level-1);
666 		num.push_back(n.op(0));
667 		den.push_back(n.op(1));
668 	}
669 	n = overall_coeff.normal(repl, rev_lookup, level-1);
670 	num.push_back(n.op(0));
671 	den.push_back(n.op(1));
672 
673 	// Perform fraction cancellation
674 	return frac_cancel((new mul(num))->setflag(status_flags::dynallocated),
675 	                   (new mul(den))->setflag(status_flags::dynallocated));
676 }
677 
678 
679 /** Implementation of ex::normal([B) for powers. It normalizes the basis,
680  *  distributes integer exponents to numerator and denominator, and replaces
681  *  non-integer powers by temporary symbols.
682  *  @see ex::normal */
normal(exmap & repl,exmap & rev_lookup,int level,unsigned options) const683 ex power::normal(exmap & repl, exmap & rev_lookup, int level, unsigned options) const
684 {
685 	if (level == 1)
686 		return (new lst(replace_with_symbol(*this, repl, rev_lookup), _ex1))->setflag(status_flags::dynallocated);
687 	if (level == -max_recursion_level)
688 		throw(std::runtime_error("max recursion level reached"));
689 
690 	// Normalize basis and exponent (exponent gets reassembled)
691 	ex n_basis = ex_to<basic>(basis).normal(repl, rev_lookup, level-1);
692 	ex n_exponent = ex_to<basic>(exponent).normal(repl, rev_lookup, level-1);
693 	n_exponent = n_exponent.op(0) / n_exponent.op(1);
694 
695 	if (n_exponent.is_integer()) {
696 
697 		if (n_exponent.is_positive()) {
698 			// (a/b)^n -> {a^n, b^n}
699 			return (new lst(power(n_basis.op(0), n_exponent),
700                                               power(n_basis.op(1), n_exponent)))
701                                 ->setflag(status_flags::dynallocated);
702 
703 		}
704                 if (n_exponent.info(info_flags::negative)) {
705 			// (a/b)^-n -> {b^n, a^n}
706 			return (new lst(power(n_basis.op(1), -n_exponent),
707                                               power(n_basis.op(0), -n_exponent)))
708                                 ->setflag(status_flags::dynallocated);
709 		}
710 	} else {
711 
712 		if (n_exponent.is_positive()) {
713 			// (a/b)^x -> {sym((a/b)^x), 1}
714 			return (new lst(replace_with_symbol(power(n_basis.op(0) / n_basis.op(1),
715                                                                 n_exponent),
716                                                             repl, rev_lookup),
717                                                 _ex1))
718                                 ->setflag(status_flags::dynallocated);
719 		}
720                 if (n_exponent.info(info_flags::negative)) {
721 
722 			if (n_basis.op(1).is_one()) {
723 
724 				// a^-x -> {1, sym(a^x)}
725 				return (new lst(_ex1,
726                                                 replace_with_symbol(power(n_basis.op(0), -n_exponent),
727                                                         repl, rev_lookup)))
728                                         ->setflag(status_flags::dynallocated);
729 			}
730 
731                 // (a/b)^-x -> {sym((b/a)^x), 1}
732                 return (new lst(replace_with_symbol(power(n_basis.op(1) / n_basis.op(0), -n_exponent),
733                                                 repl, rev_lookup), _ex1))
734                         ->setflag(status_flags::dynallocated);
735 
736 		}
737 	}
738 
739 	// (a/b)^x -> {sym((a/b)^x, 1}
740 	return (new lst(replace_with_symbol(power(n_basis.op(0) / n_basis.op(1),
741                                                 n_exponent),
742                                         repl, rev_lookup), _ex1))
743                 ->setflag(status_flags::dynallocated);
744 }
745 
746 
747 /** Implementation of ex::normal() for pseries. It normalizes each coefficient
748  *  and replaces the series by a temporary symbol.
749  *  @see ex::normal */
normal(exmap & repl,exmap & rev_lookup,int level,unsigned options) const750 ex pseries::normal(exmap & repl, exmap & rev_lookup, int level, unsigned options) const
751 {
752 	epvector newseq;
753 	auto i = seq.begin(), end = seq.end();
754 	while (i != end) {
755 		ex restexp = i->rest.normal();
756 		if (!restexp.is_zero())
757 			newseq.emplace_back(restexp, i->coeff);
758 		++i;
759 	}
760 	ex n = pseries(relational(var,point), newseq);
761 	return (new lst(replace_with_symbol(n, repl, rev_lookup), _ex1))->setflag(status_flags::dynallocated);
762 }
763 
764 
765 /** Normalization of rational functions.
766  *  This function converts an expression to its normal form
767  *  "numerator/denominator", where numerator and denominator are (relatively
768  *  prime) polynomials. Any subexpressions which are not rational functions
769  *  (like non-rational numbers, non-integer powers or functions like sin(),
770  *  cos() etc.) are replaced by temporary symbols which are re-substituted by
771  *  the (normalized) subexpressions before normal() returns (this way, any
772  *  expression can be treated as a rational function). normal() is applied
773  *  recursively to arguments of functions etc.
774  *
775  *  @param level maximum depth of recursion
776  *  @return normalized expression */
normal(int level,bool noexpand_combined,bool noexpand_numer) const777 ex ex::normal(int level, bool noexpand_combined, bool noexpand_numer) const
778 {
779 	exmap repl, rev_lookup;
780 
781         unsigned options = 0;
782         if (noexpand_combined)
783                 options |= normal_options::no_expand_combined_numer;
784         if (noexpand_numer)
785                 options |= normal_options::no_expand_fraction_numer;
786 
787 	ex e = bp->normal(repl, rev_lookup, level, options);
788 	GINAC_ASSERT(is_a<lst>(e));
789 
790 	// Re-insert replaced symbols and exp functions
791         e = e.subs(repl, subs_options::no_pattern);
792 	e = e.subs(symbol_E == exp(1));
793 
794         // Convert {numerator, denominator} form back to fraction
795         if ((options & normal_options::no_expand_fraction_numer) == 0u)
796                 return e.op(0).expand() / e.op(1);
797 
798         return e.op(0) / e.op(1);
799 }
800 
801 /** Get numerator of an expression. If the expression is not of the normal
802  *  form "numerator/denominator", it is first converted to this form and
803  *  then the numerator is returned.
804  *
805  *  @see ex::normal
806  *  @return numerator */
numer() const807 ex ex::numer() const
808 {
809 	exmap repl, rev_lookup;
810 
811 	ex e = bp->normal(repl, rev_lookup, 0);
812 	GINAC_ASSERT(is_a<lst>(e));
813 
814 	// Re-insert replaced symbols
815 	if (repl.empty())
816 		e = e.op(0);
817 	else
818 		e = e.op(0).subs(repl, subs_options::no_pattern);
819 	e = e.subs(symbol_E == exp(1));
820         return e;
821 }
822 
823 /** Get denominator of an expression. If the expression is not of the normal
824  *  form "numerator/denominator", it is first converted to this form and
825  *  then the denominator is returned.
826  *
827  *  @see ex::normal
828  *  @return denominator */
denom() const829 ex ex::denom() const
830 {
831 	exmap repl, rev_lookup;
832 
833 	ex e = bp->normal(repl, rev_lookup, 0);
834 	GINAC_ASSERT(is_a<lst>(e));
835 
836         // Re-insert replaced symbols
837 	if (repl.empty())
838 		e = e.op(1);
839 	else
840 		e = e.op(1).subs(repl, subs_options::no_pattern);
841 	e = e.subs(symbol_E == exp(1));
842         return e;
843 }
844 
845 /** Get numerator and denominator of an expression. If the expresison is not
846  *  of the normal form "numerator/denominator", it is first converted to this
847  *  form and then a list [numerator, denominator] is returned.
848  *
849  *  @see ex::normal
850  *  @return a list [numerator, denominator] */
numer_denom() const851 ex ex::numer_denom() const
852 {
853 	exmap repl, rev_lookup;
854 
855 	ex e = bp->normal(repl, rev_lookup, 0);
856 	GINAC_ASSERT(is_a<lst>(e));
857 
858 	// Re-insert replaced symbols
859 	if (not repl.empty())
860 		e = e.subs(repl, subs_options::no_pattern);
861 	e = e.subs(symbol_E == exp(1));
862         return e;
863 }
864 
865 
866 /** Rationalization of non-rational functions.
867  *  This function converts a general expression to a rational function
868  *  by replacing all non-rational subexpressions (like non-rational numbers,
869  *  non-integer powers or functions like sin(), cos() etc.) to temporary
870  *  symbols. This makes it possible to use functions like gcd() and divide()
871  *  on non-rational functions by applying to_rational() on the arguments,
872  *  calling the desired function and re-substituting the temporary symbols
873  *  in the result. To make the last step possible, all temporary symbols and
874  *  their associated expressions are collected in the map specified by the
875  *  repl parameter, ready to be passed as an argument to ex::subs().
876  *
877  *  @param repl collects all temporary symbols and their replacements
878  *  @return rationalized expression */
to_rational(exmap & repl) const879 ex ex::to_rational(exmap & repl) const
880 {
881 	return bp->to_rational(repl);
882 }
883 
884 // GiNaC 1.1 compatibility function
to_rational(lst & repl_lst) const885 ex ex::to_rational(lst & repl_lst) const
886 {
887 	// Convert lst to exmap
888 	exmap m;
889 	for (const auto & elem : repl_lst)
890 		m.insert(std::make_pair(elem.op(0), elem.op(1)));
891 
892 	ex ret = bp->to_rational(m);
893 
894 	// Convert exmap back to lst
895 	repl_lst.remove_all();
896 	for (const auto& elem : m)
897 		repl_lst.append(elem.first == elem.second);
898 
899 	return ret;
900 }
901 
to_polynomial(exmap & repl) const902 ex ex::to_polynomial(exmap & repl) const
903 {
904 	return bp->to_polynomial(repl);
905 }
906 
907 // GiNaC 1.1 compatibility function
to_polynomial(lst & repl_lst) const908 ex ex::to_polynomial(lst & repl_lst) const
909 {
910 	// Convert lst to exmap
911 	exmap m;
912 	for (const auto & elem : repl_lst)
913 		m.insert(std::make_pair(elem.op(0), elem.op(1)));
914 
915 	ex ret = bp->to_polynomial(m);
916 
917 	// Convert exmap back to lst
918 	repl_lst.remove_all();
919 	for (const auto& elem : m)
920 		repl_lst.append(elem.first == elem.second);
921 
922 	return ret;
923 }
924 
925 /** Default implementation of ex::to_rational(). This replaces the object with
926  *  a temporary symbol. */
to_rational(exmap & repl) const927 ex basic::to_rational(exmap & repl) const
928 {
929 	return replace_with_symbol(*this, repl);
930 }
931 
to_polynomial(exmap & repl) const932 ex basic::to_polynomial(exmap & repl) const
933 {
934 	return replace_with_symbol(*this, repl);
935 }
936 
937 
938 /** Implementation of ex::to_rational() for symbols. This returns the
939  *  unmodified symbol. */
to_rational(exmap & repl) const940 ex symbol::to_rational(exmap & repl) const
941 {
942 	return *this;
943 }
944 
945 /** Implementation of ex::to_polynomial() for symbols. This returns the
946  *  unmodified symbol. */
to_polynomial(exmap & repl) const947 ex symbol::to_polynomial(exmap & repl) const
948 {
949 	return *this;
950 }
951 
952 
953 /** Implementation of ex::to_rational() for a numeric. It splits complex
954  *  numbers into re+I*im and replaces I and non-rational real numbers with a
955  *  temporary symbol. */
to_rational(exmap & repl) const956 ex numeric::to_rational(exmap & repl) const
957 {
958 	if (is_real()) {
959 		if (!is_rational())
960 			return replace_with_symbol(*this, repl);
961 	} else { // complex
962 		numeric re = real();
963 		numeric im = imag();
964 		ex re_ex = re.is_rational() ? re : replace_with_symbol(re, repl);
965 		ex im_ex = im.is_rational() ? im : replace_with_symbol(im, repl);
966 		return re_ex + im_ex * replace_with_symbol(I, repl);
967 	}
968 	return *this;
969 }
970 
971 /** Implementation of ex::to_polynomial() for a numeric. It splits complex
972  *  numbers into re+I*im and replaces I and non-integer real numbers with a
973  *  temporary symbol. */
to_polynomial(exmap & repl) const974 ex numeric::to_polynomial(exmap & repl) const
975 {
976 	if (is_real()) {
977 		if (!is_integer())
978 			return replace_with_symbol(*this, repl);
979 	} else { // complex
980 		numeric re = real();
981 		numeric im = imag();
982 		ex re_ex = re.is_integer() ? re : replace_with_symbol(re, repl);
983 		ex im_ex = im.is_integer() ? im : replace_with_symbol(im, repl);
984 		return re_ex + im_ex * replace_with_symbol(I, repl);
985 	}
986 	return *this;
987 }
988 
989 
990 /** Implementation of ex::to_rational() for powers. It replaces non-integer
991  *  powers by temporary symbols. */
to_rational(exmap & repl) const992 ex power::to_rational(exmap & repl) const
993 {
994 	if (exponent.is_integer())
995 		return power(basis.to_rational(repl), exponent);
996 
997 	return replace_with_symbol(*this, repl);
998 }
999 
1000 /** Implementation of ex::to_polynomial() for powers. It replaces non-posint
1001  *  powers by temporary symbols. */
to_polynomial(exmap & repl) const1002 ex power::to_polynomial(exmap & repl) const
1003 {
1004 	if (exponent.info(info_flags::posint))
1005 		return power(basis.to_rational(repl), exponent);
1006 	if (exponent.info(info_flags::negint))
1007 	{
1008 		ex basis_pref = collect_common_factors(basis);
1009 		if (is_exactly_a<mul>(basis_pref)
1010                     or is_exactly_a<power>(basis_pref)) {
1011 			// (A*B)^n will be automagically transformed to A^n*B^n
1012 			ex t = power(basis_pref, exponent);
1013 			return t.to_polynomial(repl);
1014 		}
1015 		return power(replace_with_symbol(power(basis, _ex_1), repl),
1016                                 -exponent);
1017 	}
1018 	return replace_with_symbol(*this, repl);
1019 }
1020 
1021 
1022 /** Implementation of ex::to_rational() for expairseqs. */
to_rational(exmap & repl) const1023 ex expairseq::to_rational(exmap & repl) const
1024 {
1025 	epvector s;
1026 	s.reserve(seq.size());
1027 	auto i = seq.begin(), end = seq.end();
1028 	while (i != end) {
1029 		s.push_back(split_ex_to_pair(recombine_pair_to_ex(*i).to_rational(repl)));
1030 		++i;
1031 	}
1032 	ex oc = overall_coeff.to_rational(repl);
1033 	if (oc.info(info_flags::numeric))
1034 		return thisexpairseq(s, overall_coeff);
1035 
1036 	s.emplace_back(oc, _ex1);
1037 	return thisexpairseq(s, default_overall_coeff());
1038 }
1039 
1040 /** Implementation of ex::to_polynomial() for expairseqs. */
to_polynomial(exmap & repl) const1041 ex expairseq::to_polynomial(exmap & repl) const
1042 {
1043 	epvector s;
1044 	s.reserve(seq.size());
1045 	auto i = seq.begin(), end = seq.end();
1046 	while (i != end) {
1047 		s.push_back(split_ex_to_pair(recombine_pair_to_ex(*i).to_polynomial(repl)));
1048 		++i;
1049 	}
1050 	ex oc = overall_coeff.to_polynomial(repl);
1051 	if (oc.info(info_flags::numeric))
1052 		return thisexpairseq(s, overall_coeff);
1053 
1054 		s.emplace_back(oc, _ex1);
1055 	return thisexpairseq(s, default_overall_coeff());
1056 }
1057 
1058 
1059 /** Remove the common factor in the terms of a sum 'e' by calculating the GCD,
1060  *  and multiply it into the expression 'factor' (which needs to be initialized
1061  *  to 1, unless you're accumulating factors). */
find_common_factor(const ex & e,ex & factor,exmap & repl)1062 static ex find_common_factor(const ex & e, ex & factor, exmap & repl)
1063 {
1064 	if (is_exactly_a<add>(e)) {
1065 
1066 		size_t num = e.nops();
1067 		exvector terms; terms.reserve(num);
1068 		ex gc;
1069 
1070 		// Find the common GCD
1071 		for (size_t i=0; i<num; i++) {
1072 			ex x = e.op(i).to_polynomial(repl);
1073 
1074 			if (is_exactly_a<add>(x) || is_exactly_a<mul>(x) || is_exactly_a<power>(x)) {
1075 				ex f = 1;
1076 				x = find_common_factor(x, f, repl);
1077 				x *= f;
1078 			}
1079 
1080 			if (i == 0)
1081 				gc = x;
1082 			else
1083 				gc = gcdpoly(gc, x);
1084 
1085 			terms.push_back(x);
1086 		}
1087 
1088 		if (gc.is_one())
1089 			return e;
1090 #ifdef PYNAC_HAVE_LIBGIAC
1091                 else {
1092                         ex f = 1;
1093                         gc = find_common_factor(gc, f, repl);
1094                         gc *= f;
1095                 }
1096 #endif
1097 
1098 		// The GCD is the factor we pull out
1099 		factor *= gc;
1100 
1101 		// Now divide all terms by the GCD
1102 		for (size_t i=0; i<num; i++) {
1103 			ex x;
1104 
1105 			// Try to avoid divide() because it expands the polynomial
1106 			ex &t = terms[i];
1107 			if (is_exactly_a<mul>(t)) {
1108 				for (size_t j=0; j<t.nops(); j++) {
1109 					if (t.op(j).is_equal(gc)) {
1110 						exvector v; v.reserve(t.nops());
1111 						for (size_t k=0; k<t.nops(); k++) {
1112 							if (k == j)
1113 								v.push_back(_ex1);
1114 							else
1115 								v.push_back(t.op(k));
1116 						}
1117 						t = (new mul(v))->setflag(status_flags::dynallocated);
1118 						goto term_done;
1119 					}
1120 				}
1121 			}
1122 
1123 			divide(t, gc, x);
1124 			t = x;
1125 term_done:	;
1126 		}
1127 		return (new add(terms))->setflag(status_flags::dynallocated);
1128 	}
1129         if (is_exactly_a<mul>(e)) {
1130 
1131 		size_t num = e.nops();
1132 		exvector v; v.reserve(num);
1133 
1134 		for (size_t i=0; i<num; i++)
1135 			v.push_back(find_common_factor(e.op(i), factor, repl));
1136 
1137 		return (new mul(v))->setflag(status_flags::dynallocated);
1138 	}
1139         if (is_exactly_a<power>(e)) {
1140 		const ex e_exp(e.op(1));
1141 		if (e_exp.is_integer()) {
1142 			ex eb = e.op(0).to_polynomial(repl);
1143 			ex factor_local(_ex1);
1144 			ex pre_res = find_common_factor(eb, factor_local, repl);
1145 			factor *= power(factor_local, e_exp);
1146 			return power(pre_res, e_exp);
1147 		}
1148 		return e.to_polynomial(repl);
1149         }
1150         return e;
1151 }
1152 
1153 
1154 /** Collect common factors in sums. This converts expressions like
1155  *  'a*(b*x+b*y)' to 'a*b*(x+y)'. */
collect_common_factors(const ex & e)1156 ex collect_common_factors(const ex & e)
1157 {
1158 	if (is_exactly_a<add>(e) || is_exactly_a<mul>(e) || is_exactly_a<power>(e)) {
1159 
1160 		exmap repl;
1161 		ex factor = 1;
1162 		ex r = find_common_factor(e, factor, repl);
1163 		return factor.subs(repl, subs_options::no_pattern) * r.subs(repl, subs_options::no_pattern);
1164 
1165 	}
1166 		return e;
1167 }
1168 
gcd(const ex & a,const ex & b)1169 ex gcd(const ex &a, const ex &b)
1170 {
1171         if (is_exactly_a<numeric>(a) && is_exactly_a<numeric>(b))
1172                 return gcd(ex_to<numeric>(a), ex_to<numeric>(b));
1173         return gcdpoly(a, b);
1174 }
1175 
factor(const ex & the_ex,ex & res_ex)1176 bool factor(const ex& the_ex, ex& res_ex)
1177 {
1178         if (is_exactly_a<numeric>(the_ex)
1179             or is_exactly_a<symbol>(the_ex)
1180             or is_exactly_a<function>(the_ex)
1181             or is_exactly_a<constant>(the_ex)) {
1182                 return false;
1183         }
1184         if (is_exactly_a<mul>(the_ex)) {
1185                 const mul& m = ex_to<mul>(the_ex);
1186                 exvector ev;
1187                 bool mchanged = false;
1188                 for (size_t i=0; i<m.nops(); ++i) {
1189                         ex r;
1190                         const ex& e = m.op(i);
1191                         bool res = factor(e, r);
1192                         if (res) {
1193                                 ev.push_back(r);
1194                                 mchanged = true;
1195                         }
1196                         else
1197                                 ev.push_back(e);
1198                 }
1199                 if (mchanged)
1200                         res_ex = mul(ev);
1201                 return mchanged;
1202         }
1203         if (is_exactly_a<power>(the_ex)) {
1204                 const power& p = ex_to<power>(the_ex);
1205                 ex r;
1206                 bool pchanged = factor(p.op(0), r);
1207                 if (pchanged)
1208                         res_ex = power(r, p.op(1));
1209                 return pchanged;
1210         }
1211         ex num, den;
1212         ex normalized = the_ex.numer_denom();
1213         num = normalized.op(0);
1214         bool nres = factorpoly(num, res_ex);
1215         den = normalized.op(1);
1216         ex res_den;
1217         bool dres = factorpoly(den, res_den);
1218         if (not nres and not dres)
1219                 return false;
1220         if (not nres)
1221                 res_ex = num;
1222         if (not dres)
1223                 res_den = den;
1224         res_ex = res_ex / res_den;
1225         return true;
1226 }
1227 
1228 } // namespace GiNaC
1229