1"""Most of these tests come from the examples in Bronstein's book."""
2from sympy import Poly, symbols, oo, I, Rational
3from sympy.integrals.risch import (DifferentialExtension,
4    NonElementaryIntegralException)
5from sympy.integrals.rde import (order_at, order_at_oo, weak_normalizer,
6    normal_denom, special_denom, bound_degree, spde, solve_poly_rde,
7    no_cancel_equal, cancel_primitive, cancel_exp, rischDE)
8
9from sympy.testing.pytest import raises
10from sympy.abc import x, t, z, n
11
12t0, t1, t2, k = symbols('t:3 k')
13
14
15def test_order_at():
16    a = Poly(t**4, t)
17    b = Poly((t**2 + 1)**3*t, t)
18    c = Poly((t**2 + 1)**6*t, t)
19    d = Poly((t**2 + 1)**10*t**10, t)
20    e = Poly((t**2 + 1)**100*t**37, t)
21    p1 = Poly(t, t)
22    p2 = Poly(1 + t**2, t)
23    assert order_at(a, p1, t) == 4
24    assert order_at(b, p1, t) == 1
25    assert order_at(c, p1, t) == 1
26    assert order_at(d, p1, t) == 10
27    assert order_at(e, p1, t) == 37
28    assert order_at(a, p2, t) == 0
29    assert order_at(b, p2, t) == 3
30    assert order_at(c, p2, t) == 6
31    assert order_at(d, p1, t) == 10
32    assert order_at(e, p2, t) == 100
33    assert order_at(Poly(0, t), Poly(t, t), t) is oo
34    assert order_at_oo(Poly(t**2 - 1, t), Poly(t + 1), t) == \
35        order_at_oo(Poly(t - 1, t), Poly(1, t), t) == -1
36    assert order_at_oo(Poly(0, t), Poly(1, t), t) is oo
37
38def test_weak_normalizer():
39    a = Poly((1 + x)*t**5 + 4*t**4 + (-1 - 3*x)*t**3 - 4*t**2 + (-2 + 2*x)*t, t)
40    d = Poly(t**4 - 3*t**2 + 2, t)
41    DE = DifferentialExtension(extension={'D': [Poly(1, x), Poly(t, t)]})
42    r = weak_normalizer(a, d, DE, z)
43    assert r == (Poly(t**5 - t**4 - 4*t**3 + 4*t**2 + 4*t - 4, t, domain='ZZ[x]'),
44        (Poly((1 + x)*t**2 + x*t, t, domain='ZZ[x]'),
45         Poly(t + 1, t, domain='ZZ[x]')))
46    assert weak_normalizer(r[1][0], r[1][1], DE) == (Poly(1, t), r[1])
47    r = weak_normalizer(Poly(1 + t**2), Poly(t**2 - 1, t), DE, z)
48    assert r == (Poly(t**4 - 2*t**2 + 1, t), (Poly(-3*t**2 + 1, t), Poly(t**2 - 1, t)))
49    assert weak_normalizer(r[1][0], r[1][1], DE, z) == (Poly(1, t), r[1])
50    DE = DifferentialExtension(extension={'D': [Poly(1, x), Poly(1 + t**2)]})
51    r = weak_normalizer(Poly(1 + t**2), Poly(t, t), DE, z)
52    assert r == (Poly(t, t), (Poly(0, t), Poly(1, t)))
53    assert weak_normalizer(r[1][0], r[1][1], DE, z) == (Poly(1, t), r[1])
54
55
56def test_normal_denom():
57    DE = DifferentialExtension(extension={'D': [Poly(1, x)]})
58    raises(NonElementaryIntegralException, lambda: normal_denom(Poly(1, x), Poly(1, x),
59    Poly(1, x), Poly(x, x), DE))
60    fa, fd = Poly(t**2 + 1, t), Poly(1, t)
61    ga, gd = Poly(1, t), Poly(t**2, t)
62    DE = DifferentialExtension(extension={'D': [Poly(1, x), Poly(t**2 + 1, t)]})
63    assert normal_denom(fa, fd, ga, gd, DE) == \
64        (Poly(t, t), (Poly(t**3 - t**2 + t - 1, t), Poly(1, t)), (Poly(1, t),
65        Poly(1, t)), Poly(t, t))
66
67
68def test_special_denom():
69    # TODO: add more tests here
70    DE = DifferentialExtension(extension={'D': [Poly(1, x), Poly(t, t)]})
71    assert special_denom(Poly(1, t), Poly(t**2, t), Poly(1, t), Poly(t**2 - 1, t),
72    Poly(t, t), DE) == \
73        (Poly(1, t), Poly(t**2 - 1, t), Poly(t**2 - 1, t), Poly(t, t))
74#    assert special_denom(Poly(1, t), Poly(2*x, t), Poly((1 + 2*x)*t, t), DE) == 1
75
76    # issue 3940
77    # Note, this isn't a very good test, because the denominator is just 1,
78    # but at least it tests the exp cancellation case
79    DE = DifferentialExtension(extension={'D': [Poly(1, x), Poly(-2*x*t0, t0),
80        Poly(I*k*t1, t1)]})
81    DE.decrement_level()
82    assert special_denom(Poly(1, t0), Poly(I*k, t0), Poly(1, t0), Poly(t0, t0),
83    Poly(1, t0), DE) == \
84        (Poly(1, t0, domain='ZZ'), Poly(I*k, t0, domain='ZZ_I[k,x]'),
85                Poly(t0, t0, domain='ZZ'), Poly(1, t0, domain='ZZ'))
86
87
88    assert special_denom(Poly(1, t), Poly(t**2, t), Poly(1, t), Poly(t**2 - 1, t),
89    Poly(t, t), DE, case='tan') == \
90           (Poly(1, t, t0, domain='ZZ'), Poly(t**2, t0, t, domain='ZZ[x]'),
91            Poly(t, t, t0, domain='ZZ'), Poly(1, t0, domain='ZZ'))
92
93    raises(ValueError, lambda: special_denom(Poly(1, t), Poly(t**2, t), Poly(1, t), Poly(t**2 - 1, t),
94    Poly(t, t), DE, case='unrecognized_case'))
95
96
97def test_bound_degree_fail():
98    # Primitive
99    DE = DifferentialExtension(extension={'D': [Poly(1, x),
100        Poly(t0/x**2, t0), Poly(1/x, t)]})
101    assert bound_degree(Poly(t**2, t), Poly(-(1/x**2*t**2 + 1/x), t),
102        Poly((2*x - 1)*t**4 + (t0 + x)/x*t**3 - (t0 + 4*x**2)/2*x*t**2 + x*t,
103        t), DE) == 3
104
105
106def test_bound_degree():
107    # Base
108    DE = DifferentialExtension(extension={'D': [Poly(1, x)]})
109    assert bound_degree(Poly(1, x), Poly(-2*x, x), Poly(1, x), DE) == 0
110
111    # Primitive (see above test_bound_degree_fail)
112    # TODO: Add test for when the degree bound becomes larger after limited_integrate
113    # TODO: Add test for db == da - 1 case
114
115    # Exp
116    # TODO: Add tests
117    # TODO: Add test for when the degree becomes larger after parametric_log_deriv()
118
119    # Nonlinear
120    DE = DifferentialExtension(extension={'D': [Poly(1, x), Poly(t**2 + 1, t)]})
121    assert bound_degree(Poly(t, t), Poly((t - 1)*(t**2 + 1), t), Poly(1, t), DE) == 0
122
123
124def test_spde():
125    DE = DifferentialExtension(extension={'D': [Poly(1, x), Poly(t**2 + 1, t)]})
126    raises(NonElementaryIntegralException, lambda: spde(Poly(t, t), Poly((t - 1)*(t**2 + 1), t), Poly(1, t), 0, DE))
127    DE = DifferentialExtension(extension={'D': [Poly(1, x), Poly(t, t)]})
128    assert spde(Poly(t**2 + x*t*2 + x**2, t), Poly(t**2/x**2 + (2/x - 1)*t, t),
129        Poly(t**2/x**2 + (2/x - 1)*t, t), 0, DE) == \
130        (Poly(0, t), Poly(0, t), 0, Poly(0, t), Poly(1, t, domain='ZZ(x)'))
131    DE = DifferentialExtension(extension={'D': [Poly(1, x), Poly(t0/x**2, t0), Poly(1/x, t)]})
132    assert spde(Poly(t**2, t), Poly(-t**2/x**2 - 1/x, t),
133    Poly((2*x - 1)*t**4 + (t0 + x)/x*t**3 - (t0 + 4*x**2)/(2*x)*t**2 + x*t, t), 3, DE) == \
134        (Poly(0, t), Poly(0, t), 0, Poly(0, t),
135        Poly(t0*t**2/2 + x**2*t**2 - x**2*t, t, domain='ZZ(x,t0)'))
136    DE = DifferentialExtension(extension={'D': [Poly(1, x)]})
137    assert spde(Poly(x**2 + x + 1, x), Poly(-2*x - 1, x), Poly(x**5/2 +
138    3*x**4/4 + x**3 - x**2 + 1, x), 4, DE) == \
139        (Poly(0, x, domain='QQ'), Poly(x/2 - Rational(1, 4), x), 2, Poly(x**2 + x + 1, x), Poly(x*Rational(5, 4), x))
140    assert spde(Poly(x**2 + x + 1, x), Poly(-2*x - 1, x), Poly(x**5/2 +
141    3*x**4/4 + x**3 - x**2 + 1, x), n, DE) == \
142        (Poly(0, x, domain='QQ'), Poly(x/2 - Rational(1, 4), x), -2 + n, Poly(x**2 + x + 1, x), Poly(x*Rational(5, 4), x))
143    DE = DifferentialExtension(extension={'D': [Poly(1, x), Poly(1, t)]})
144    raises(NonElementaryIntegralException, lambda: spde(Poly((t - 1)*(t**2 + 1)**2, t), Poly((t - 1)*(t**2 + 1), t), Poly(1, t), 0, DE))
145    DE = DifferentialExtension(extension={'D': [Poly(1, x)]})
146    assert spde(Poly(x**2 - x, x), Poly(1, x), Poly(9*x**4 - 10*x**3 + 2*x**2, x), 4, DE) == \
147        (Poly(0, x, domain='ZZ'), Poly(0, x), 0, Poly(0, x), Poly(3*x**3 - 2*x**2, x, domain='QQ'))
148    assert spde(Poly(x**2 - x, x), Poly(x**2 - 5*x + 3, x), Poly(x**7 - x**6 - 2*x**4 + 3*x**3 - x**2, x), 5, DE) == \
149        (Poly(1, x, domain='QQ'), Poly(x + 1, x, domain='QQ'), 1, Poly(x**4 - x**3, x), Poly(x**3 - x**2, x, domain='QQ'))
150
151def test_solve_poly_rde_no_cancel():
152    # deg(b) large
153    DE = DifferentialExtension(extension={'D': [Poly(1, x), Poly(1 + t**2, t)]})
154    assert solve_poly_rde(Poly(t**2 + 1, t), Poly(t**3 + (x + 1)*t**2 + t + x + 2, t),
155    oo, DE) == Poly(t + x, t)
156    # deg(b) small
157    DE = DifferentialExtension(extension={'D': [Poly(1, x)]})
158    assert solve_poly_rde(Poly(0, x), Poly(x/2 - Rational(1, 4), x), oo, DE) == \
159        Poly(x**2/4 - x/4, x)
160    DE = DifferentialExtension(extension={'D': [Poly(1, x), Poly(t**2 + 1, t)]})
161    assert solve_poly_rde(Poly(2, t), Poly(t**2 + 2*t + 3, t), 1, DE) == \
162        Poly(t + 1, t, x)
163    # deg(b) == deg(D) - 1
164    DE = DifferentialExtension(extension={'D': [Poly(1, x), Poly(t**2 + 1, t)]})
165    assert no_cancel_equal(Poly(1 - t, t),
166    Poly(t**3 + t**2 - 2*x*t - 2*x, t), oo, DE) == \
167        (Poly(t**2, t), 1, Poly((-2 - 2*x)*t - 2*x, t))
168
169
170def test_solve_poly_rde_cancel():
171    # exp
172    DE = DifferentialExtension(extension={'D': [Poly(1, x), Poly(t, t)]})
173    assert cancel_exp(Poly(2*x, t), Poly(2*x, t), 0, DE) == \
174        Poly(1, t)
175    assert cancel_exp(Poly(2*x, t), Poly((1 + 2*x)*t, t), 1, DE) == \
176        Poly(t, t)
177    # TODO: Add more exp tests, including tests that require is_deriv_in_field()
178
179    # primitive
180    DE = DifferentialExtension(extension={'D': [Poly(1, x), Poly(1/x, t)]})
181
182    # If the DecrementLevel context manager is working correctly, this shouldn't
183    # cause any problems with the further tests.
184    raises(NonElementaryIntegralException, lambda: cancel_primitive(Poly(1, t), Poly(t, t), oo, DE))
185
186    assert cancel_primitive(Poly(1, t), Poly(t + 1/x, t), 2, DE) == \
187        Poly(t, t)
188    assert cancel_primitive(Poly(4*x, t), Poly(4*x*t**2 + 2*t/x, t), 3, DE) == \
189        Poly(t**2, t)
190
191    # TODO: Add more primitive tests, including tests that require is_deriv_in_field()
192
193
194def test_rischDE():
195    # TODO: Add more tests for rischDE, including ones from the text
196    DE = DifferentialExtension(extension={'D': [Poly(1, x), Poly(t, t)]})
197    DE.decrement_level()
198    assert rischDE(Poly(-2*x, x), Poly(1, x), Poly(1 - 2*x - 2*x**2, x),
199    Poly(1, x), DE) == \
200        (Poly(x + 1, x), Poly(1, x))
201