1import pytest
2
3from diofant import (Eq, Function, I, Lambda, binomial, cos, factorial, gamma,
4                     pi, rf, rsolve, sin, sqrt, symbols)
5from diofant.abc import a, b
6from diofant.solvers.recurr import rsolve_hyper, rsolve_poly, rsolve_ratio
7
8
9__all__ = ()
10
11f, g = symbols('f,g', cls=Function)
12n, k = symbols('n,k', integer=True)
13C0, C1, C2 = symbols('C:3')
14
15
16def test_poly():
17    assert rsolve_poly([-1, -1, 1], 0, n) == (0, [])
18    assert rsolve_poly([-1, -1, 1], 1, n) == (-1, [])
19    assert rsolve_poly([-n**2, n, -1, 1], 1, n) is None
20
21    assert rsolve(-f(n) + (n + 1)*f(n + 1) -
22                  n) == [{f: Lambda(n, C0/factorial(n) + 1)}]
23    assert rsolve(-(n + 1)*f(n) + n*f(n + 1) -
24                  1) == [{f: Lambda(n, C0*n - 1)}]
25    assert rsolve(-(4*n + 2)*f(n) + f(n + 1) - 4*n -
26                  1) == [{f: Lambda(n, 4**n*C0*gamma(n + 1/2)/sqrt(pi) - 1)}]
27
28    assert (rsolve(-f(n) + f(n + 1) - n**5 - n**3) ==
29            [{f: Lambda(n, C0 + n**2*(n**4 - 3*n**3 + 4*n**2 - 3*n + 1)/6)}])
30
31    assert rsolve_poly([1, 1], sqrt(n), n) is None
32
33    assert rsolve_poly([-2, -1, 1],
34                       -2*n**4 - (n + 1)**4 + (n + 2)**4, n) == (n**4, [])
35    assert rsolve(n*f(n) - f(n + 1) - n**3 +
36                  (n + 1)**2) == [{f: Lambda(n, (C0*factorial(n) + n**3)/n)}]
37
38
39def test_ratio():
40    assert rsolve_ratio([-2*n**3 + n**2 + 2*n - 1, 2*n**3 + n**2 - 6*n,
41                         -2*n**3 - 11*n**2 - 18*n - 9, 2*n**3 + 13*n**2 +
42                         22*n + 8], 0, n) == (C2*(2*n - 3)/(n**2 - 1)/2, [C2])
43    assert rsolve_ratio([1, 1], sqrt(n), n) is None
44    assert rsolve_ratio([-n**3, n + 1], n, n) is None
45
46
47def test_hyper():
48    assert rsolve((n**2 - 2)*f(n) - (2*n + 1)*f(n + 1) +
49                  f(n + 2)) == [{f: Lambda(n, C0*rf(-sqrt(2), n) +
50                                           C1*rf(+sqrt(2), n))}]
51
52    assert rsolve((n**2 - k)*f(n) - (2*n + 1)*f(n + 1) +
53                  f(n + 2)) == [{f: Lambda(n, C1*rf(sqrt(k), n) +
54                                           C0*rf(-sqrt(k), n))}]
55
56    assert rsolve(2*n*(n + 1)*f(n) - (n**2 + 3*n - 2)*f(n + 1) +
57                  (n - 1)*f(n + 2)) == [{f: Lambda(n, C1*factorial(n) +
58                                                   C0*2**n)}]
59
60    assert rsolve_hyper([n + 2, -(2*n + 3)*(17*n**2 + 51*n + 39), n + 1],
61                        0, n) == (0, [])
62
63    assert rsolve_hyper([-n - 1, -1, 1], 0, n) == (0, [])
64
65    assert rsolve(-n - f(n) + f(n + 1)) == [{f: Lambda(n, C0 + n*(n - 1)/2)}]
66    assert rsolve(-1 - n - f(n) +
67                  f(n + 1)) == [{f: Lambda(n, C0 + n*(n + 1)/2)}]
68    assert rsolve(-3*(n + n**2) - f(n) +
69                  f(n + 1)) == [{f: Lambda(n, C0 + n**3 - n)}]
70
71    assert rsolve(-n - factorial(n) - f(n) + f(n + 1)) is None
72
73    assert rsolve(-a*f(n) + f(n + 1)) == [{f: Lambda(n, C0*a**n)}]
74    assert rsolve(-a*f(n) +
75                  f(n + 2)) == [{f: Lambda(n, a**(n/2)*((-1)**n*C1 + C0))}]
76
77    assert (rsolve(f(n) + f(n + 1) + f(n + 2)) ==
78            [{f: Lambda(n, 2**-n*(C0*(-1 - sqrt(3)*I)**n +
79                                  C1*(-1 + sqrt(3)*I)**n))}])
80
81    assert rsolve_hyper([1, -2*n/a - 2/a, 1], 0, n) == (0, [])
82
83    assert rsolve_hyper([1, 1], sqrt(n), n) is None
84    assert rsolve_hyper([1, 1], n + sqrt(n), n) is None
85
86
87@pytest.mark.slow
88def test_bulk():
89    funcs = [n, n + 1, n**2, n**3, n**4, n + n**2,
90             27*n + 52*n**2 - 3*n**3 + 12*n**4 - 52*n**5]
91    coeffs = [[-2, 1], [-2, -1, 1], [-1, 1, 1, -1, 1],
92              [-n, 1], [n**2 - n + 12, 1]]
93    for p in funcs:
94        for c in coeffs:
95            q = sum(c[i]*p.subs({n: n + i}) for i in range(len(c)))
96            if p.is_polynomial(n):
97                assert rsolve_poly(c, q, n)[0] == p
98            if p.is_hypergeometric(n) and len(c) <= 3:
99                assert rsolve_hyper(c, q, n)[0].subs({C0: 0, C1: 0,
100                                                      C2: 0}).expand() == p
101
102
103def test_rsolve():
104    eq = f(n + 2) - f(n + 1) - f(n)
105    res = [{f: Lambda(n, 2**(-n)*(C0*(1 + sqrt(5))**n +
106                                  C1*(-sqrt(5) + 1)**n))}]
107
108    assert rsolve(eq) == res
109
110    res = [{k: v.subs({C0: sqrt(5), C1: -sqrt(5)}).simplify()
111            for k, v in r.items()} for r in res]
112
113    assert rsolve(eq, init={f(0): 0, f(1): 5}) == res
114    assert rsolve(f(n) - f(n - 1) - f(n - 2), init={f(0): 0, f(1): 5}) == res
115    assert rsolve(Eq(f(n), f(n - 1) + f(n - 2)), init={f(0): 0, f(1): 5}) == res
116
117    eq = (n - 1)*f(n + 2) - (n**2 + 3*n - 2)*f(n + 1) + 2*n*(n + 1)*f(n)
118    res = [{f: Lambda(n, C1*factorial(n) + C0*2**n)}]
119
120    assert rsolve(eq) == res
121
122    res = [{f: Lambda(n, -3*factorial(n) + 3*2**n)}]
123
124    assert rsolve(eq, init={f(0): 0, f(1): 3}) == res
125
126    eq = f(n) - f(n - 1) - 2
127
128    assert rsolve(eq, f(n)) == [{f: Lambda(n, C0 + 2*n)}]
129    assert rsolve(eq) == [{f: Lambda(n, C0 + 2*n)}]
130    assert rsolve(eq, init={f(0): 0}) == [{f: Lambda(n, 2*n)}]
131    assert rsolve(eq, init={f(0): 1}) == [{f: Lambda(n, 2*n + 1)}]
132    assert rsolve(eq, init={f(0): 0, f(1): 1}) is None
133
134    eq = 3*f(n - 1) - f(n) - 1
135
136    assert rsolve(eq) == [{f: Lambda(n, 3**n*C0 + 1/2)}]
137    assert rsolve(eq, init={f(0): 0}) == [{f: Lambda(n, -3**n/2 + 1/2)}]
138    assert rsolve(eq, init={f(0): 1}) == [{f: Lambda(n, 3**n/2 + 1/2)}]
139    assert rsolve(eq, init={f(0): 2}) == [{f: Lambda(n, 3*3**n/2 + 1/2)}]
140
141    assert rsolve(f(n) - 1/n*f(n - 1),
142                  f(n)) == [{f: Lambda(n, C0/factorial(n))}]
143    assert rsolve(f(n) - 1/n*f(n - 1) - 1, f(n)) is None
144
145    eq = 2*f(n - 1) + (1 - n)*f(n)/n
146
147    assert rsolve(eq) == [{f: Lambda(n, 2**n*C0*n)}]
148    assert rsolve([eq]) == [{f: Lambda(n, 2**n*C0*n)}]
149    assert rsolve(eq, init={f(1): 1}) == [{f: Lambda(n, 2**(n - 1)*n)}]
150    assert rsolve(eq, init={f(1): 2},
151                  simplify=False) == [{f: Lambda(n, 2**(n - 1)*n*2)}]
152    assert rsolve(eq, init={f(1): 2}) == [{f: Lambda(n, 2**n*n)}]
153    assert rsolve(eq, init={f(1): 3}) == [{f: Lambda(n, 3*2**n*n/2)}]
154
155    eq = (n - 1)*(n - 2)*f(n + 2) - (n + 1)*(n + 2)*f(n)
156
157    assert rsolve(eq) == [{f: Lambda(n,
158                                     n*(n - 2)*(n - 1)*((-1)**n*C1 + C0))}]
159    assert rsolve(eq, init={f(3): 6,
160                            f(4): 24}) == [{f: Lambda((n), n*(n - 1)*(n - 2))}]
161    assert (rsolve(eq, init={f(3): 6, f(4): -24}) ==
162            [{f: Lambda(n, (-1)**(n + 1)*n*(n - 2)*(n - 1))}])
163
164    assert rsolve(Eq(f(n + 1), a*f(n)),
165                  init={f(1): a}) == [{f: Lambda(n, a**n)}]
166
167    assert (rsolve(f(n) - a*f(n - 2),
168                   init={f(1): sqrt(a)*(a + b), f(2): a*(a - b)}) ==
169            [{f: Lambda(n, a**(n/2)*(-(-1)**n*b + a))}])
170
171    eq = (-16*n**2 + 32*n - 12)*f(n - 1) + (4*n**2 - 12*n + 9)*f(n)
172
173    assert (rsolve(eq, init={f(1): binomial(2*n + 1, 3)}) ==
174            [{f: Lambda(n, 4**n*n*(2*n - 1)*gamma(n + 3/2)/(3*gamma(n - 1/2)))}])
175
176    assert (rsolve(f(n) + a*(f(n + 1) + f(n - 1))/2) ==
177            [{f: Lambda(n, a**-n*(C0*(-a*sqrt(-1 + a**-2) - 1)**n +
178                                  C1*(a*sqrt(-1 + a**-2) - 1)**n))}])
179
180
181def test_rsolve_raises():
182    pytest.raises(ValueError, lambda: rsolve(f(n) - f(k + 1), f(n)))
183    pytest.raises(ValueError, lambda: rsolve(f(n) - sqrt(n)*f(n + 1)))
184    pytest.raises(ValueError,
185                  lambda: rsolve(f(n) - f(n + 1), f(n), init={g(0): 0}))
186    pytest.raises(NotImplementedError, lambda: rsolve(f(n) - sqrt(n)))
187
188
189def test_sympyissue_6844():
190    eq = f(n + 2) - f(n + 1) + f(n)/4
191
192    assert rsolve(eq) == [{f: Lambda(n, 2**(-n)*(C0 + C1*n))}]
193    assert rsolve(eq, init={f(0): 0,
194                            f(1): 1}) == [{f: Lambda(n, 2**(1 - n)*n)}]
195
196
197def test_diofantissue_294():
198    eq = f(n) - f(n - 1) - 2*f(n - 2) - 2*n
199
200    assert rsolve(eq) == [{f: Lambda(n, (-1)**n*C0 + 2**n*C1 - n - 5/2)}]
201    # issue sympy/sympy#11261
202    assert rsolve(eq, init={f(0): -1,
203                            f(1): 1}) == [{f: Lambda(n, -(-1)**n/2 +
204                                                     2**(n + 1) - n - 5/2)}]
205    # issue sympy/sympy#7055
206    assert rsolve(-2*f(n) + f(n + 1) +
207                  n - 1) == [{f: Lambda(n, 2**n*C0 + n)}]
208
209
210def test_sympyissue_8697():
211    assert rsolve(f(n + 3) - f(n + 2) - f(n + 1) +
212                  f(n)) == [{f: Lambda(n, (-1)**n*C1 + C0 + C2*n)}]
213    assert (rsolve(f(n + 3) + 3*f(n + 2) + 3*f(n + 1) + f(n)) ==
214            [{f: Lambda(n, (-1)**n*(C0 + C1*n + C2*n**2))}])
215
216    assert rsolve(f(n) - 2*f(n - 3) + 5*f(n - 2) - 4*f(n - 1),
217                  init={f(0): 1, f(1): 3,
218                        f(2): 8}) == [{f: Lambda(n, 3*2**n - n - 2)}]
219
220    # From issue thread (but not related to the problem, fixed before):
221    assert rsolve(f(n) - 2*f(n - 1) - n,
222                  init={f(0): 1}) == [{f: Lambda(n, 3*2**n - n - 2)}]
223    assert (rsolve(f(n + 2) - 5*f(n + 1) + 6*f(n) - n) ==
224            [{f: Lambda(n, 2**n*C0 + 3**n*C1 + n/2 + 3/4)}])
225
226
227def test_diofantissue_451():
228    assert rsolve(f(n) - 2*f(n - 1) - 3**n,
229                  init={f(0): 1}) == [{f: Lambda(n, -2**(n + 1) +
230                                                 3**(n + 1))}]
231
232
233def test_diofantissue_456():
234    assert rsolve(f(n) - 2*f(n - 1) - 3**n*n,
235                  init={f(0): 1}) == [{f: Lambda(n, 7*2**n +
236                                                 3**(n + 1)*(n - 2))}]
237
238
239def test_diofantissue_13629():
240    assert rsolve(f(n + 1) - (f(n) + (n + 1)**2),
241                  init={f(0): 0}) == [{f: Lambda(n, n*(2*n**2 + 3*n + 1)/6)}]
242
243
244def test_sympyissue_15553():
245    assert rsolve(Eq(f(n + 1), 2*f(n) +
246                     n**2 + 1)) == [{f: Lambda(n, 2**n*C0 - n**2 - 2*n - 4)}]
247    assert rsolve(Eq(f(n + 1), 2*f(n) + n**2 + 1),
248                  init={f(1): 0}) == [{f: Lambda(n, 7*2**n/2 - n**2 -
249                                                 2*n - 4)}]
250
251
252def test_diofantissue_922():
253    assert rsolve(-2*n/3 + f(n) - f(n - 1) + 2*(n - 1)**3/3 + 2*(n - 1)**2/3,
254                  init={f(0): 0}) == [{f: Lambda(n, n*(-3*n**3 + 2*n**2 +
255                                                       9*n + 4)/18)}]
256
257
258def test_diofantissue_923():
259    assert rsolve(4*f(n) + 4*f(n + 1) +
260                  f(n + 2)) == [{f: Lambda(n, (-2)**n*(C0 + C1*n))}]
261
262
263def test_sympyissue_17982():
264    assert (rsolve(f(n + 3) + 10*f(n + 2) + 32*f(n + 1) + 32*f(n)) ==
265            [{f: Lambda(n, (-2)**n*C0 + (-4)**n*C1 + (-4)**n*C2*n)}])
266
267
268def test_sympyissue_18751():
269    r = symbols('r', real=True, positive=True)
270    theta = symbols('theta', real=True)
271
272    eq = f(n) - 2*r*cos(theta)*f(n - 1) + r**2*f(n - 2)
273    res = [{f: Lambda(n, r**n*(C0*(cos(theta) - I*abs(sin(theta)))**n +
274                      C1*(cos(theta) + I*abs(sin(theta)))**n))}]
275
276    assert rsolve(eq) == res
277
278
279def test_sympyissue_19630():
280    eq = f(n + 3) - 3*f(n + 1) + 2*f(n)
281    res = [{f: Lambda(n, (-2)**n*C1 + C0 + C2*n)}]
282
283    assert rsolve(eq) == res
284
285    res0 = [{f: Lambda(n, (-2)**n + 2*n)}]
286
287    assert rsolve(eq, init={f(1): 0, f(2): 8, f(3): -2}) == res0
288