1"""
2Algorithms for solving Parametric Risch Differential Equations.
3
4The methods used for solving Parametric Risch Differential Equations parallel
5those for solving Risch Differential Equations.  See the outline in the
6docstring of rde.py for more information.
7
8The Parametric Risch Differential Equation problem is, given f, g1, ..., gm in
9K(t), to determine if there exist y in K(t) and c1, ..., cm in Const(K) such
10that Dy + f*y == Sum(ci*gi, (i, 1, m)), and to find such y and ci if they exist.
11
12For the algorithms here G is a list of tuples of factions of the terms on the
13right hand side of the equation (i.e., gi in k(t)), and Q is a list of terms on
14the right hand side of the equation (i.e., qi in k[t]).  See the docstring of
15each function for more information.
16"""
17
18import functools
19import math
20
21from ..core import Add, Dummy, Integer, Mul, Pow
22from ..matrices import Matrix, eye, zeros
23from ..polys import Poly, cancel, lcm, sqf_list
24from ..solvers import solve
25from .rde import order_at, order_at_oo, solve_poly_rde, spde
26from .risch import (DecrementLevel, NonElementaryIntegralException, derivation,
27                    frac_in, gcdex_diophantine, recognize_log_derivative,
28                    residue_reduce, residue_reduce_derivation, splitfactor)
29
30
31def prde_normal_denom(fa, fd, G, DE):
32    """
33    Parametric Risch Differential Equation - Normal part of the denominator.
34
35    Given a derivation D on k[t] and f, g1, ..., gm in k(t) with f weakly
36    normalized with respect to t, return the tuple (a, b, G, h) such that
37    a, h in k[t], b in k<t>, G = [g1, ..., gm] in k(t)^m, and for any solution
38    c1, ..., cm in Const(k) and y in k(t) of Dy + f*y == Sum(ci*gi, (i, 1, m)),
39    q == y*h in k<t> satisfies a*Dq + b*q == Sum(ci*Gi, (i, 1, m)).
40    """
41    dn, _ = splitfactor(fd, DE)
42    _, Gds = list(zip(*G))
43    gd = functools.reduce(lambda i, j: i.lcm(j), Gds, Poly(1, DE.t))
44    en, _ = splitfactor(gd, DE)
45
46    p = dn.gcd(en)
47    h = en.gcd(en.diff(DE.t)).quo(p.gcd(p.diff(DE.t)))
48
49    a = dn*h
50    c = a*h
51
52    ba = a*fa - dn*derivation(h, DE)*fd
53    ba, bd = ba.cancel(fd, include=True)
54
55    G = [(c*A).cancel(D, include=True) for A, D in G]
56
57    return a, (ba, bd), G, h
58
59
60def real_imag(ba, bd, gen):
61    """
62    Helper function, to get the real and imaginary part of a rational function
63    evaluated at sqrt(-1) without actually evaluating it at sqrt(-1)
64
65    Separates the even and odd power terms by checking the degree of terms wrt
66    mod 4. Returns a tuple (ba[0], ba[1], bd) where ba[0] is real part
67    of the numerator ba[1] is the imaginary part and bd is the denominator
68    of the rational function.
69    """
70    bd = bd.as_poly(gen).as_dict()
71    ba = ba.as_poly(gen).as_dict()
72    denom_real = [value if key[0] % 4 == 0 else -value if key[0] % 4 == 2 else 0 for key, value in bd.items()]
73    denom_imag = [value if key[0] % 4 == 1 else -value if key[0] % 4 == 3 else 0 for key, value in bd.items()]
74    bd_real = sum(r for r in denom_real)
75    bd_imag = sum(r for r in denom_imag)
76    num_real = [value if key[0] % 4 == 0 else -value if key[0] % 4 == 2 else 0 for key, value in ba.items()]
77    num_imag = [value if key[0] % 4 == 1 else -value if key[0] % 4 == 3 else 0 for key, value in ba.items()]
78    ba_real = sum(r for r in num_real)
79    ba_imag = sum(r for r in num_imag)
80    ba = ((ba_real*bd_real + ba_imag*bd_imag).as_poly(gen), (ba_imag*bd_real - ba_real*bd_imag).as_poly(gen))
81    bd = (bd_real*bd_real + bd_imag*bd_imag).as_poly(gen)
82    return ba[0], ba[1], bd
83
84
85def prde_special_denom(a, ba, bd, G, DE, case='auto'):
86    """
87    Parametric Risch Differential Equation - Special part of the denominator.
88
89    case is on of {'exp', 'tan', 'primitive'} for the hyperexponential,
90    hypertangent, and primitive cases, respectively.  For the hyperexponential
91    (resp. hypertangent) case, given a derivation D on k[t] and a in k[t],
92    b in k<t>, and g1, ..., gm in k(t) with Dt/t in k (resp. Dt/(t**2 + 1) in
93    k, sqrt(-1) not in k), a != 0, and gcd(a, t) == 1 (resp.
94    gcd(a, t**2 + 1) == 1), return the tuple (A, B, GG, h) such that A, B, h in
95    k[t], GG = [gg1, ..., ggm] in k(t)^m, and for any solution c1, ..., cm in
96    Const(k) and q in k<t> of a*Dq + b*q == Sum(ci*gi, (i, 1, m)), r == q*h in
97    k[t] satisfies A*Dr + B*r == Sum(ci*ggi, (i, 1, m)).
98
99    For case == 'primitive', k<t> == k[t], so it returns (a, b, G, 1) in this
100    case.
101    """
102    # TODO: Merge this with the very similar special_denom() in rde.py
103    if case == 'auto':
104        case = DE.case
105
106    if case == 'exp':
107        p = Poly(DE.t, DE.t)
108    elif case == 'tan':
109        p = Poly(DE.t**2 + 1, DE.t)
110    elif case in ['primitive', 'base']:
111        B = ba.quo(bd)
112        return a, B, G, Poly(1, DE.t)
113    else:
114        raise ValueError("case must be one of {'exp', 'tan', 'primitive', "
115                         f"'base'}}, not {case}.")
116
117    nb = order_at(ba, p, DE.t) - order_at(bd, p, DE.t)
118    nc = min(order_at(Ga, p, DE.t) - order_at(Gd, p, DE.t) for Ga, Gd in G)
119    n = min(0, nc - min(0, nb))
120    if not nb:
121        # Possible cancellation.
122        if case == 'exp':
123            dcoeff = DE.d.quo(Poly(DE.t, DE.t))
124            # We are guaranteed to not have problems,
125            # because case != 'base'.
126            with DecrementLevel(DE):
127                alphaa, alphad = frac_in(-ba.eval(0)/bd.eval(0)/a.eval(0), DE.t)
128                etaa, etad = frac_in(dcoeff, DE.t)
129                A = parametric_log_deriv(alphaa, alphad, etaa, etad, DE)
130                if A is not None:
131                    a, m, _ = A
132                    if a == 1:
133                        n = min(n, m)
134
135        elif case == 'tan':
136            dcoeff = DE.d.quo(Poly(DE.t**2 + 1, DE.t))
137            # We are guaranteed to not have problems,
138            # because case != 'base'.
139            with DecrementLevel(DE):
140                betaa, alphaa, alphad = real_imag(ba, bd*a, DE.t)
141                betad = alphad
142                etaa, etad = frac_in(dcoeff, DE.t)
143                if recognize_log_derivative(2*betaa, betad, DE):
144                    A = parametric_log_deriv(alphaa, alphad, etaa, etad, DE)
145                    B = parametric_log_deriv(betaa, betad, etaa, etad, DE)
146                    if A is not None and B is not None:
147                        a, s, _ = A
148                        if a == 1:
149                            n = min(n, s/2)
150
151    N = max(0, -nb)
152    pN = p**N
153    pn = p**-n  # This is 1/h
154
155    A = a*pN
156    B = ba*pN.quo(bd) + Poly(n, DE.t)*a*derivation(p, DE).quo(p)*pN
157    G = [(Ga*pN*pn).cancel(Gd, include=True) for Ga, Gd in G]
158    h = pn
159
160    # (a*p**N, (b + n*a*Dp/p)*p**N, g1*p**(N - n), ..., gm*p**(N - n), p**-n)
161    return A, B, G, h
162
163
164def prde_linear_constraints(a, b, G, DE):
165    """
166    Parametric Risch Differential Equation - Generate linear constraints on the constants.
167
168    Given a derivation D on k[t], a, b, in k[t] with gcd(a, b) == 1, and
169    G = [g1, ..., gm] in k(t)^m, return Q = [q1, ..., qm] in k[t]^m and a
170    matrix M with entries in k(t) such that for any solution c1, ..., cm in
171    Const(k) and p in k[t] of a*Dp + b*p == Sum(ci*gi, (i, 1, m)),
172    (c1, ..., cm) is a solution of Mx == 0, and p and the ci satisfy
173    a*Dp + b*p == Sum(ci*qi, (i, 1, m)).
174
175    Because M has entries in k(t), and because Matrix doesn't play well with
176    Poly, M will be a Matrix of Basic expressions.
177    """
178    m = len(G)
179
180    _, Gds = list(zip(*G))
181    d = functools.reduce(lambda i, j: i.lcm(j), Gds)
182    d = Poly(d, field=True)
183    Q = [(ga*d.quo(gd)).div(d) for ga, gd in G]
184
185    if not all(ri.is_zero for _, ri in Q):
186        N = max(ri.degree(DE.t) for _, ri in Q)
187        M = Matrix(N + 1, m, lambda i, j: Q[j][1].coeff_monomial((i,)))
188    else:
189        M = Matrix()  # No constraints, return the empty matrix.
190
191    qs, _ = list(zip(*Q))
192    return qs, M
193
194
195def constant_system(A, u, DE):
196    """
197    Generate a system for the constant solutions.
198
199    Given a differential field (K, D) with constant field C = Const(K), a Matrix
200    A, and a vector (Matrix) u with coefficients in K, returns the tuple
201    (B, v, s), where B is a Matrix with coefficients in C and v is a vector
202    (Matrix) such that either v has coefficients in C, in which case s is True
203    and the solutions in C of Ax == u are exactly all the solutions of Bx == v,
204    or v has a non-constant coefficient, in which case s is False Ax == u has no
205    constant solution.
206
207    This algorithm is used both in solving parametric problems and in
208    determining if an element a of K is a derivative of an element of K or the
209    logarithmic derivative of a K-radical using the structure theorem approach.
210
211    Because Poly does not play well with Matrix yet, this algorithm assumes that
212    all matrix entries are Basic expressions.
213    """
214    if not A:
215        return A, u
216    Au = A.row_join(u)
217    Au = Au.rref(simplify=cancel)[0]
218    # Warning: This will NOT return correct results if cancel() cannot reduce
219    # an identically zero expression to 0.  The danger is that we might
220    # incorrectly prove that an integral is nonelementary (such as
221    # risch_integrate(exp((sin(x)**2 + cos(x)**2 - 1)*x**2), x).
222    # But this is a limitation in computer algebra in general, and implicit
223    # in the correctness of the Risch Algorithm is the computability of the
224    # constant field (actually, this same correctness problem exists in any
225    # algorithm that uses rref()).
226    #
227    # We therefore limit ourselves to constant fields that are computable
228    # via the cancel() function, in order to prevent a speed bottleneck from
229    # calling some more complex simplification function (rational function
230    # coefficients will fall into this class).  Furthermore, (I believe) this
231    # problem will only crop up if the integral explicitly contains an
232    # expression in the constant field that is identically zero, but cannot
233    # be reduced to such by cancel().  Therefore, a careful user can avoid this
234    # problem entirely by being careful with the sorts of expressions that
235    # appear in his integrand in the variables other than the integration
236    # variable (the structure theorems should be able to completely decide these
237    # problems in the integration variable).
238
239    Au = Au.applyfunc(cancel)
240    A, u = Au[:, :-1], Au[:, -1]
241
242    for j in range(A.cols):
243        for i in range(A.rows):
244            if A[i, j].has(*DE.T):
245                # This assumes that const(F(t0, ..., tn) == const(K) == F
246                Ri = A[i, :]
247                # Rm+1; m = A.rows
248                Rm1 = Ri.applyfunc(lambda x: derivation(x, DE, basic=True) /
249                                   derivation(A[i, j], DE, basic=True))
250                Rm1 = Rm1.applyfunc(cancel)
251                um1 = cancel(derivation(u[i], DE, basic=True) /
252                             derivation(A[i, j], DE, basic=True))
253
254                for s in range(A.rows):
255                    # A[s, :] = A[s, :] - A[s, i]*A[:, m+1]
256                    Asj = A[s, j]
257                    A.row_op(s, lambda r, jj: cancel(r - Asj*Rm1[jj]))
258                    # u[s] = u[s] - A[s, j]*u[m+1
259                    u.row_op(s, lambda r, jj: cancel(r - Asj*um1))
260
261                A = A.col_join(Rm1)
262                u = u.col_join(Matrix([um1]))
263
264    return A, u
265
266
267def prde_spde(a, b, Q, n, DE):
268    """
269    Special Polynomial Differential Equation algorithm: Parametric Version.
270
271    Given a derivation D on k[t], an integer n, and a, b, q1, ..., qm in k[t]
272    with deg(a) > 0 and gcd(a, b) == 1, return (A, B, Q, R, n1), with
273    Qq = [q1, ..., qm] and R = [r1, ..., rm], such that for any solution
274    c1, ..., cm in Const(k) and q in k[t] of degree at most n of
275    a*Dq + b*q == Sum(ci*gi, (i, 1, m)), p = (q - Sum(ci*ri, (i, 1, m)))/a has
276    degree at most n1 and satisfies A*Dp + B*p == Sum(ci*qi, (i, 1, m))
277    """
278    R, Z = list(zip(*[gcdex_diophantine(b, a, qi) for qi in Q]))
279
280    A = a
281    B = b + derivation(a, DE)
282    Qq = [zi - derivation(ri, DE) for ri, zi in zip(R, Z)]
283    R = list(R)
284    n1 = n - a.degree(DE.t)
285
286    return A, B, Qq, R, n1
287
288
289def prde_no_cancel_b_large(b, Q, n, DE):
290    """
291    Parametric Poly Risch Differential Equation - No cancellation: deg(b) large enough.
292
293    Given a derivation D on k[t], n in ZZ, and b, q1, ..., qm in k[t] with
294    b != 0 and either D == d/dt or deg(b) > max(0, deg(D) - 1), returns
295    h1, ..., hr in k[t] and a matrix A with coefficients in Const(k) such that
296    if c1, ..., cm in Const(k) and q in k[t] satisfy deg(q) <= n and
297    Dq + b*q == Sum(ci*qi, (i, 1, m)), then q = Sum(dj*hj, (j, 1, r)), where
298    d1, ..., dr in Const(k) and A*Matrix([[c1, ..., cm, d1, ..., dr]]).T == 0.
299    """
300    db = b.degree(DE.t)
301    m = len(Q)
302    H = [Poly(0, DE.t)]*m
303
304    for N in range(n, -1, -1):  # [n, ..., 0]
305        for i in range(m):
306            si = Q[i].coeff_monomial((N + db,))/b.LC()
307            sitn = Poly(si*DE.t**N, DE.t)
308            H[i] = H[i] + sitn
309            Q[i] = Q[i] - derivation(sitn, DE) - b*sitn
310
311    if all(qi.is_zero for qi in Q):
312        dc = -1
313        M = zeros(0, 2)
314    else:
315        dc = max(qi.degree(DE.t) for qi in Q)
316        M = Matrix(dc + 1, m, lambda i, j: Q[j].coeff_monomial((i,)))
317    A, _ = constant_system(M, zeros(dc + 1, 1), DE)
318    c = eye(m)
319    A = A.row_join(zeros(A.rows, m)).col_join(c.row_join(-c))
320
321    return H, A
322
323
324def prde_no_cancel_b_small(b, Q, n, DE):
325    """
326    Parametric Poly Risch Differential Equation - No cancellation: deg(b) small enough.
327
328    Given a derivation D on k[t], n in ZZ, and b, q1, ..., qm in k[t] with
329    deg(b) < deg(D) - 1 and either D == d/dt or deg(D) >= 2, returns
330    h1, ..., hr in k[t] and a matrix A with coefficients in Const(k) such that
331    if c1, ..., cm in Const(k) and q in k[t] satisfy deg(q) <= n and
332    Dq + b*q == Sum(ci*qi, (i, 1, m)) then q = Sum(dj*hj, (j, 1, r)) where
333    d1, ..., dr in Const(k) and A*Matrix([[c1, ..., cm, d1, ..., dr]]).T == 0.
334    """
335    m = len(Q)
336    H = [Poly(0, DE.t)]*m
337
338    for N in range(n, 0, -1):  # [n, ..., 1]
339        for i in range(m):
340            si = Q[i].coeff_monomial((N + DE.d.degree(DE.t) - 1,))/(N*DE.d.LC())
341            sitn = Poly(si*DE.t**N, DE.t)
342            H[i] = H[i] + sitn
343            Q[i] = Q[i] - derivation(sitn, DE) - b*sitn
344
345    if b.degree(DE.t) > 0:
346        for i in range(m):
347            si = Poly(Q[i].coeff_monomial((b.degree(DE.t),))/b.LC(), DE.t)
348            H[i] = H[i] + si
349            Q[i] = Q[i] - derivation(si, DE) - b*si
350        if all(qi.is_zero for qi in Q):
351            dc = -1
352            M = Matrix()
353        else:
354            dc = max(qi.degree(DE.t) for qi in Q)
355            M = Matrix(dc + 1, m, lambda i, j: Q[j].coeff_monomial((i,)))
356        A, _ = constant_system(M, zeros(dc + 1, 1), DE)
357        c = eye(m)
358        A = A.row_join(zeros(A.rows, m)).col_join(c.row_join(-c))
359        return H, A
360    else:
361        # TODO: implement this (requires recursive param_rischDE() call)
362        raise NotImplementedError
363
364
365def limited_integrate_reduce(fa, fd, G, DE):
366    """
367    Simpler version of step 1 & 2 for the limited integration problem.
368
369    Given a derivation D on k(t) and f, g1, ..., gn in k(t), return
370    (a, b, h, N, g, V) such that a, b, h in k[t], N is a non-negative integer,
371    g in k(t), V == [v1, ..., vm] in k(t)^m, and for any solution v in k(t),
372    c1, ..., cm in C of f == Dv + Sum(ci*wi, (i, 1, m)), p = v*h is in k<t>, and
373    p and the ci satisfy a*Dp + b*p == g + Sum(ci*vi, (i, 1, m)).  Furthermore,
374    if S1irr == Sirr, then p is in k[t], and if t is nonlinear or Liouvillian
375    over k, then deg(p) <= N.
376
377    So that the special part is always computed, this function calls the more
378    general prde_special_denom() automatically if it cannot determine that
379    S1irr == Sirr.  Furthermore, it will automatically call bound_degree() when
380    t is linear and non-Liouvillian, which for the transcendental case, implies
381    that Dt == a*t + b with for some a, b in k*.
382    """
383    dn, ds = splitfactor(fd, DE)
384    E = [splitfactor(gd, DE) for _, gd in G]
385    En, Es = list(zip(*E))
386    c = functools.reduce(lambda i, j: i.lcm(j), (dn,) + En)  # lcm(dn, en1, ..., enm)
387    hn = c.gcd(c.diff(DE.t))
388    a = hn
389    b = -derivation(hn, DE)
390    N = 0
391
392    # These are the cases where we know that S1irr = Sirr, but there could be
393    # others, and this algorithm will need to be extended to handle them.
394    if DE.case in ['base', 'primitive', 'exp', 'tan']:
395        hs = functools.reduce(lambda i, j: i.lcm(j), (ds,) + Es)  # lcm(ds, es1, ..., esm)
396        a = hn*hs
397        b = -derivation(hn, DE) - (hn*derivation(hs, DE)).quo(hs)
398        mu = min(order_at_oo(fa, fd, DE.t), min(order_at_oo(ga, gd, DE.t)
399                                                for ga, gd in G))
400        # So far, all the above are also nonlinear or Liouvillian, but if this
401        # changes, then this will need to be updated to call bound_degree()
402        # as per the docstring of this function (DE.case == 'other_linear').
403        N = hn.degree(DE.t) + hs.degree(DE.t) + max(0, 1 - DE.d.degree(DE.t) - mu)
404    else:
405        # TODO: implement this
406        raise NotImplementedError
407
408    V = [(-a*hn*ga).cancel(gd, include=True) for ga, gd in G]
409    return a, b, a, N, (a*hn*fa).cancel(fd, include=True), V
410
411
412def limited_integrate(fa, fd, G, DE):
413    """
414    Solves the limited integration problem:  f = Dv + Sum(ci*wi, (i, 1, n))
415    """
416    fa, fd = fa*Poly(1/fd.LC(), DE.t), fd.monic()
417    A, B, h, N, g, V = limited_integrate_reduce(fa, fd, G, DE)
418    V = [g] + V
419    g = A.gcd(B)
420    A, B, V = A.quo(g), B.quo(g), [via.cancel(vid*g, include=True) for
421                                   via, vid in V]
422    Q, M = prde_linear_constraints(A, B, V, DE)
423    M, _ = constant_system(M, zeros(M.rows, 1), DE)
424    l = M.nullspace()
425    if M == Matrix() or len(l) > 1:
426        # Continue with param_rischDE()
427        raise NotImplementedError('param_rischDE() is required to solve this '
428                                  'integral.')
429    elif len(l) == 0:
430        raise NonElementaryIntegralException
431    elif len(l) == 1:
432        # The c1 == 1.  In this case, we can assume a normal Risch DE
433        if l[0][0].is_zero:
434            raise NonElementaryIntegralException
435        else:
436            l[0] *= 1/l[0][0]
437            C = sum(Poly(i, DE.t)*q for (i, q) in zip(l[0], Q))
438            # Custom version of rischDE() that uses the already computed
439            # denominator and degree bound from above.
440            B, C, m, alpha, beta = spde(A, B, C, N, DE)
441            y = solve_poly_rde(B, C, m, DE)
442
443            return (alpha*y + beta, h), list(l[0][1:])
444    else:
445        raise NotImplementedError
446
447
448def parametric_log_deriv_heu(fa, fd, wa, wd, DE, c1=None):
449    """
450    Parametric logarithmic derivative heuristic.
451
452    Given a derivation D on k[t], f in k(t), and a hyperexponential monomial
453    theta over k(t), raises either NotImplementedError, in which case the
454    heuristic failed, or returns None, in which case it has proven that no
455    solution exists, or returns a solution (n, m, v) of the equation
456    n*f == Dv/v + m*Dtheta/theta, with v in k(t)* and n, m in ZZ with n != 0.
457
458    If this heuristic fails, the structure theorem approach will need to be
459    used.
460
461    The argument w == Dtheta/theta
462    """
463    # TODO: finish writing this and write tests
464    c1 = c1 or Dummy('c1')
465
466    p, _ = fa.div(fd)
467    q, _ = wa.div(wd)
468
469    B = max(0, derivation(DE.t, DE).degree(DE.t) - 1)
470    C = max(p.degree(DE.t), q.degree(DE.t))
471
472    if q.degree(DE.t) > B:
473        eqs = [p.coeff_monomial((i,)) - c1*q.coeff_monomial((i,)) for i in range(B + 1, C + 1)]
474        s = solve(eqs, c1)
475        if not s or not s[0][c1].is_Rational:
476            # deg(q) > B, no solution for c.
477            return
478
479        N, M = s[0][c1].as_numer_denom()  # N and M are integers
480        N, M = Poly(N, DE.t), Poly(M, DE.t)
481
482        nfmwa = N*fa*wd - M*wa*fd
483        nfmwd = fd*wd
484        Qv = is_log_deriv_k_t_radical_in_field(N*fa*wd - M*wa*fd, fd*wd, DE,
485                                               'auto')
486        if Qv is None:
487            # (N*f - M*w) is not the logarithmic derivative of a k(t)-radical.
488            return
489
490        Q, e, v = Qv
491        if e != 1:
492            return
493
494        if Q.is_zero or v.is_zero:
495            return
496
497        return Q*N, Q*M, v
498
499    if p.degree(DE.t) > B:
500        return
501
502    c = lcm(fd.as_poly(DE.t).LC(), wd.as_poly(DE.t).LC())
503    l = fd.monic().lcm(wd.monic())*Poly(c, DE.t)
504    ln, ls = splitfactor(l, DE)
505    z = ls*ln.gcd(ln.diff(DE.t))
506
507    if not z.has(DE.t):
508        raise NotImplementedError('parametric_log_deriv_heu() '
509                                  'heuristic failed: z in k.')
510
511    _, r1 = (fa*l.quo(fd)).div(z)  # (l*f).div(z)
512    _, r2 = (wa*l.quo(wd)).div(z)  # (l*w).div(z)
513
514    eqs = [r1.coeff_monomial((i,)) - c1*r2.coeff_monomial((i,)) for i in range(z.degree(DE.t))]
515    s = solve(eqs, c1)
516    if not s or not s[0][c1].is_Rational:
517        # deg(q) <= B, no solution for c.
518        return
519
520    M, N = s[0][c1].as_numer_denom()
521
522    nfmwa = N.as_poly(DE.t)*fa*wd - M.as_poly(DE.t)*wa*fd
523    nfmwd = fd*wd
524    Qv = is_log_deriv_k_t_radical_in_field(nfmwa, nfmwd, DE)
525    if Qv is None:
526        # (N*f - M*w) is not the logarithmic derivative of a k(t)-radical.
527        return
528
529    Q, v = Qv
530
531    if Q.is_zero or v.is_zero:
532        return
533
534    return Q*N, Q*M, v
535
536
537def parametric_log_deriv(fa, fd, wa, wd, DE):
538    # TODO: Write the full algorithm using the structure theorems.
539    # try:
540    A = parametric_log_deriv_heu(fa, fd, wa, wd, DE)
541    # except NotImplementedError:
542    # Heuristic failed, we have to use the full method.
543    # TODO: This could be implemented more efficiently.  It isn't too
544    # worrisome, because the heuristic handles most difficult cases.
545    return A
546
547
548def is_deriv_k(fa, fd, DE):
549    r"""
550    Checks if Df/f is the derivative of an element of k(t).
551
552    a in k(t) is the derivative of an element of k(t) if there exists b in k(t)
553    such that a = Db.  Either returns (ans, u), such that Df/f == Du, or None,
554    which means that Df/f is not the derivative of an element of k(t).  ans is
555    a list of tuples such that Add(*[i*j for i, j in ans]) == u.  This is useful
556    for seeing exactly which elements of k(t) produce u.
557
558    This function uses the structure theorem approach, which says that for any
559    f in K, Df/f is the derivative of a element of K if and only if there are ri
560    in QQ such that::
561
562            ---               ---       Dt
563            \    r  * Dt   +  \    r  *   i      Df
564            /     i     i     /     i   ---   =  --.
565            ---               ---        t        f
566         i in L            i in E         i
567               K/C(x)            K/C(x)
568
569
570    Where C = Const(K), L_K/C(x) = { i in {1, ..., n} such that t_i is
571    transcendental over C(x)(t_1, ..., t_i-1) and Dt_i = Da_i/a_i, for some a_i
572    in C(x)(t_1, ..., t_i-1)* } (i.e., the set of all indices of logarithmic
573    monomials of K over C(x)), and E_K/C(x) = { i in {1, ..., n} such that t_i
574    is transcendental over C(x)(t_1, ..., t_i-1) and Dt_i/t_i = Da_i, for some
575    a_i in C(x)(t_1, ..., t_i-1) } (i.e., the set of all indices of
576    hyperexponential monomials of K over C(x)).  If K is an elementary extension
577    over C(x), then the cardinality of L_K/C(x) U E_K/C(x) is exactly the
578    transcendence degree of K over C(x).  Furthermore, because Const_D(K) ==
579    Const_D(C(x)) == C, deg(Dt_i) == 1 when t_i is in E_K/C(x) and
580    deg(Dt_i) == 0 when t_i is in L_K/C(x), implying in particular that E_K/C(x)
581    and L_K/C(x) are disjoint.
582
583    The sets L_K/C(x) and E_K/C(x) must, by their nature, be computed
584    recursively using this same function.  Therefore, it is required to pass
585    them as indices to D (or T).  E_args are the arguments of the
586    hyperexponentials indexed by E_K (i.e., if i is in E_K, then T[i] ==
587    exp(E_args[i])).  This is needed to compute the final answer u such that
588    Df/f == Du.
589
590    log(f) will be the same as u up to a additive constant.  This is because
591    they will both behave the same as monomials. For example, both log(x) and
592    log(2*x) == log(x) + log(2) satisfy Dt == 1/x, because log(2) is constant.
593    Therefore, the term const is returned.  const is such that
594    log(const) + f == u.  This is calculated by dividing the arguments of one
595    logarithm from the other.  Therefore, it is necessary to pass the arguments
596    of the logarithmic terms in L_args.
597
598    To handle the case where we are given Df/f, not f, use is_deriv_k_in_field().
599    """
600    # Compute Df/f
601    dfa, dfd = fd*(fd*derivation(fa, DE) - fa*derivation(fd, DE)), fd**2*fa
602    dfa, dfd = dfa.cancel(dfd, include=True)
603
604    # Our assumption here is that each monomial is recursively transcendental
605    if len(DE.L_K) + len(DE.E_K) != len(DE.D) - 1:
606        if [i for i in DE.cases if i == 'tan'] or \
607                {i for i in DE.cases if i == 'primitive'} - set(DE.L_K):
608            raise NotImplementedError('Real version of the structure '
609                                      'theorems with hypertangent support is not yet implemented.')
610
611        # TODO: What should really be done in this case?
612        raise NotImplementedError('Nonelementary extensions not supported '
613                                  'in the structure theorems.')
614
615    E_part = [DE.D[i].quo(Poly(DE.T[i], DE.T[i])).as_expr() for i in DE.E_K]
616    L_part = [DE.D[i].as_expr() for i in DE.L_K]
617
618    lhs = Matrix([E_part + L_part])
619    rhs = Matrix([dfa.as_expr()/dfd.as_expr()])
620
621    A, u = constant_system(lhs, rhs, DE)
622
623    if all(derivation(i, DE, basic=True).is_zero for i in u) and A:
624        # If the elements of u are all constant
625        # Note: See comment in constant_system
626
627        # Also note: derivation(basic=True) calls cancel()
628        if not all(i.is_Rational for i in u):
629            raise NotImplementedError('Cannot work with non-rational '
630                                      'coefficients in this case.')
631        else:
632            terms = DE.E_args + [DE.T[i] for i in DE.L_K]
633            ans = list(zip(terms, u))
634            result = Add(*[Mul(i, j) for i, j in ans])
635            argterms = [DE.T[i] for i in DE.E_K] + DE.L_args
636            l, ld = [], []
637            for i, j in zip(argterms, u):
638                # We need to get around things like sqrt(x**2) != x
639                # and also sqrt(x**2 + 2*x + 1) != x + 1
640                i, d = i.as_numer_denom()
641                icoeff, iterms = sqf_list(i)
642                l.append(Mul(*([Pow(icoeff, j)] + [Pow(b, e*j) for b, e in iterms])))
643                dcoeff, dterms = sqf_list(d)
644                ld.append(Mul(*([Pow(dcoeff, j)] + [Pow(b, e*j) for b, e in dterms])))
645            const = cancel(fa.as_expr()/fd.as_expr()/Mul(*l)*Mul(*ld))
646
647            return ans, result, const
648
649
650def is_log_deriv_k_t_radical(fa, fd, DE, Df=True):
651    r"""
652    Checks if Df is the logarithmic derivative of a k(t)-radical.
653
654    b in k(t) can be written as the logarithmic derivative of a k(t) radical if
655    there exist n in ZZ and u in k(t) with n, u != 0 such that n*b == Du/u.
656    Either returns (ans, u, n, const) or None, which means that Df cannot be
657    written as the logarithmic derivative of a k(t)-radical.  ans is a list of
658    tuples such that Mul(*[i**j for i, j in ans]) == u.  This is useful for
659    seeing exactly what elements of k(t) produce u.
660
661    This function uses the structure theorem approach, which says that for any
662    f in K, Df is the logarithmic derivative of a K-radical if and only if there
663    are ri in QQ such that::
664
665            ---               ---       Dt
666            \    r  * Dt   +  \    r  *   i
667            /     i     i     /     i   ---   =  Df.
668            ---               ---        t
669         i in L            i in E         i
670               K/C(x)            K/C(x)
671
672
673    Where C = Const(K), L_K/C(x) = { i in {1, ..., n} such that t_i is
674    transcendental over C(x)(t_1, ..., t_i-1) and Dt_i = Da_i/a_i, for some a_i
675    in C(x)(t_1, ..., t_i-1)* } (i.e., the set of all indices of logarithmic
676    monomials of K over C(x)), and E_K/C(x) = { i in {1, ..., n} such that t_i
677    is transcendental over C(x)(t_1, ..., t_i-1) and Dt_i/t_i = Da_i, for some
678    a_i in C(x)(t_1, ..., t_i-1) } (i.e., the set of all indices of
679    hyperexponential monomials of K over C(x)).  If K is an elementary extension
680    over C(x), then the cardinality of L_K/C(x) U E_K/C(x) is exactly the
681    transcendence degree of K over C(x).  Furthermore, because Const_D(K) ==
682    Const_D(C(x)) == C, deg(Dt_i) == 1 when t_i is in E_K/C(x) and
683    deg(Dt_i) == 0 when t_i is in L_K/C(x), implying in particular that E_K/C(x)
684    and L_K/C(x) are disjoint.
685
686    The sets L_K/C(x) and E_K/C(x) must, by their nature, be computed
687    recursively using this same function.  Therefore, it is required to pass
688    them as indices to D (or T).  L_args are the arguments of the logarithms
689    indexed by L_K (i.e., if i is in L_K, then T[i] == log(L_args[i])).  This is
690    needed to compute the final answer u such that n*f == Du/u.
691
692    exp(f) will be the same as u up to a multiplicative constant.  This is
693    because they will both behave the same as monomials.  For example, both
694    exp(x) and exp(x + 1) == E*exp(x) satisfy Dt == t. Therefore, the term const
695    is returned.  const is such that exp(const)*f == u.  This is calculated by
696    subtracting the arguments of one exponential from the other.  Therefore, it
697    is necessary to pass the arguments of the exponential terms in E_args.
698
699    To handle the case where we are given Df, not f, use
700    is_log_deriv_k_t_radical_in_field().
701    """
702    if Df:
703        dfa, dfd = (fd*derivation(fa, DE) - fa*derivation(fd, DE)).cancel(fd**2,
704                                                                          include=True)
705    else:
706        dfa, dfd = fa, fd
707
708    # Our assumption here is that each monomial is recursively transcendental
709    if len(DE.L_K) + len(DE.E_K) != len(DE.D) - 1:
710        if [i for i in DE.cases if i == 'tan'] or \
711                {i for i in DE.cases if i == 'primitive'} - set(DE.L_K):
712            raise NotImplementedError('Real version of the structure '
713                                      'theorems with hypertangent support is not yet implemented.')
714
715        # TODO: What should really be done in this case?
716        raise NotImplementedError('Nonelementary extensions not supported '
717                                  'in the structure theorems.')
718
719    E_part = [DE.D[i].quo(Poly(DE.T[i], DE.T[i])).as_expr() for i in DE.E_K]
720    L_part = [DE.D[i].as_expr() for i in DE.L_K]
721
722    lhs = Matrix([E_part + L_part])
723    rhs = Matrix([dfa.as_expr()/dfd.as_expr()])
724
725    A, u = constant_system(lhs, rhs, DE)
726    if all(derivation(i, DE, basic=True).is_zero for i in u) and A:
727        # If the elements of u are all constant
728        # Note: See comment in constant_system
729
730        # Also note: derivation(basic=True) calls cancel()
731        if not all(i.is_Rational for i in u):
732            # TODO: But maybe we can tell if they're not rational, like
733            # log(2)/log(3). Also, there should be an option to continue
734            # anyway, even if the result might potentially be wrong.
735            raise NotImplementedError('Cannot work with non-rational '
736                                      'coefficients in this case.')
737        else:
738            n = functools.reduce(math.lcm, [i.as_numer_denom()[1] for i in u])
739            u *= Integer(n)
740            terms = [DE.T[i] for i in DE.E_K] + DE.L_args
741            ans = list(zip(terms, u))
742            result = Mul(*[Pow(i, j) for i, j in ans])
743
744            # exp(f) will be the same as result up to a multiplicative
745            # constant.  We now find the log of that constant.
746            argterms = DE.E_args + [DE.T[i] for i in DE.L_K]
747            const = cancel(fa.as_expr()/fd.as_expr() -
748                           Add(*[Mul(i, j/n) for i, j in zip(argterms, u)]))
749
750            return ans, result, n, const
751
752
753def is_log_deriv_k_t_radical_in_field(fa, fd, DE, case='auto', z=None):
754    """
755    Checks if f can be written as the logarithmic derivative of a k(t)-radical.
756
757    f in k(t) can be written as the logarithmic derivative of a k(t) radical if
758    there exist n in ZZ and u in k(t) with n, u != 0 such that n*f == Du/u.
759    Either returns (n, u) or None, which means that f cannot be written as the
760    logarithmic derivative of a k(t)-radical.
761
762    case is one of {'primitive', 'exp', 'tan', 'auto'} for the primitive,
763    hyperexponential, and hypertangent cases, respectively.  If case is 'auto',
764    it will attempt to determine the type of the derivation automatically.
765    """
766    fa, fd = fa.cancel(fd, include=True)
767
768    # f must be simple
769    n, s = splitfactor(fd, DE)
770    if not s.is_one:
771        pass
772
773    z = z or Dummy('z')
774    H, b = residue_reduce(fa, fd, DE, z=z)
775    if not b:
776        # I will have to verify, but I believe that the answer should be
777        # None in this case. This should never happen for the
778        # functions given when solving the parametric logarithmic
779        # derivative problem when integration elementary functions (see
780        # Bronstein's book, page 255), so most likely this indicates a bug.
781        return
782
783    roots = [(i, i.real_roots()) for i, _ in H]
784    if not all(len(j) == i.degree() and all(k.is_Rational for k in j) for
785               i, j in roots):
786        # If f is the logarithmic derivative of a k(t)-radical, then all the
787        # roots of the resultant must be rational numbers.
788        return
789
790    # [(a, i), ...], where i*log(a) is a term in the log-part of the integral
791    # of f
792    _, residues = list(zip(*roots)) or [[], []]
793    # Note: this might be empty, but everything below should work find in that
794    # case (it should be the same as if it were [[1, 1]])
795    residueterms = [(H[j][1].subs({z: i}), i) for j in range(len(H)) for
796                    i in residues[j]]
797
798    # TODO: finish writing this and write tests
799
800    p = cancel(fa.as_expr()/fd.as_expr() - residue_reduce_derivation(H, DE, z))
801
802    p = p.as_poly(DE.t)
803    if p is None:
804        # f - Dg will be in k[t] if f is the logarithmic derivative of a k(t)-radical
805        return
806
807    if p.degree(DE.t) >= max(1, DE.d.degree(DE.t)):
808        return
809
810    if case == 'auto':
811        case = DE.case
812
813    if case == 'exp':
814        wa, wd = derivation(DE.t, DE).cancel(Poly(DE.t, DE.t), include=True)
815        with DecrementLevel(DE):
816            pa, pd = frac_in(p, DE.t, cancel=True)
817            wa, wd = frac_in((wa, wd), DE.t)
818            A = parametric_log_deriv(pa, pd, wa, wd, DE)
819        if A is None:
820            return
821        n, e, u = A
822        u *= DE.t**e
823
824    elif case == 'primitive':
825        with DecrementLevel(DE):
826            pa, pd = frac_in(p, DE.t)
827            A = is_log_deriv_k_t_radical_in_field(pa, pd, DE, case='auto')
828        if A is None:
829            return
830        n, u = A
831
832    elif case == 'base':
833        # TODO: we can use more efficient residue reduction from ratint()
834        if not fd.is_squarefree or fa.degree() >= fd.degree():
835            # f is the logarithmic derivative in the base case if and only if
836            # f = fa/fd, fd is square-free, deg(fa) < deg(fd), and
837            # gcd(fa, fd) == 1.  The last condition is handled by cancel() above.
838            return
839        # Note: if residueterms = [], returns (1, 1)
840        # f had better be 0 in that case.
841        n = functools.reduce(math.lcm, [i.as_numer_denom()[1] for _, i in residueterms], Integer(1))
842        u = Mul(*[Pow(i, j*n) for i, j in residueterms])
843        return Integer(n), u
844
845    elif case == 'tan':
846        raise NotImplementedError('The hypertangent case is '
847                                  'not yet implemented for is_log_deriv_k_t_radical_in_field()')
848
849    elif case in ['other_linear', 'other_nonlinear']:
850        # XXX: If these are supported by the structure theorems, change to NotImplementedError.
851        raise ValueError(f'The {case} case is not supported in this function.')
852
853    else:
854        raise ValueError("case must be one of {'primitive', 'exp', 'tan', "
855                         f"'base', 'auto'}}, not {case}")
856
857    common_denom = functools.reduce(math.lcm, [i.as_numer_denom()[1]
858                                               for i in [j for _, j in residueterms]] + [n], Integer(1))
859    residueterms = [(i, j*common_denom) for i, j in residueterms]
860    m = common_denom//n
861    if common_denom != n*m:  # Verify exact division
862        raise ValueError('Inexact division')
863    u = cancel(u**m*Mul(*[Pow(i, j) for i, j in residueterms]))
864
865    return Integer(common_denom), u
866