1"""
2Algorithms for solving Parametric Risch Differential Equations.
3
4The methods used for solving Parametric Risch Differential Equations parallel
5those for solving Risch Differential Equations.  See the outline in the
6docstring of rde.py for more information.
7
8The Parametric Risch Differential Equation problem is, given f, g1, ..., gm in
9K(t), to determine if there exist y in K(t) and c1, ..., cm in Const(K) such
10that Dy + f*y == Sum(ci*gi, (i, 1, m)), and to find such y and ci if they exist.
11
12For the algorithms here G is a list of tuples of factions of the terms on the
13right hand side of the equation (i.e., gi in k(t)), and Q is a list of terms on
14the right hand side of the equation (i.e., qi in k[t]).  See the docstring of
15each function for more information.
16"""
17
18from functools import reduce
19
20from sympy.core import Dummy, ilcm, Add, Mul, Pow, S
21from sympy.integrals.rde import (order_at, order_at_oo, weak_normalizer,
22    bound_degree)
23from sympy.integrals.risch import (gcdex_diophantine, frac_in, derivation,
24    residue_reduce, splitfactor, residue_reduce_derivation, DecrementLevel,
25    recognize_log_derivative)
26from sympy.polys import Poly, lcm, cancel, sqf_list
27from sympy.polys.polymatrix import PolyMatrix as Matrix
28from sympy.solvers import solve
29
30zeros = Matrix.zeros
31eye = Matrix.eye
32
33
34def prde_normal_denom(fa, fd, G, DE):
35    """
36    Parametric Risch Differential Equation - Normal part of the denominator.
37
38    Explanation
39    ===========
40
41    Given a derivation D on k[t] and f, g1, ..., gm in k(t) with f weakly
42    normalized with respect to t, return the tuple (a, b, G, h) such that
43    a, h in k[t], b in k<t>, G = [g1, ..., gm] in k(t)^m, and for any solution
44    c1, ..., cm in Const(k) and y in k(t) of Dy + f*y == Sum(ci*gi, (i, 1, m)),
45    q == y*h in k<t> satisfies a*Dq + b*q == Sum(ci*Gi, (i, 1, m)).
46    """
47    dn, ds = splitfactor(fd, DE)
48    Gas, Gds = list(zip(*G))
49    gd = reduce(lambda i, j: i.lcm(j), Gds, Poly(1, DE.t))
50    en, es = splitfactor(gd, DE)
51
52    p = dn.gcd(en)
53    h = en.gcd(en.diff(DE.t)).quo(p.gcd(p.diff(DE.t)))
54
55    a = dn*h
56    c = a*h
57
58    ba = a*fa - dn*derivation(h, DE)*fd
59    ba, bd = ba.cancel(fd, include=True)
60
61    G = [(c*A).cancel(D, include=True) for A, D in G]
62
63    return (a, (ba, bd), G, h)
64
65def real_imag(ba, bd, gen):
66    """
67    Helper function, to get the real and imaginary part of a rational function
68    evaluated at sqrt(-1) without actually evaluating it at sqrt(-1).
69
70    Explanation
71    ===========
72
73    Separates the even and odd power terms by checking the degree of terms wrt
74    mod 4. Returns a tuple (ba[0], ba[1], bd) where ba[0] is real part
75    of the numerator ba[1] is the imaginary part and bd is the denominator
76    of the rational function.
77    """
78    bd = bd.as_poly(gen).as_dict()
79    ba = ba.as_poly(gen).as_dict()
80    denom_real = [value if key[0] % 4 == 0 else -value if key[0] % 4 == 2 else 0 for key, value in bd.items()]
81    denom_imag = [value if key[0] % 4 == 1 else -value if key[0] % 4 == 3 else 0 for key, value in bd.items()]
82    bd_real = sum(r for r in denom_real)
83    bd_imag = sum(r for r in denom_imag)
84    num_real = [value if key[0] % 4 == 0 else -value if key[0] % 4 == 2 else 0 for key, value in ba.items()]
85    num_imag = [value if key[0] % 4 == 1 else -value if key[0] % 4 == 3 else 0 for key, value in ba.items()]
86    ba_real = sum(r for r in num_real)
87    ba_imag = sum(r for r in num_imag)
88    ba = ((ba_real*bd_real + ba_imag*bd_imag).as_poly(gen), (ba_imag*bd_real - ba_real*bd_imag).as_poly(gen))
89    bd = (bd_real*bd_real + bd_imag*bd_imag).as_poly(gen)
90    return (ba[0], ba[1], bd)
91
92
93def prde_special_denom(a, ba, bd, G, DE, case='auto'):
94    """
95    Parametric Risch Differential Equation - Special part of the denominator.
96
97    Explanation
98    ===========
99
100    Case is one of {'exp', 'tan', 'primitive'} for the hyperexponential,
101    hypertangent, and primitive cases, respectively.  For the hyperexponential
102    (resp. hypertangent) case, given a derivation D on k[t] and a in k[t],
103    b in k<t>, and g1, ..., gm in k(t) with Dt/t in k (resp. Dt/(t**2 + 1) in
104    k, sqrt(-1) not in k), a != 0, and gcd(a, t) == 1 (resp.
105    gcd(a, t**2 + 1) == 1), return the tuple (A, B, GG, h) such that A, B, h in
106    k[t], GG = [gg1, ..., ggm] in k(t)^m, and for any solution c1, ..., cm in
107    Const(k) and q in k<t> of a*Dq + b*q == Sum(ci*gi, (i, 1, m)), r == q*h in
108    k[t] satisfies A*Dr + B*r == Sum(ci*ggi, (i, 1, m)).
109
110    For case == 'primitive', k<t> == k[t], so it returns (a, b, G, 1) in this
111    case.
112    """
113    # TODO: Merge this with the very similar special_denom() in rde.py
114    if case == 'auto':
115        case = DE.case
116
117    if case == 'exp':
118        p = Poly(DE.t, DE.t)
119    elif case == 'tan':
120        p = Poly(DE.t**2 + 1, DE.t)
121    elif case in ['primitive', 'base']:
122        B = ba.quo(bd)
123        return (a, B, G, Poly(1, DE.t))
124    else:
125        raise ValueError("case must be one of {'exp', 'tan', 'primitive', "
126            "'base'}, not %s." % case)
127
128    nb = order_at(ba, p, DE.t) - order_at(bd, p, DE.t)
129    nc = min([order_at(Ga, p, DE.t) - order_at(Gd, p, DE.t) for Ga, Gd in G])
130    n = min(0, nc - min(0, nb))
131    if not nb:
132        # Possible cancellation.
133        if case == 'exp':
134            dcoeff = DE.d.quo(Poly(DE.t, DE.t))
135            with DecrementLevel(DE):  # We are guaranteed to not have problems,
136                                      # because case != 'base'.
137                alphaa, alphad = frac_in(-ba.eval(0)/bd.eval(0)/a.eval(0), DE.t)
138                etaa, etad = frac_in(dcoeff, DE.t)
139                A = parametric_log_deriv(alphaa, alphad, etaa, etad, DE)
140                if A is not None:
141                    Q, m, z = A
142                    if Q == 1:
143                        n = min(n, m)
144
145        elif case == 'tan':
146            dcoeff = DE.d.quo(Poly(DE.t**2 + 1, DE.t))
147            with DecrementLevel(DE):  # We are guaranteed to not have problems,
148                                      # because case != 'base'.
149                betaa, alphaa, alphad =  real_imag(ba, bd*a, DE.t)
150                betad = alphad
151                etaa, etad = frac_in(dcoeff, DE.t)
152                if recognize_log_derivative(Poly(2, DE.t)*betaa, betad, DE):
153                    A = parametric_log_deriv(alphaa, alphad, etaa, etad, DE)
154                    B = parametric_log_deriv(betaa, betad, etaa, etad, DE)
155                    if A is not None and B is not None:
156                        Q, s, z = A
157                        # TODO: Add test
158                        if Q == 1:
159                            n = min(n, s/2)
160
161    N = max(0, -nb)
162    pN = p**N
163    pn = p**-n  # This is 1/h
164
165    A = a*pN
166    B = ba*pN.quo(bd) + Poly(n, DE.t)*a*derivation(p, DE).quo(p)*pN
167    G = [(Ga*pN*pn).cancel(Gd, include=True) for Ga, Gd in G]
168    h = pn
169
170    # (a*p**N, (b + n*a*Dp/p)*p**N, g1*p**(N - n), ..., gm*p**(N - n), p**-n)
171    return (A, B, G, h)
172
173
174def prde_linear_constraints(a, b, G, DE):
175    """
176    Parametric Risch Differential Equation - Generate linear constraints on the constants.
177
178    Explanation
179    ===========
180
181    Given a derivation D on k[t], a, b, in k[t] with gcd(a, b) == 1, and
182    G = [g1, ..., gm] in k(t)^m, return Q = [q1, ..., qm] in k[t]^m and a
183    matrix M with entries in k(t) such that for any solution c1, ..., cm in
184    Const(k) and p in k[t] of a*Dp + b*p == Sum(ci*gi, (i, 1, m)),
185    (c1, ..., cm) is a solution of Mx == 0, and p and the ci satisfy
186    a*Dp + b*p == Sum(ci*qi, (i, 1, m)).
187
188    Because M has entries in k(t), and because Matrix doesn't play well with
189    Poly, M will be a Matrix of Basic expressions.
190    """
191    m = len(G)
192
193    Gns, Gds = list(zip(*G))
194    d = reduce(lambda i, j: i.lcm(j), Gds)
195    d = Poly(d, field=True)
196    Q = [(ga*(d).quo(gd)).div(d) for ga, gd in G]
197
198    if not all([ri.is_zero for _, ri in Q]):
199        N = max([ri.degree(DE.t) for _, ri in Q])
200        M = Matrix(N + 1, m, lambda i, j: Q[j][1].nth(i), DE.t)
201    else:
202        M = Matrix(0, m, [], DE.t)  # No constraints, return the empty matrix.
203
204    qs, _ = list(zip(*Q))
205    return (qs, M)
206
207def poly_linear_constraints(p, d):
208    """
209    Given p = [p1, ..., pm] in k[t]^m and d in k[t], return
210    q = [q1, ..., qm] in k[t]^m and a matrix M with entries in k such
211    that Sum(ci*pi, (i, 1, m)), for c1, ..., cm in k, is divisible
212    by d if and only if (c1, ..., cm) is a solution of Mx = 0, in
213    which case the quotient is Sum(ci*qi, (i, 1, m)).
214    """
215    m = len(p)
216    q, r = zip(*[pi.div(d) for pi in p])
217
218    if not all([ri.is_zero for ri in r]):
219        n = max([ri.degree() for ri in r])
220        M = Matrix(n + 1, m, lambda i, j: r[j].nth(i), d.gens)
221    else:
222        M = Matrix(0, m, [], d.gens)  # No constraints.
223
224    return q, M
225
226def constant_system(A, u, DE):
227    """
228    Generate a system for the constant solutions.
229
230    Explanation
231    ===========
232
233    Given a differential field (K, D) with constant field C = Const(K), a Matrix
234    A, and a vector (Matrix) u with coefficients in K, returns the tuple
235    (B, v, s), where B is a Matrix with coefficients in C and v is a vector
236    (Matrix) such that either v has coefficients in C, in which case s is True
237    and the solutions in C of Ax == u are exactly all the solutions of Bx == v,
238    or v has a non-constant coefficient, in which case s is False Ax == u has no
239    constant solution.
240
241    This algorithm is used both in solving parametric problems and in
242    determining if an element a of K is a derivative of an element of K or the
243    logarithmic derivative of a K-radical using the structure theorem approach.
244
245    Because Poly does not play well with Matrix yet, this algorithm assumes that
246    all matrix entries are Basic expressions.
247    """
248    if not A:
249        return A, u
250    Au = A.row_join(u)
251    Au, _ = Au.rref()
252    # Warning: This will NOT return correct results if cancel() cannot reduce
253    # an identically zero expression to 0.  The danger is that we might
254    # incorrectly prove that an integral is nonelementary (such as
255    # risch_integrate(exp((sin(x)**2 + cos(x)**2 - 1)*x**2), x).
256    # But this is a limitation in computer algebra in general, and implicit
257    # in the correctness of the Risch Algorithm is the computability of the
258    # constant field (actually, this same correctness problem exists in any
259    # algorithm that uses rref()).
260    #
261    # We therefore limit ourselves to constant fields that are computable
262    # via the cancel() function, in order to prevent a speed bottleneck from
263    # calling some more complex simplification function (rational function
264    # coefficients will fall into this class).  Furthermore, (I believe) this
265    # problem will only crop up if the integral explicitly contains an
266    # expression in the constant field that is identically zero, but cannot
267    # be reduced to such by cancel().  Therefore, a careful user can avoid this
268    # problem entirely by being careful with the sorts of expressions that
269    # appear in his integrand in the variables other than the integration
270    # variable (the structure theorems should be able to completely decide these
271    # problems in the integration variable).
272
273    A, u = Au[:, :-1], Au[:, -1]
274
275    D = lambda x: derivation(x, DE, basic=True)
276
277    for j in range(A.cols):
278        for i in range(A.rows):
279            if A[i, j].expr.has(*DE.T):
280                # This assumes that const(F(t0, ..., tn) == const(K) == F
281                Ri = A[i, :]
282                # Rm+1; m = A.rows
283                DAij = D(A[i, j])
284                Rm1 = Ri.applyfunc(lambda x: D(x) / DAij)
285                um1 = D(u[i]) / DAij
286
287                Aj = A[:, j]
288                A = A - Aj * Rm1
289                u = u - Aj * um1
290
291                A = A.col_join(Rm1)
292                u = u.col_join(Matrix([um1], u.gens))
293
294    return (A, u)
295
296
297def prde_spde(a, b, Q, n, DE):
298    """
299    Special Polynomial Differential Equation algorithm: Parametric Version.
300
301    Explanation
302    ===========
303
304    Given a derivation D on k[t], an integer n, and a, b, q1, ..., qm in k[t]
305    with deg(a) > 0 and gcd(a, b) == 1, return (A, B, Q, R, n1), with
306    Qq = [q1, ..., qm] and R = [r1, ..., rm], such that for any solution
307    c1, ..., cm in Const(k) and q in k[t] of degree at most n of
308    a*Dq + b*q == Sum(ci*gi, (i, 1, m)), p = (q - Sum(ci*ri, (i, 1, m)))/a has
309    degree at most n1 and satisfies A*Dp + B*p == Sum(ci*qi, (i, 1, m))
310    """
311    R, Z = list(zip(*[gcdex_diophantine(b, a, qi) for qi in Q]))
312
313    A = a
314    B = b + derivation(a, DE)
315    Qq = [zi - derivation(ri, DE) for ri, zi in zip(R, Z)]
316    R = list(R)
317    n1 = n - a.degree(DE.t)
318
319    return (A, B, Qq, R, n1)
320
321
322def prde_no_cancel_b_large(b, Q, n, DE):
323    """
324    Parametric Poly Risch Differential Equation - No cancellation: deg(b) large enough.
325
326    Explanation
327    ===========
328
329    Given a derivation D on k[t], n in ZZ, and b, q1, ..., qm in k[t] with
330    b != 0 and either D == d/dt or deg(b) > max(0, deg(D) - 1), returns
331    h1, ..., hr in k[t] and a matrix A with coefficients in Const(k) such that
332    if c1, ..., cm in Const(k) and q in k[t] satisfy deg(q) <= n and
333    Dq + b*q == Sum(ci*qi, (i, 1, m)), then q = Sum(dj*hj, (j, 1, r)), where
334    d1, ..., dr in Const(k) and A*Matrix([[c1, ..., cm, d1, ..., dr]]).T == 0.
335    """
336    db = b.degree(DE.t)
337    m = len(Q)
338    H = [Poly(0, DE.t)]*m
339
340    for N in range(n, -1, -1):  # [n, ..., 0]
341        for i in range(m):
342            si = Q[i].nth(N + db)/b.LC()
343            sitn = Poly(si*DE.t**N, DE.t)
344            H[i] = H[i] + sitn
345            Q[i] = Q[i] - derivation(sitn, DE) - b*sitn
346
347    if all(qi.is_zero for qi in Q):
348        dc = -1
349        M = zeros(0, 2, DE.t)
350    else:
351        dc = max([qi.degree(DE.t) for qi in Q])
352        M = Matrix(dc + 1, m, lambda i, j: Q[j].nth(i), DE.t)
353    A, u = constant_system(M, zeros(dc + 1, 1, DE.t), DE)
354    c = eye(m, DE.t)
355    A = A.row_join(zeros(A.rows, m, DE.t)).col_join(c.row_join(-c))
356
357    return (H, A)
358
359
360def prde_no_cancel_b_small(b, Q, n, DE):
361    """
362    Parametric Poly Risch Differential Equation - No cancellation: deg(b) small enough.
363
364    Explanation
365    ===========
366
367    Given a derivation D on k[t], n in ZZ, and b, q1, ..., qm in k[t] with
368    deg(b) < deg(D) - 1 and either D == d/dt or deg(D) >= 2, returns
369    h1, ..., hr in k[t] and a matrix A with coefficients in Const(k) such that
370    if c1, ..., cm in Const(k) and q in k[t] satisfy deg(q) <= n and
371    Dq + b*q == Sum(ci*qi, (i, 1, m)) then q = Sum(dj*hj, (j, 1, r)) where
372    d1, ..., dr in Const(k) and A*Matrix([[c1, ..., cm, d1, ..., dr]]).T == 0.
373    """
374    m = len(Q)
375    H = [Poly(0, DE.t)]*m
376
377    for N in range(n, 0, -1):  # [n, ..., 1]
378        for i in range(m):
379            si = Q[i].nth(N + DE.d.degree(DE.t) - 1)/(N*DE.d.LC())
380            sitn = Poly(si*DE.t**N, DE.t)
381            H[i] = H[i] + sitn
382            Q[i] = Q[i] - derivation(sitn, DE) - b*sitn
383
384    if b.degree(DE.t) > 0:
385        for i in range(m):
386            si = Poly(Q[i].nth(b.degree(DE.t))/b.LC(), DE.t)
387            H[i] = H[i] + si
388            Q[i] = Q[i] - derivation(si, DE) - b*si
389        if all(qi.is_zero for qi in Q):
390            dc = -1
391            M = Matrix()
392        else:
393            dc = max([qi.degree(DE.t) for qi in Q])
394            M = Matrix(dc + 1, m, lambda i, j: Q[j].nth(i), DE.t)
395        A, u = constant_system(M, zeros(dc + 1, 1, DE.t), DE)
396        c = eye(m, DE.t)
397        A = A.row_join(zeros(A.rows, m, DE.t)).col_join(c.row_join(-c))
398        return (H, A)
399
400    # else: b is in k, deg(qi) < deg(Dt)
401
402    t = DE.t
403    if DE.case != 'base':
404        with DecrementLevel(DE):
405            t0 = DE.t  # k = k0(t0)
406            ba, bd = frac_in(b, t0, field=True)
407            Q0 = [frac_in(qi.TC(), t0, field=True) for qi in Q]
408            f, B = param_rischDE(ba, bd, Q0, DE)
409
410            # f = [f1, ..., fr] in k^r and B is a matrix with
411            # m + r columns and entries in Const(k) = Const(k0)
412            # such that Dy0 + b*y0 = Sum(ci*qi, (i, 1, m)) has
413            # a solution y0 in k with c1, ..., cm in Const(k)
414            # if and only y0 = Sum(dj*fj, (j, 1, r)) where
415            # d1, ..., dr ar in Const(k) and
416            # B*Matrix([c1, ..., cm, d1, ..., dr]) == 0.
417
418        # Transform fractions (fa, fd) in f into constant
419        # polynomials fa/fd in k[t].
420        # (Is there a better way?)
421        f = [Poly(fa.as_expr()/fd.as_expr(), t, field=True)
422             for fa, fd in f]
423        B = Matrix.from_Matrix(B.to_Matrix(), t)
424    else:
425        # Base case. Dy == 0 for all y in k and b == 0.
426        # Dy + b*y = Sum(ci*qi) is solvable if and only if
427        # Sum(ci*qi) == 0 in which case the solutions are
428        # y = d1*f1 for f1 = 1 and any d1 in Const(k) = k.
429
430        f = [Poly(1, t, field=True)]  # r = 1
431        B = Matrix([[qi.TC() for qi in Q] + [S.Zero]], DE.t)
432        # The condition for solvability is
433        # B*Matrix([c1, ..., cm, d1]) == 0
434        # There are no constraints on d1.
435
436    # Coefficients of t^j (j > 0) in Sum(ci*qi) must be zero.
437    d = max([qi.degree(DE.t) for qi in Q])
438    if d > 0:
439        M = Matrix(d, m, lambda i, j: Q[j].nth(i + 1), DE.t)
440        A, _ = constant_system(M, zeros(d, 1, DE.t), DE)
441    else:
442        # No constraints on the hj.
443        A = Matrix(0, m, [], DE.t)
444
445    # Solutions of the original equation are
446    #    y = Sum(dj*fj, (j, 1, r) + Sum(ei*hi, (i, 1, m)),
447    # where  ei == ci  (i = 1, ..., m),  when
448    # A*Matrix([c1, ..., cm]) == 0 and
449    # B*Matrix([c1, ..., cm, d1, ..., dr]) == 0
450
451    # Build combined constraint matrix with m + r + m columns.
452
453    r = len(f)
454    I = eye(m, DE.t)
455    A = A.row_join(zeros(A.rows, r + m, DE.t))
456    B = B.row_join(zeros(B.rows, m, DE.t))
457    C = I.row_join(zeros(m, r, DE.t)).row_join(-I)
458
459    return f + H, A.col_join(B).col_join(C)
460
461
462def prde_cancel_liouvillian(b, Q, n, DE):
463    """
464    Pg, 237.
465    """
466    H = []
467
468    # Why use DecrementLevel? Below line answers that:
469    # Assuming that we can solve such problems over 'k' (not k[t])
470    if DE.case == 'primitive':
471        with DecrementLevel(DE):
472            ba, bd = frac_in(b, DE.t, field=True)
473
474    for i in range(n, -1, -1):
475        if DE.case == 'exp': # this re-checking can be avoided
476            with DecrementLevel(DE):
477                ba, bd = frac_in(b + (i*(derivation(DE.t, DE)/DE.t)).as_poly(b.gens),
478                                DE.t, field=True)
479        with DecrementLevel(DE):
480            Qy = [frac_in(q.nth(i), DE.t, field=True) for q in Q]
481            fi, Ai = param_rischDE(ba, bd, Qy, DE)
482        fi = [Poly(fa.as_expr()/fd.as_expr(), DE.t, field=True)
483                for fa, fd in fi]
484        Ai = Ai.set_gens(DE.t)
485
486        ri = len(fi)
487
488        if i == n:
489            M = Ai
490        else:
491            M = Ai.col_join(M.row_join(zeros(M.rows, ri, DE.t)))
492
493        Fi, hi = [None]*ri, [None]*ri
494
495        # from eq. on top of p.238 (unnumbered)
496        for j in range(ri):
497            hji = fi[j] * (DE.t**i).as_poly(fi[j].gens)
498            hi[j] = hji
499            # building up Sum(djn*(D(fjn*t^n) - b*fjnt^n))
500            Fi[j] = -(derivation(hji, DE) - b*hji)
501
502        H += hi
503        # in the next loop instead of Q it has
504        # to be Q + Fi taking its place
505        Q = Q + Fi
506
507    return (H, M)
508
509
510def param_poly_rischDE(a, b, q, n, DE):
511    """Polynomial solutions of a parametric Risch differential equation.
512
513    Explanation
514    ===========
515
516    Given a derivation D in k[t], a, b in k[t] relatively prime, and q
517    = [q1, ..., qm] in k[t]^m, return h = [h1, ..., hr] in k[t]^r and
518    a matrix A with m + r columns and entries in Const(k) such that
519    a*Dp + b*p = Sum(ci*qi, (i, 1, m)) has a solution p of degree <= n
520    in k[t] with c1, ..., cm in Const(k) if and only if p = Sum(dj*hj,
521    (j, 1, r)) where d1, ..., dr are in Const(k) and (c1, ..., cm,
522    d1, ..., dr) is a solution of Ax == 0.
523    """
524    m = len(q)
525    if n < 0:
526        # Only the trivial zero solution is possible.
527        # Find relations between the qi.
528        if all([qi.is_zero for qi in q]):
529            return [], zeros(1, m, DE.t)  # No constraints.
530
531        N = max([qi.degree(DE.t) for qi in q])
532        M = Matrix(N + 1, m, lambda i, j: q[j].nth(i), DE.t)
533        A, _ = constant_system(M, zeros(M.rows, 1, DE.t), DE)
534
535        return [], A
536
537    if a.is_ground:
538        # Normalization: a = 1.
539        a = a.LC()
540        b, q = b.quo_ground(a), [qi.quo_ground(a) for qi in q]
541
542        if not b.is_zero and (DE.case == 'base' or
543                b.degree() > max(0, DE.d.degree() - 1)):
544            return prde_no_cancel_b_large(b, q, n, DE)
545
546        elif ((b.is_zero or b.degree() < DE.d.degree() - 1)
547                and (DE.case == 'base' or DE.d.degree() >= 2)):
548            return prde_no_cancel_b_small(b, q, n, DE)
549
550        elif (DE.d.degree() >= 2 and
551              b.degree() == DE.d.degree() - 1 and
552              n > -b.as_poly().LC()/DE.d.as_poly().LC()):
553            raise NotImplementedError("prde_no_cancel_b_equal() is "
554                "not yet implemented.")
555
556        else:
557            # Liouvillian cases
558            if DE.case == 'primitive' or DE.case == 'exp':
559                return prde_cancel_liouvillian(b, q, n, DE)
560            else:
561                raise NotImplementedError("non-linear and hypertangent "
562                        "cases have not yet been implemented")
563
564    # else: deg(a) > 0
565
566    # Iterate SPDE as long as possible cumulating coefficient
567    # and terms for the recovery of original solutions.
568    alpha, beta = a.one, [a.zero]*m
569    while n >= 0:  # and a, b relatively prime
570        a, b, q, r, n = prde_spde(a, b, q, n, DE)
571        beta = [betai + alpha*ri for betai, ri in zip(beta, r)]
572        alpha *= a
573        # Solutions p of a*Dp + b*p = Sum(ci*qi) correspond to
574        # solutions alpha*p + Sum(ci*betai) of the initial equation.
575        d = a.gcd(b)
576        if not d.is_ground:
577            break
578
579    # a*Dp + b*p = Sum(ci*qi) may have a polynomial solution
580    # only if the sum is divisible by d.
581
582    qq, M = poly_linear_constraints(q, d)
583    # qq = [qq1, ..., qqm] where qqi = qi.quo(d).
584    # M is a matrix with m columns an entries in k.
585    # Sum(fi*qi, (i, 1, m)), where f1, ..., fm are elements of k, is
586    # divisible by d if and only if M*Matrix([f1, ..., fm]) == 0,
587    # in which case the quotient is Sum(fi*qqi).
588
589    A, _ = constant_system(M, zeros(M.rows, 1, DE.t), DE)
590    # A is a matrix with m columns and entries in Const(k).
591    # Sum(ci*qqi) is Sum(ci*qi).quo(d), and the remainder is zero
592    # for c1, ..., cm in Const(k) if and only if
593    # A*Matrix([c1, ...,cm]) == 0.
594
595    V = A.nullspace()
596    # V = [v1, ..., vu] where each vj is a column matrix with
597    # entries aj1, ..., ajm in Const(k).
598    # Sum(aji*qi) is divisible by d with exact quotient Sum(aji*qqi).
599    # Sum(ci*qi) is divisible by d if and only if ci = Sum(dj*aji)
600    # (i = 1, ..., m) for some d1, ..., du in Const(k).
601    # In that case, solutions of
602    #     a*Dp + b*p = Sum(ci*qi) = Sum(dj*Sum(aji*qi))
603    # are the same as those of
604    #     (a/d)*Dp + (b/d)*p = Sum(dj*rj)
605    # where rj = Sum(aji*qqi).
606
607    if not V:  # No non-trivial solution.
608        return [], eye(m, DE.t)  # Could return A, but this has
609                                 # the minimum number of rows.
610
611    Mqq = Matrix([qq])  # A single row.
612    r = [(Mqq*vj)[0] for vj in V]  # [r1, ..., ru]
613
614    # Solutions of (a/d)*Dp + (b/d)*p = Sum(dj*rj) correspond to
615    # solutions alpha*p + Sum(Sum(dj*aji)*betai) of the initial
616    # equation. These are equal to alpha*p + Sum(dj*fj) where
617    # fj = Sum(aji*betai).
618    Mbeta = Matrix([beta])
619    f = [(Mbeta*vj)[0] for vj in V]  # [f1, ..., fu]
620
621    #
622    # Solve the reduced equation recursively.
623    #
624    g, B = param_poly_rischDE(a.quo(d), b.quo(d), r, n, DE)
625
626    # g = [g1, ..., gv] in k[t]^v and and B is a matrix with u + v
627    # columns and entries in Const(k) such that
628    # (a/d)*Dp + (b/d)*p = Sum(dj*rj) has a solution p of degree <= n
629    # in k[t] if and only if p = Sum(ek*gk) where e1, ..., ev are in
630    # Const(k) and B*Matrix([d1, ..., du, e1, ..., ev]) == 0.
631    # The solutions of the original equation are then
632    # Sum(dj*fj, (j, 1, u)) + alpha*Sum(ek*gk, (k, 1, v)).
633
634    # Collect solution components.
635    h = f + [alpha*gk for gk in g]
636
637    # Build combined relation matrix.
638    A = -eye(m, DE.t)
639    for vj in V:
640        A = A.row_join(vj)
641    A = A.row_join(zeros(m, len(g), DE.t))
642    A = A.col_join(zeros(B.rows, m, DE.t).row_join(B))
643
644    return h, A
645
646
647def param_rischDE(fa, fd, G, DE):
648    """
649    Solve a Parametric Risch Differential Equation: Dy + f*y == Sum(ci*Gi, (i, 1, m)).
650
651    Explanation
652    ===========
653
654    Given a derivation D in k(t), f in k(t), and G
655    = [G1, ..., Gm] in k(t)^m, return h = [h1, ..., hr] in k(t)^r and
656    a matrix A with m + r columns and entries in Const(k) such that
657    Dy + f*y = Sum(ci*Gi, (i, 1, m)) has a solution y
658    in k(t) with c1, ..., cm in Const(k) if and only if y = Sum(dj*hj,
659    (j, 1, r)) where d1, ..., dr are in Const(k) and (c1, ..., cm,
660    d1, ..., dr) is a solution of Ax == 0.
661
662    Elements of k(t) are tuples (a, d) with a and d in k[t].
663    """
664    m = len(G)
665    q, (fa, fd) = weak_normalizer(fa, fd, DE)
666    # Solutions of the weakly normalized equation Dz + f*z = q*Sum(ci*Gi)
667    # correspond to solutions y = z/q of the original equation.
668    gamma = q
669    G = [(q*ga).cancel(gd, include=True) for ga, gd in G]
670
671    a, (ba, bd), G, hn = prde_normal_denom(fa, fd, G, DE)
672    # Solutions q in k<t> of  a*Dq + b*q = Sum(ci*Gi) correspond
673    # to solutions z = q/hn of the weakly normalized equation.
674    gamma *= hn
675
676    A, B, G, hs = prde_special_denom(a, ba, bd, G, DE)
677    # Solutions p in k[t] of  A*Dp + B*p = Sum(ci*Gi) correspond
678    # to solutions q = p/hs of the previous equation.
679    gamma *= hs
680
681    g = A.gcd(B)
682    a, b, g = A.quo(g), B.quo(g), [gia.cancel(gid*g, include=True) for
683        gia, gid in G]
684
685    # a*Dp + b*p = Sum(ci*gi)  may have a polynomial solution
686    # only if the sum is in k[t].
687
688    q, M = prde_linear_constraints(a, b, g, DE)
689
690    # q = [q1, ..., qm] where qi in k[t] is the polynomial component
691    # of the partial fraction expansion of gi.
692    # M is a matrix with m columns and entries in k.
693    # Sum(fi*gi, (i, 1, m)), where f1, ..., fm are elements of k,
694    # is a polynomial if and only if M*Matrix([f1, ..., fm]) == 0,
695    # in which case the sum is equal to Sum(fi*qi).
696
697    M, _ = constant_system(M, zeros(M.rows, 1, DE.t), DE)
698    # M is a matrix with m columns and entries in Const(k).
699    # Sum(ci*gi) is in k[t] for c1, ..., cm in Const(k)
700    # if and only if M*Matrix([c1, ..., cm]) == 0,
701    # in which case the sum is Sum(ci*qi).
702
703    ## Reduce number of constants at this point
704
705    V = M.nullspace()
706    # V = [v1, ..., vu] where each vj is a column matrix with
707    # entries aj1, ..., ajm in Const(k).
708    # Sum(aji*gi) is in k[t] and equal to Sum(aji*qi) (j = 1, ..., u).
709    # Sum(ci*gi) is in k[t] if and only is ci = Sum(dj*aji)
710    # (i = 1, ..., m) for some d1, ..., du in Const(k).
711    # In that case,
712    #     Sum(ci*gi) = Sum(ci*qi) = Sum(dj*Sum(aji*qi)) = Sum(dj*rj)
713    # where rj = Sum(aji*qi) (j = 1, ..., u) in k[t].
714
715    if not V:  # No non-trivial solution
716        return [], eye(m, DE.t)
717
718    Mq = Matrix([q])  # A single row.
719    r = [(Mq*vj)[0] for vj in V]  # [r1, ..., ru]
720
721    # Solutions of a*Dp + b*p = Sum(dj*rj) correspond to solutions
722    # y = p/gamma of the initial equation with ci = Sum(dj*aji).
723
724    try:
725        # We try n=5. At least for prde_spde, it will always
726        # terminate no matter what n is.
727        n = bound_degree(a, b, r, DE, parametric=True)
728    except NotImplementedError:
729        # A temporary bound is set. Eventually, it will be removed.
730        # the currently added test case takes large time
731        # even with n=5, and much longer with large n's.
732        n = 5
733
734    h, B = param_poly_rischDE(a, b, r, n, DE)
735
736    # h = [h1, ..., hv] in k[t]^v and and B is a matrix with u + v
737    # columns and entries in Const(k) such that
738    # a*Dp + b*p = Sum(dj*rj) has a solution p of degree <= n
739    # in k[t] if and only if p = Sum(ek*hk) where e1, ..., ev are in
740    # Const(k) and B*Matrix([d1, ..., du, e1, ..., ev]) == 0.
741    # The solutions of the original equation for ci = Sum(dj*aji)
742    # (i = 1, ..., m) are then y = Sum(ek*hk, (k, 1, v))/gamma.
743
744    ## Build combined relation matrix with m + u + v columns.
745
746    A = -eye(m, DE.t)
747    for vj in V:
748        A = A.row_join(vj)
749    A = A.row_join(zeros(m, len(h), DE.t))
750    A = A.col_join(zeros(B.rows, m, DE.t).row_join(B))
751
752    ## Eliminate d1, ..., du.
753
754    W = A.nullspace()
755
756    # W = [w1, ..., wt] where each wl is a column matrix with
757    # entries blk (k = 1, ..., m + u + v) in Const(k).
758    # The vectors (bl1, ..., blm) generate the space of those
759    # constant families (c1, ..., cm) for which a solution of
760    # the equation Dy + f*y == Sum(ci*Gi) exists. They generate
761    # the space and form a basis except possibly when Dy + f*y == 0
762    # is solvable in k(t}. The corresponding solutions are
763    # y = Sum(blk'*hk, (k, 1, v))/gamma, where k' = k + m + u.
764
765    v = len(h)
766    M = Matrix([wl[:m] + wl[-v:] for wl in W])  # excise dj's.
767    N = M.nullspace()
768    # N = [n1, ..., ns] where the ni in Const(k)^(m + v) are column
769    # vectors generating the space of linear relations between
770    # c1, ..., cm, e1, ..., ev.
771
772    C = Matrix([ni[:] for ni in N], DE.t)  # rows n1, ..., ns.
773
774    return [hk.cancel(gamma, include=True) for hk in h], C
775
776
777def limited_integrate_reduce(fa, fd, G, DE):
778    """
779    Simpler version of step 1 & 2 for the limited integration problem.
780
781    Explanation
782    ===========
783
784    Given a derivation D on k(t) and f, g1, ..., gn in k(t), return
785    (a, b, h, N, g, V) such that a, b, h in k[t], N is a non-negative integer,
786    g in k(t), V == [v1, ..., vm] in k(t)^m, and for any solution v in k(t),
787    c1, ..., cm in C of f == Dv + Sum(ci*wi, (i, 1, m)), p = v*h is in k<t>, and
788    p and the ci satisfy a*Dp + b*p == g + Sum(ci*vi, (i, 1, m)).  Furthermore,
789    if S1irr == Sirr, then p is in k[t], and if t is nonlinear or Liouvillian
790    over k, then deg(p) <= N.
791
792    So that the special part is always computed, this function calls the more
793    general prde_special_denom() automatically if it cannot determine that
794    S1irr == Sirr.  Furthermore, it will automatically call bound_degree() when
795    t is linear and non-Liouvillian, which for the transcendental case, implies
796    that Dt == a*t + b with for some a, b in k*.
797    """
798    dn, ds = splitfactor(fd, DE)
799    E = [splitfactor(gd, DE) for _, gd in G]
800    En, Es = list(zip(*E))
801    c = reduce(lambda i, j: i.lcm(j), (dn,) + En)  # lcm(dn, en1, ..., enm)
802    hn = c.gcd(c.diff(DE.t))
803    a = hn
804    b = -derivation(hn, DE)
805    N = 0
806
807    # These are the cases where we know that S1irr = Sirr, but there could be
808    # others, and this algorithm will need to be extended to handle them.
809    if DE.case in ['base', 'primitive', 'exp', 'tan']:
810        hs = reduce(lambda i, j: i.lcm(j), (ds,) + Es)  # lcm(ds, es1, ..., esm)
811        a = hn*hs
812        b -= (hn*derivation(hs, DE)).quo(hs)
813        mu = min(order_at_oo(fa, fd, DE.t), min([order_at_oo(ga, gd, DE.t) for
814            ga, gd in G]))
815        # So far, all the above are also nonlinear or Liouvillian, but if this
816        # changes, then this will need to be updated to call bound_degree()
817        # as per the docstring of this function (DE.case == 'other_linear').
818        N = hn.degree(DE.t) + hs.degree(DE.t) + max(0, 1 - DE.d.degree(DE.t) - mu)
819    else:
820        # TODO: implement this
821        raise NotImplementedError
822
823    V = [(-a*hn*ga).cancel(gd, include=True) for ga, gd in G]
824    return (a, b, a, N, (a*hn*fa).cancel(fd, include=True), V)
825
826
827def limited_integrate(fa, fd, G, DE):
828    """
829    Solves the limited integration problem:  f = Dv + Sum(ci*wi, (i, 1, n))
830    """
831    fa, fd = fa*Poly(1/fd.LC(), DE.t), fd.monic()
832    # interpreting limited integration problem as a
833    # parametric Risch DE problem
834    Fa = Poly(0, DE.t)
835    Fd = Poly(1, DE.t)
836    G = [(fa, fd)] + G
837    h, A = param_rischDE(Fa, Fd, G, DE)
838    V = A.nullspace()
839    V = [v for v in V if v[0] != 0]
840    if not V:
841        return None
842    else:
843        # we can take any vector from V, we take V[0]
844        c0 = V[0][0]
845        # v = [-1, c1, ..., cm, d1, ..., dr]
846        v = V[0]/(-c0)
847        r = len(h)
848        m = len(v) - r - 1
849        C = list(v[1: m + 1])
850        y = -sum([v[m + 1 + i]*h[i][0].as_expr()/h[i][1].as_expr() \
851                for i in range(r)])
852        y_num, y_den = y.as_numer_denom()
853        Ya, Yd = Poly(y_num, DE.t), Poly(y_den, DE.t)
854        Y = Ya*Poly(1/Yd.LC(), DE.t), Yd.monic()
855        return Y, C
856
857
858def parametric_log_deriv_heu(fa, fd, wa, wd, DE, c1=None):
859    """
860    Parametric logarithmic derivative heuristic.
861
862    Explanation
863    ===========
864
865    Given a derivation D on k[t], f in k(t), and a hyperexponential monomial
866    theta over k(t), raises either NotImplementedError, in which case the
867    heuristic failed, or returns None, in which case it has proven that no
868    solution exists, or returns a solution (n, m, v) of the equation
869    n*f == Dv/v + m*Dtheta/theta, with v in k(t)* and n, m in ZZ with n != 0.
870
871    If this heuristic fails, the structure theorem approach will need to be
872    used.
873
874    The argument w == Dtheta/theta
875    """
876    # TODO: finish writing this and write tests
877    c1 = c1 or Dummy('c1')
878
879    p, a = fa.div(fd)
880    q, b = wa.div(wd)
881
882    B = max(0, derivation(DE.t, DE).degree(DE.t) - 1)
883    C = max(p.degree(DE.t), q.degree(DE.t))
884
885    if q.degree(DE.t) > B:
886        eqs = [p.nth(i) - c1*q.nth(i) for i in range(B + 1, C + 1)]
887        s = solve(eqs, c1)
888        if not s or not s[c1].is_Rational:
889            # deg(q) > B, no solution for c.
890            return None
891
892        M, N = s[c1].as_numer_denom()
893        M_poly = M.as_poly(q.gens)
894        N_poly = N.as_poly(q.gens)
895
896        nfmwa = N_poly*fa*wd - M_poly*wa*fd
897        nfmwd = fd*wd
898        Qv = is_log_deriv_k_t_radical_in_field(nfmwa, nfmwd, DE, 'auto')
899        if Qv is None:
900            # (N*f - M*w) is not the logarithmic derivative of a k(t)-radical.
901            return None
902
903        Q, v = Qv
904
905        if Q.is_zero or v.is_zero:
906            return None
907
908        return (Q*N, Q*M, v)
909
910    if p.degree(DE.t) > B:
911        return None
912
913    c = lcm(fd.as_poly(DE.t).LC(), wd.as_poly(DE.t).LC())
914    l = fd.monic().lcm(wd.monic())*Poly(c, DE.t)
915    ln, ls = splitfactor(l, DE)
916    z = ls*ln.gcd(ln.diff(DE.t))
917
918    if not z.has(DE.t):
919        # TODO: We treat this as 'no solution', until the structure
920        # theorem version of parametric_log_deriv is implemented.
921        return None
922
923    u1, r1 = (fa*l.quo(fd)).div(z)  # (l*f).div(z)
924    u2, r2 = (wa*l.quo(wd)).div(z)  # (l*w).div(z)
925
926    eqs = [r1.nth(i) - c1*r2.nth(i) for i in range(z.degree(DE.t))]
927    s = solve(eqs, c1)
928    if not s or not s[c1].is_Rational:
929        # deg(q) <= B, no solution for c.
930        return None
931
932    M, N = s[c1].as_numer_denom()
933
934    nfmwa = N.as_poly(DE.t)*fa*wd - M.as_poly(DE.t)*wa*fd
935    nfmwd = fd*wd
936    Qv = is_log_deriv_k_t_radical_in_field(nfmwa, nfmwd, DE)
937    if Qv is None:
938        # (N*f - M*w) is not the logarithmic derivative of a k(t)-radical.
939        return None
940
941    Q, v = Qv
942
943    if Q.is_zero or v.is_zero:
944        return None
945
946    return (Q*N, Q*M, v)
947
948
949def parametric_log_deriv(fa, fd, wa, wd, DE):
950    # TODO: Write the full algorithm using the structure theorems.
951#    try:
952    A = parametric_log_deriv_heu(fa, fd, wa, wd, DE)
953#    except NotImplementedError:
954        # Heuristic failed, we have to use the full method.
955        # TODO: This could be implemented more efficiently.
956        # It isn't too worrisome, because the heuristic handles most difficult
957        # cases.
958    return A
959
960
961def is_deriv_k(fa, fd, DE):
962    r"""
963    Checks if Df/f is the derivative of an element of k(t).
964
965    Explanation
966    ===========
967
968    a in k(t) is the derivative of an element of k(t) if there exists b in k(t)
969    such that a = Db.  Either returns (ans, u), such that Df/f == Du, or None,
970    which means that Df/f is not the derivative of an element of k(t).  ans is
971    a list of tuples such that Add(*[i*j for i, j in ans]) == u.  This is useful
972    for seeing exactly which elements of k(t) produce u.
973
974    This function uses the structure theorem approach, which says that for any
975    f in K, Df/f is the derivative of a element of K if and only if there are ri
976    in QQ such that::
977
978            ---               ---       Dt
979            \    r  * Dt   +  \    r  *   i      Df
980            /     i     i     /     i   ---   =  --.
981            ---               ---        t        f
982         i in L            i in E         i
983               K/C(x)            K/C(x)
984
985
986    Where C = Const(K), L_K/C(x) = { i in {1, ..., n} such that t_i is
987    transcendental over C(x)(t_1, ..., t_i-1) and Dt_i = Da_i/a_i, for some a_i
988    in C(x)(t_1, ..., t_i-1)* } (i.e., the set of all indices of logarithmic
989    monomials of K over C(x)), and E_K/C(x) = { i in {1, ..., n} such that t_i
990    is transcendental over C(x)(t_1, ..., t_i-1) and Dt_i/t_i = Da_i, for some
991    a_i in C(x)(t_1, ..., t_i-1) } (i.e., the set of all indices of
992    hyperexponential monomials of K over C(x)).  If K is an elementary extension
993    over C(x), then the cardinality of L_K/C(x) U E_K/C(x) is exactly the
994    transcendence degree of K over C(x).  Furthermore, because Const_D(K) ==
995    Const_D(C(x)) == C, deg(Dt_i) == 1 when t_i is in E_K/C(x) and
996    deg(Dt_i) == 0 when t_i is in L_K/C(x), implying in particular that E_K/C(x)
997    and L_K/C(x) are disjoint.
998
999    The sets L_K/C(x) and E_K/C(x) must, by their nature, be computed
1000    recursively using this same function.  Therefore, it is required to pass
1001    them as indices to D (or T).  E_args are the arguments of the
1002    hyperexponentials indexed by E_K (i.e., if i is in E_K, then T[i] ==
1003    exp(E_args[i])).  This is needed to compute the final answer u such that
1004    Df/f == Du.
1005
1006    log(f) will be the same as u up to a additive constant.  This is because
1007    they will both behave the same as monomials. For example, both log(x) and
1008    log(2*x) == log(x) + log(2) satisfy Dt == 1/x, because log(2) is constant.
1009    Therefore, the term const is returned.  const is such that
1010    log(const) + f == u.  This is calculated by dividing the arguments of one
1011    logarithm from the other.  Therefore, it is necessary to pass the arguments
1012    of the logarithmic terms in L_args.
1013
1014    To handle the case where we are given Df/f, not f, use is_deriv_k_in_field().
1015
1016    See also
1017    ========
1018    is_log_deriv_k_t_radical_in_field, is_log_deriv_k_t_radical
1019
1020    """
1021    # Compute Df/f
1022    dfa, dfd = (fd*derivation(fa, DE) - fa*derivation(fd, DE)), fd*fa
1023    dfa, dfd = dfa.cancel(dfd, include=True)
1024
1025    # Our assumption here is that each monomial is recursively transcendental
1026    if len(DE.exts) != len(DE.D):
1027        if [i for i in DE.cases if i == 'tan'] or \
1028                ({i for i in DE.cases if i == 'primitive'} -
1029                        set(DE.indices('log'))):
1030            raise NotImplementedError("Real version of the structure "
1031                "theorems with hypertangent support is not yet implemented.")
1032
1033        # TODO: What should really be done in this case?
1034        raise NotImplementedError("Nonelementary extensions not supported "
1035            "in the structure theorems.")
1036
1037    E_part = [DE.D[i].quo(Poly(DE.T[i], DE.T[i])).as_expr() for i in DE.indices('exp')]
1038    L_part = [DE.D[i].as_expr() for i in DE.indices('log')]
1039
1040    # The expression dfa/dfd might not be polynomial in any of its symbols so we
1041    # use a Dummy as the generator for PolyMatrix.
1042    dum = Dummy()
1043    lhs = Matrix([E_part + L_part], dum)
1044    rhs = Matrix([dfa.as_expr()/dfd.as_expr()], dum)
1045
1046    A, u = constant_system(lhs, rhs, DE)
1047
1048    u = u.to_Matrix()  # Poly to Expr
1049
1050    if not all(derivation(i, DE, basic=True).is_zero for i in u) or not A:
1051        # If the elements of u are not all constant
1052        # Note: See comment in constant_system
1053
1054        # Also note: derivation(basic=True) calls cancel()
1055        return None
1056    else:
1057        if not all(i.is_Rational for i in u):
1058            raise NotImplementedError("Cannot work with non-rational "
1059                "coefficients in this case.")
1060        else:
1061            terms = ([DE.extargs[i] for i in DE.indices('exp')] +
1062                    [DE.T[i] for i in DE.indices('log')])
1063            ans = list(zip(terms, u))
1064            result = Add(*[Mul(i, j) for i, j in ans])
1065            argterms = ([DE.T[i] for i in DE.indices('exp')] +
1066                    [DE.extargs[i] for i in DE.indices('log')])
1067            l = []
1068            ld = []
1069            for i, j in zip(argterms, u):
1070                # We need to get around things like sqrt(x**2) != x
1071                # and also sqrt(x**2 + 2*x + 1) != x + 1
1072                # Issue 10798: i need not be a polynomial
1073                i, d = i.as_numer_denom()
1074                icoeff, iterms = sqf_list(i)
1075                l.append(Mul(*([Pow(icoeff, j)] + [Pow(b, e*j) for b, e in iterms])))
1076                dcoeff, dterms = sqf_list(d)
1077                ld.append(Mul(*([Pow(dcoeff, j)] + [Pow(b, e*j) for b, e in dterms])))
1078            const = cancel(fa.as_expr()/fd.as_expr()/Mul(*l)*Mul(*ld))
1079
1080            return (ans, result, const)
1081
1082
1083def is_log_deriv_k_t_radical(fa, fd, DE, Df=True):
1084    r"""
1085    Checks if Df is the logarithmic derivative of a k(t)-radical.
1086
1087    Explanation
1088    ===========
1089
1090    b in k(t) can be written as the logarithmic derivative of a k(t) radical if
1091    there exist n in ZZ and u in k(t) with n, u != 0 such that n*b == Du/u.
1092    Either returns (ans, u, n, const) or None, which means that Df cannot be
1093    written as the logarithmic derivative of a k(t)-radical.  ans is a list of
1094    tuples such that Mul(*[i**j for i, j in ans]) == u.  This is useful for
1095    seeing exactly what elements of k(t) produce u.
1096
1097    This function uses the structure theorem approach, which says that for any
1098    f in K, Df is the logarithmic derivative of a K-radical if and only if there
1099    are ri in QQ such that::
1100
1101            ---               ---       Dt
1102            \    r  * Dt   +  \    r  *   i
1103            /     i     i     /     i   ---   =  Df.
1104            ---               ---        t
1105         i in L            i in E         i
1106               K/C(x)            K/C(x)
1107
1108
1109    Where C = Const(K), L_K/C(x) = { i in {1, ..., n} such that t_i is
1110    transcendental over C(x)(t_1, ..., t_i-1) and Dt_i = Da_i/a_i, for some a_i
1111    in C(x)(t_1, ..., t_i-1)* } (i.e., the set of all indices of logarithmic
1112    monomials of K over C(x)), and E_K/C(x) = { i in {1, ..., n} such that t_i
1113    is transcendental over C(x)(t_1, ..., t_i-1) and Dt_i/t_i = Da_i, for some
1114    a_i in C(x)(t_1, ..., t_i-1) } (i.e., the set of all indices of
1115    hyperexponential monomials of K over C(x)).  If K is an elementary extension
1116    over C(x), then the cardinality of L_K/C(x) U E_K/C(x) is exactly the
1117    transcendence degree of K over C(x).  Furthermore, because Const_D(K) ==
1118    Const_D(C(x)) == C, deg(Dt_i) == 1 when t_i is in E_K/C(x) and
1119    deg(Dt_i) == 0 when t_i is in L_K/C(x), implying in particular that E_K/C(x)
1120    and L_K/C(x) are disjoint.
1121
1122    The sets L_K/C(x) and E_K/C(x) must, by their nature, be computed
1123    recursively using this same function.  Therefore, it is required to pass
1124    them as indices to D (or T).  L_args are the arguments of the logarithms
1125    indexed by L_K (i.e., if i is in L_K, then T[i] == log(L_args[i])).  This is
1126    needed to compute the final answer u such that n*f == Du/u.
1127
1128    exp(f) will be the same as u up to a multiplicative constant.  This is
1129    because they will both behave the same as monomials.  For example, both
1130    exp(x) and exp(x + 1) == E*exp(x) satisfy Dt == t. Therefore, the term const
1131    is returned.  const is such that exp(const)*f == u.  This is calculated by
1132    subtracting the arguments of one exponential from the other.  Therefore, it
1133    is necessary to pass the arguments of the exponential terms in E_args.
1134
1135    To handle the case where we are given Df, not f, use
1136    is_log_deriv_k_t_radical_in_field().
1137
1138    See also
1139    ========
1140
1141    is_log_deriv_k_t_radical_in_field, is_deriv_k
1142
1143    """
1144    if Df:
1145        dfa, dfd = (fd*derivation(fa, DE) - fa*derivation(fd, DE)).cancel(fd**2,
1146            include=True)
1147    else:
1148        dfa, dfd = fa, fd
1149
1150    # Our assumption here is that each monomial is recursively transcendental
1151    if len(DE.exts) != len(DE.D):
1152        if [i for i in DE.cases if i == 'tan'] or \
1153                ({i for i in DE.cases if i == 'primitive'} -
1154                        set(DE.indices('log'))):
1155            raise NotImplementedError("Real version of the structure "
1156                "theorems with hypertangent support is not yet implemented.")
1157
1158        # TODO: What should really be done in this case?
1159        raise NotImplementedError("Nonelementary extensions not supported "
1160            "in the structure theorems.")
1161
1162    E_part = [DE.D[i].quo(Poly(DE.T[i], DE.T[i])).as_expr() for i in DE.indices('exp')]
1163    L_part = [DE.D[i].as_expr() for i in DE.indices('log')]
1164
1165    # The expression dfa/dfd might not be polynomial in any of its symbols so we
1166    # use a Dummy as the generator for PolyMatrix.
1167    dum = Dummy()
1168    lhs = Matrix([E_part + L_part], dum)
1169    rhs = Matrix([dfa.as_expr()/dfd.as_expr()], dum)
1170
1171    A, u = constant_system(lhs, rhs, DE)
1172
1173    u = u.to_Matrix()  # Poly to Expr
1174
1175    if not all(derivation(i, DE, basic=True).is_zero for i in u) or not A:
1176        # If the elements of u are not all constant
1177        # Note: See comment in constant_system
1178
1179        # Also note: derivation(basic=True) calls cancel()
1180        return None
1181    else:
1182        if not all(i.is_Rational for i in u):
1183            # TODO: But maybe we can tell if they're not rational, like
1184            # log(2)/log(3). Also, there should be an option to continue
1185            # anyway, even if the result might potentially be wrong.
1186            raise NotImplementedError("Cannot work with non-rational "
1187                "coefficients in this case.")
1188        else:
1189            n = reduce(ilcm, [i.as_numer_denom()[1] for i in u])
1190            u *= n
1191            terms = ([DE.T[i] for i in DE.indices('exp')] +
1192                    [DE.extargs[i] for i in DE.indices('log')])
1193            ans = list(zip(terms, u))
1194            result = Mul(*[Pow(i, j) for i, j in ans])
1195
1196            # exp(f) will be the same as result up to a multiplicative
1197            # constant.  We now find the log of that constant.
1198            argterms = ([DE.extargs[i] for i in DE.indices('exp')] +
1199                    [DE.T[i] for i in DE.indices('log')])
1200            const = cancel(fa.as_expr()/fd.as_expr() -
1201                Add(*[Mul(i, j/n) for i, j in zip(argterms, u)]))
1202
1203            return (ans, result, n, const)
1204
1205
1206def is_log_deriv_k_t_radical_in_field(fa, fd, DE, case='auto', z=None):
1207    """
1208    Checks if f can be written as the logarithmic derivative of a k(t)-radical.
1209
1210    Explanation
1211    ===========
1212
1213    It differs from is_log_deriv_k_t_radical(fa, fd, DE, Df=False)
1214    for any given fa, fd, DE in that it finds the solution in the
1215    given field not in some (possibly unspecified extension) and
1216    "in_field" with the function name is used to indicate that.
1217
1218    f in k(t) can be written as the logarithmic derivative of a k(t) radical if
1219    there exist n in ZZ and u in k(t) with n, u != 0 such that n*f == Du/u.
1220    Either returns (n, u) or None, which means that f cannot be written as the
1221    logarithmic derivative of a k(t)-radical.
1222
1223    case is one of {'primitive', 'exp', 'tan', 'auto'} for the primitive,
1224    hyperexponential, and hypertangent cases, respectively.  If case is 'auto',
1225    it will attempt to determine the type of the derivation automatically.
1226
1227    See also
1228    ========
1229    is_log_deriv_k_t_radical, is_deriv_k
1230
1231    """
1232    fa, fd = fa.cancel(fd, include=True)
1233
1234    # f must be simple
1235    n, s = splitfactor(fd, DE)
1236    if not s.is_one:
1237        pass
1238
1239    z = z or Dummy('z')
1240    H, b = residue_reduce(fa, fd, DE, z=z)
1241    if not b:
1242        # I will have to verify, but I believe that the answer should be
1243        # None in this case. This should never happen for the
1244        # functions given when solving the parametric logarithmic
1245        # derivative problem when integration elementary functions (see
1246        # Bronstein's book, page 255), so most likely this indicates a bug.
1247        return None
1248
1249    roots = [(i, i.real_roots()) for i, _ in H]
1250    if not all(len(j) == i.degree() and all(k.is_Rational for k in j) for
1251               i, j in roots):
1252        # If f is the logarithmic derivative of a k(t)-radical, then all the
1253        # roots of the resultant must be rational numbers.
1254        return None
1255
1256    # [(a, i), ...], where i*log(a) is a term in the log-part of the integral
1257    # of f
1258    respolys, residues = list(zip(*roots)) or [[], []]
1259    # Note: this might be empty, but everything below should work find in that
1260    # case (it should be the same as if it were [[1, 1]])
1261    residueterms = [(H[j][1].subs(z, i), i) for j in range(len(H)) for
1262        i in residues[j]]
1263
1264    # TODO: finish writing this and write tests
1265
1266    p = cancel(fa.as_expr()/fd.as_expr() - residue_reduce_derivation(H, DE, z))
1267
1268    p = p.as_poly(DE.t)
1269    if p is None:
1270        # f - Dg will be in k[t] if f is the logarithmic derivative of a k(t)-radical
1271        return None
1272
1273    if p.degree(DE.t) >= max(1, DE.d.degree(DE.t)):
1274        return None
1275
1276    if case == 'auto':
1277        case = DE.case
1278
1279    if case == 'exp':
1280        wa, wd = derivation(DE.t, DE).cancel(Poly(DE.t, DE.t), include=True)
1281        with DecrementLevel(DE):
1282            pa, pd = frac_in(p, DE.t, cancel=True)
1283            wa, wd = frac_in((wa, wd), DE.t)
1284            A = parametric_log_deriv(pa, pd, wa, wd, DE)
1285        if A is None:
1286            return None
1287        n, e, u = A
1288        u *= DE.t**e
1289
1290    elif case == 'primitive':
1291        with DecrementLevel(DE):
1292            pa, pd = frac_in(p, DE.t)
1293            A = is_log_deriv_k_t_radical_in_field(pa, pd, DE, case='auto')
1294        if A is None:
1295            return None
1296        n, u = A
1297
1298    elif case == 'base':
1299        # TODO: we can use more efficient residue reduction from ratint()
1300        if not fd.is_sqf or fa.degree() >= fd.degree():
1301            # f is the logarithmic derivative in the base case if and only if
1302            # f = fa/fd, fd is square-free, deg(fa) < deg(fd), and
1303            # gcd(fa, fd) == 1.  The last condition is handled by cancel() above.
1304            return None
1305        # Note: if residueterms = [], returns (1, 1)
1306        # f had better be 0 in that case.
1307        n = reduce(ilcm, [i.as_numer_denom()[1] for _, i in residueterms], S.One)
1308        u = Mul(*[Pow(i, j*n) for i, j in residueterms])
1309        return (n, u)
1310
1311    elif case == 'tan':
1312        raise NotImplementedError("The hypertangent case is "
1313        "not yet implemented for is_log_deriv_k_t_radical_in_field()")
1314
1315    elif case in ['other_linear', 'other_nonlinear']:
1316        # XXX: If these are supported by the structure theorems, change to NotImplementedError.
1317        raise ValueError("The %s case is not supported in this function." % case)
1318
1319    else:
1320        raise ValueError("case must be one of {'primitive', 'exp', 'tan', "
1321        "'base', 'auto'}, not %s" % case)
1322
1323    common_denom = reduce(ilcm, [i.as_numer_denom()[1] for i in [j for _, j in
1324        residueterms]] + [n], S.One)
1325    residueterms = [(i, j*common_denom) for i, j in residueterms]
1326    m = common_denom//n
1327    if common_denom != n*m:  # Verify exact division
1328        raise ValueError("Inexact division")
1329    u = cancel(u**m*Mul(*[Pow(i, j) for i, j in residueterms]))
1330
1331    return (common_denom, u)
1332