1r"""
2This module contains :py:meth:`~sympy.solvers.ode.riccati.solve_riccati`,
3a function which gives all rational particular solutions to first order
4Riccati ODEs. A general first order Riccati ODE is given by -
5
6.. math:: y' = b_0(x) + b_1(x)w + b_2(x)w^2
7
8where `b_0, b_1` and `b_2` can be arbitrary rational functions of `x`
9with `b_2 \ne 0`. When `b_2 = 0`, the equation is not a Riccati ODE
10anymore and becomes a Linear ODE. Similarly, when `b_0 = 0`, the equation
11is a Bernoulli ODE. The algorithm presented below can find rational
12solution(s) to all ODEs with `b_2 \ne 0` that have a rational solution,
13or prove that no rational solution exists for the equation.
14
15Background
16==========
17
18A Riccati equation can be transformed to its normal form
19
20.. math:: y' + y^2 = a(x)
21
22using the transformation
23
24.. math:: y = -b_2(x) - \frac{b'_2(x)}{2 b_2(x)} - \frac{b_1(x)}{2}
25
26where `a(x)` is given by
27
28.. math:: a(x) = \frac{1}{4}\left(\frac{b_2'}{b_2} + b_1\right)^2 - \frac{1}{2}\left(\frac{b_2'}{b_2} + b_1\right)' - b_0 b_2
29
30Thus, we can develop an algorithm to solve for the Riccati equation
31in its normal form, which would in turn give us the solution for
32the original Riccati equation.
33
34Algorithm
35=========
36
37The algorithm implemented here is presented in the Ph.D thesis
38"Rational and Algebraic Solutions of First-Order Algebraic ODEs"
39by N. Thieu Vo. The entire thesis can be found here -
40https://www3.risc.jku.at/publications/download/risc_5387/PhDThesisThieu.pdf
41
42We have only implemented the Rational Riccati solver (Algorithm 11,
43Pg 78-82 in Thesis). Before we proceed towards the implementation
44of the algorithm, a few definitions to understand are -
45
461. Valuation of a Rational Function at `\infty`:
47    The valuation of a rational function `p(x)` at `\infty` is equal
48    to the difference between the degree of the denominator and the
49    numerator of `p(x)`.
50
51    NOTE: A general definition of valuation of a rational function
52    at any value of `x` can be found in Pg 63 of the thesis, but
53    is not of any interest for this algorithm.
54
552. Zeros and Poles of a Rational Function:
56    Let `a(x) = \frac{S(x)}{T(x)}, T \ne 0` be a rational function
57    of `x`. Then -
58
59    a. The Zeros of `a(x)` are the roots of `S(x)`.
60    b. The Poles of `a(x)` are the roots of `T(x)`. However, `\infty`
61    can also be a pole of a(x). We say that `a(x)` has a pole at
62    `\infty` if `a(\frac{1}{x})` has a pole at 0.
63
64Every pole is associated with an order that is equal to the multiplicity
65of its appearence as a root of `T(x)`. A pole is called a simple pole if
66it has an order 1. Similarly, a pole is called a multiple pole if it has
67an order `\ge` 2.
68
69Necessary Conditions
70====================
71
72For a Riccati equation in its normal form,
73
74.. math:: y' + y^2 = a(x)
75
76we can define
77
78a. A pole is called a movable pole if it is a pole of `y(x)` and is not
79a pole of `a(x)`.
80b. Similarly, a pole is called a non-movable pole if it is a pole of both
81`y(x)` and `a(x)`.
82
83Then, the algorithm states that a rational solution exists only if -
84
85a. Every pole of `a(x)` must be either a simple pole or a multiple pole
86of even order.
87b. The valuation of `a(x)` at `\infty` must be even or be `\ge` 2.
88
89This algorithm finds all possible rational solutions for the Riccati ODE.
90If no rational solutions are found, it means that no rational solutions
91exist.
92
93The algorithm works for Riccati ODEs where the coefficients are rational
94functions in the independent variable `x` with rational number coefficients
95i.e. in `Q(x)`. The coefficients in the rational function cannot be floats,
96irrational numbers, symbols or any other kind of expression. The reasons
97for this are -
98
991. When using symbols, different symbols could take the same value and this
100would affect the multiplicity of poles if symbols are present here.
101
1022. An integer degree bound is required to calculate a polynomial solution
103to an auxiliary differential equation, which in turn gives the particular
104solution for the original ODE. If symbols/floats/irrational numbers are
105present, we cannot determine if the expression for the degree bound is an
106integer or not.
107
108Solution
109========
110
111With these definitions, we can state a general form for the solution of
112the equation. `y(x)` must have the form -
113
114.. math:: y(x) = \sum_{i=1}^{n} \sum_{j=1}^{r_i} \frac{c_{ij}}{(x - x_i)^j} + \sum_{i=1}^{m} \frac{1}{x - \chi_i} + \sum_{i=0}^{N} d_i x^i
115
116where `x_1, x_2, ..., x_n` are non-movable poles of `a(x)`,
117`\chi_1, \chi_2, ..., \chi_m` are movable poles of `a(x)`, and the values
118of `N, n, r_1, r_2, ..., r_n` can be determined from `a(x)`. The
119coefficient vectors `(d_0, d_1, ..., d_N)` and `(c_{i1}, c_{i2}, ..., c_{i r_i})`
120can be determined from `a(x)`. We will have 2 choices each of these vectors
121and part of the procedure is figuring out which of the 2 should be used
122to get the solution correctly.
123
124Implementation
125==============
126
127In this implementatin, we use ``Poly`` to represent a rational function
128rather than using ``Expr`` since ``Poly`` is much faster. Since we cannot
129represent rational functions directly using ``Poly``, we instead represent
130a rational function with 2 ``Poly`` objects - one for its numerator and
131the other for its denominator.
132
133The code is written to match the steps given in the thesis (Pg 82)
134
135Step 0 : Match the equation -
136Find `b_0, b_1` and `b_2`. If `b_2 = 0` or no such functions exist, raise
137an error
138
139Step 1 : Transform the equation to its normal form as explained in the
140theory section.
141
142Step 2 : Initialize an empty set of solutions, ``sol``.
143
144Step 3 : If `a(x) = 0`, append `\frac{1}/{(x - C1)}` to ``sol``.
145
146Step 4 : If `a(x)` is a rational non-zero number, append `\pm \sqrt{a}`
147to ``sol``.
148
149Step 5 : Find the poles and their multiplicities of `a(x)`. Let
150the number of poles be `n`. Also find the valuation of `a(x)` at
151`\infty` using ``val_at_inf``.
152
153NOTE: Although the algorithm considers `\infty` as a pole, it is
154not mentioned if it a part of the set of finite poles. `\infty`
155is NOT a part of the set of finite poles. If a pole exists at
156`\infty`, we use its multiplicty to find the laurent series of
157`a(x)` about `\infty`.
158
159Step 6 : Find `n` c-vectors (one for each pole) and 1 d-vector using
160``construct_c`` and ``construct_d``. Now, determine all the ``2**(n + 1)``
161combinations of choosing between 2 choices for each of the `n` c-vectors
162and 1 d-vector.
163
164NOTE: The equation for `d_{-1}` in Case 4 (Pg 80) has a printinig
165mistake. The term `- d_N` must be replaced with `-N d_N`. The same
166has been explained in the code as well.
167
168For each of these above combinations, do
169
170Step 8 : Compute `m` in ``compute_m_ybar``. `m` is the degree bound of
171the polynomial solution we must find for the auxiliary equation.
172
173Step 9 : In ``compute_m_ybar``, compute ybar as well where ``ybar`` is
174one part of y(x) -
175
176.. math:: \overline{y}(x) = \sum_{i=1}^{n} \sum_{j=1}^{r_i} \frac{c_{ij}}{(x - x_i)^j} + \sum_{i=0}^{N} d_i x^i
177
178Step 10 : If `m` is a non-negative integer -
179
180Step 11: Find a polynomial solution of degree `m` for the auxiliary equation.
181
182There are 2 cases possible -
183
184    a. `m` is a non-negative integer: We can solve for the coefficients
185    in `p(x)` using Undetermined Coefficients.
186
187    b. `m` is not a non-negative integer: In this case, we cannot find
188    a polynomial solution to the auxiliary equation, and hence, we ignore
189    this value of `m`.
190
191Step 12 : For each `p(x)` that exists, append `ybar + \frac{p'(x)}{p(x)}`
192to ``sol``.
193
194Step 13 : For each solution in ``sol``, apply an inverse transformation,
195so that the solutions of the original equation are found using the
196solutions of the equation in its normal form.
197"""
198
199
200from itertools import product
201from sympy.core import S
202from sympy.core.add import Add
203from sympy.core.numbers import oo, Float
204from sympy.core.function import count_ops
205from sympy.core.relational import Eq
206from sympy.core.symbol import symbols, Symbol, Dummy
207from sympy.functions import sqrt, exp
208from sympy.functions.elementary.complexes import sign
209from sympy.integrals.integrals import Integral
210from sympy.polys.domains import ZZ
211from sympy.polys.polytools import Poly
212from sympy.polys.polyroots import roots
213from sympy.solvers.solveset import linsolve
214
215
216def riccati_normal(w, x, b1, b2):
217    """
218    Given a solution `w(x)` to the equation
219
220    .. math:: w'(x) = b_0(x) + b_1(x)*w(x) + b_2(x)*w(x)^2
221
222    and rational function coefficients `b_1(x)` and
223    `b_2(x)`, this function transforms the solution to
224    give a solution `y(x)` for its corresponding normal
225    Riccati ODE
226
227    .. math:: y'(x) + y(x)^2 = a(x)
228
229    using the transformation
230
231    .. math:: y(x) = -b_2(x)*w(x) - b'_2(x)/(2*b_2(x)) - b_1(x)/2
232    """
233    return -b2*w - b2.diff(x)/(2*b2) - b1/2
234
235
236def riccati_inverse_normal(y, x, b1, b2, bp=None):
237    """
238    Inverse transforming the solution to the normal
239    Riccati ODE to get the solution to the Riccati ODE.
240    """
241    # bp is the expression which is independent of the solution
242    # and hence, it need not be computed again
243    if bp is None:
244        bp = -b2.diff(x)/(2*b2**2) - b1/(2*b2)
245    # w(x) = -y(x)/b2(x) - b2'(x)/(2*b2(x)^2) - b1(x)/(2*b2(x))
246    return -y/b2 + bp
247
248
249def riccati_reduced(eq, f, x):
250    """
251    Convert a Riccati ODE into its corresponding
252    normal Riccati ODE.
253    """
254    match, funcs = match_riccati(eq, f, x)
255    # If equation is not a Riccati ODE, exit
256    if not match:
257        return False
258    # Using the rational functions, find the expression for a(x)
259    b0, b1, b2 = funcs
260    a = -b0*b2 + b1**2/4 - b1.diff(x)/2 + 3*b2.diff(x)**2/(4*b2**2) + b1*b2.diff(x)/(2*b2) - \
261        b2.diff(x, 2)/(2*b2)
262    # Normal form of Riccati ODE is f'(x) + f(x)^2 = a(x)
263    return f(x).diff(x) + f(x)**2 - a
264
265def linsolve_dict(eq, syms):
266    """
267    Get the output of linsolve as a dict
268    """
269    # Convert tuple type return value of linsolve
270    # to a dictionary for ease of use
271    sol = linsolve(eq, syms)
272    if not sol:
273        return {}
274    return {k:v for k, v in zip(syms, list(sol)[0])}
275
276
277def match_riccati(eq, f, x):
278    """
279    A function that matches and returns the coefficients
280    if an equation is a Riccati ODE
281
282    Parameters
283    ==========
284
285    eq: Equation to be matched
286    f: Dependent variable
287    x: Independent variable
288
289    Returns
290    =======
291
292    match: True if equation is a Riccati ODE, False otherwise
293    funcs: [b0, b1, b2] if match is True, [] otherwise. Here,
294    b0, b1 and b2 are rational functions which match the equation.
295    """
296    # Group terms based on f(x)
297    if isinstance(eq, Eq):
298        eq = eq.lhs - eq.rhs
299    eq = eq.expand().collect(f(x))
300    cf = eq.coeff(f(x).diff(x))
301
302    # There must be an f(x).diff(x) term.
303    # eq must be an Add object since we are using the expanded
304    # equation and it must have atleast 2 terms (b2 != 0)
305    if cf != 0 and isinstance(eq, Add):
306
307        # Divide all coefficients by the coefficient of f(x).diff(x)
308        # and add the terms again to get the same equation
309        eq = Add(*((x/cf).cancel() for x in eq.args)).collect(f(x))
310
311        # Match the equation with the pattern
312        b1 = -eq.coeff(f(x))
313        b2 = -eq.coeff(f(x)**2)
314        b0 = (f(x).diff(x) - b1*f(x) - b2*f(x)**2 - eq).expand()
315        funcs = [b0, b1, b2]
316
317        # Check if coefficients are not symbols and floats
318        if any([len(x.atoms(Symbol)) > 1 or len(x.atoms(Float)) for x in [b0, b1, b2]]):
319            return False, []
320
321        # If b_0(x) contains f(x), it is not a Riccati ODE
322        if len(b0.atoms(f)) or not all([b2 != 0, b0.is_rational_function(x), \
323            b1.is_rational_function(x), b2.is_rational_function(x)]):
324            return False, []
325        return True, funcs
326    return False, []
327
328
329def val_at_inf(num, den, x):
330    # Valuation of a rational function at oo = deg(denom) - deg(numer)
331    return den.degree(x) - num.degree(x)
332
333
334def check_necessary_conds(val_inf, muls):
335    """
336    The necessary conditions for a rational solution
337    to exist are as follows -
338
339    i) Every pole of a(x) must be either a simple pole
340    or a multiple pole of even order.
341
342    ii) The valuation of a(x) at infinity must be even
343    or be greater than or equal to 2.
344
345    Here, a simple pole is a pole with multiplicity 1
346    and a multiple pole is a pole with multiplicity
347    greater than 1.
348    """
349    return (val_inf >= 2 or (val_inf <= 0 and val_inf%2 == 0)) and \
350        all([mul == 1 or (mul%2 == 0 and mul >= 2) for mul in muls])
351
352
353def inverse_transform_poly(num, den, x):
354    """
355    A function to make the substitution
356    x -> 1/x in a rational function that
357    is represented using Poly objects for
358    numerator and denominator.
359    """
360    # Declare for reuse
361    one = Poly(1, x)
362    xpoly = Poly(x, x)
363
364    # Check if degree of numerator is same as denominator
365    pwr = val_at_inf(num, den, x)
366    if pwr >= 0:
367        # Denominator has greater degree. Substituting x with
368        # 1/x would make the extra power go to the numerator
369        if num.expr != 0:
370            num = num.transform(one, xpoly) * x**pwr
371            den = den.transform(one, xpoly)
372    else:
373        # Numerator has greater degree. Substituting x with
374        # 1/x would make the extra power go to the denominator
375        num = num.transform(one, xpoly)
376        den = den.transform(one, xpoly) * x**(-pwr)
377    return num.cancel(den, include=True)
378
379
380def limit_at_inf(num, den, x):
381    """
382    Find the limit of a rational function
383    at oo
384    """
385    # pwr = degree(num) - degree(den)
386    pwr = -val_at_inf(num, den, x)
387    # Numerator has a greater degree than denominator
388    # Limit at infinity would depend on the sign of the
389    # leading coefficients of numerator and denominator
390    if pwr > 0:
391        return oo*sign(num.LC()/den.LC())
392    # Degree of numerator is equal to that of denominator
393    # Limit at infinity is just the ratio of leading coeffs
394    elif pwr == 0:
395        return num.LC()/den.LC()
396    # Degree of numerator is less than that of denominator
397    # Limit at infinity is just 0
398    else:
399        return 0
400
401
402def construct_c_case_1(num, den, x, pole):
403    # Find the coefficient of 1/(x - pole)**2 in the
404    # Laurent series expansion of a(x) about pole.
405    num1, den1 = (num*Poly((x - pole)**2, x, extension=True)).cancel(den, include=True)
406    r = (num1.subs(x, pole))/(den1.subs(x, pole))
407
408    # If multiplicity is 2, the coefficient to be added
409    # in the c-vector is c = (1 +- sqrt(1 + 4*r))/2
410    if r != -S(1)/4:
411        return [[(1 + sqrt(1 + 4*r))/2], [(1 - sqrt(1 + 4*r))/2]]
412    return [[S(1)/2]]
413
414
415def construct_c_case_2(num, den, x, pole, mul):
416    # Generate the coefficients using the recurrence
417    # relation mentioned in (5.14) in the thesis (Pg 80)
418
419    # r_i = mul/2
420    ri = mul//2
421
422    # Find the Laurent series coefficients about the pole
423    ser = rational_laurent_series(num, den, x, pole, mul, 6)
424
425    # Start with an empty memo to store the coefficients
426    # This is for the plus case
427    cplus = [0 for i in range(ri)]
428
429    # Base Case
430    cplus[ri-1] = sqrt(ser[2*ri])
431
432    # Iterate backwards to find all coefficients
433    s = ri - 1
434    sm = 0
435    for s in range(ri-1, 0, -1):
436        sm = 0
437        for j in range(s+1, ri):
438            sm += cplus[j-1]*cplus[ri+s-j-1]
439        if s!= 1:
440            cplus[s-1] = (ser[ri+s] - sm)/(2*cplus[ri-1])
441
442    # Memo for the minus case
443    cminus = [-x for x in cplus]
444
445    # Find the 0th coefficient in the recurrence
446    cplus[0] = (ser[ri+s] - sm - ri*cplus[ri-1])/(2*cplus[ri-1])
447    cminus[0] = (ser[ri+s] - sm  - ri*cminus[ri-1])/(2*cminus[ri-1])
448
449    # Add both the plus and minus cases' coefficients
450    if cplus != cminus:
451        return [cplus, cminus]
452    return cplus
453
454
455def construct_c_case_3():
456    # If multiplicity is 1, the coefficient to be added
457    # in the c-vector is 1 (no choice)
458    return [[1]]
459
460
461def construct_c(num, den, x, poles, muls):
462    """
463    Helper function to calculate the coefficients
464    in the c-vector for each pole.
465    """
466    c = []
467    for pole, mul in zip(poles, muls):
468        c.append([])
469
470        # Case 3
471        if mul == 1:
472            # Add the coefficients from Case 3
473            c[-1].extend(construct_c_case_3())
474
475        # Case 1
476        elif mul == 2:
477            # Add the coefficients from Case 1
478            c[-1].extend(construct_c_case_1(num, den, x, pole))
479
480        # Case 2
481        else:
482            # Add the coefficients from Case 2
483            c[-1].extend(construct_c_case_2(num, den, x, pole, mul))
484
485    return c
486
487
488def construct_d_case_4(ser, N):
489    # Initialize an empty vector
490    dplus = [0 for i in range(N+2)]
491    # d_N = sqrt(a_{2*N})
492    dplus[N] = sqrt(ser[2*N])
493
494    # Use the recurrence relations to find
495    # the value of d_s
496    for s in range(N-1, -2, -1):
497        sm = 0
498        for j in range(s+1, N):
499            sm += dplus[j]*dplus[N+s-j]
500        if s != -1:
501            dplus[s] = (ser[N+s] - sm)/(2*dplus[N])
502
503    # Coefficients for the case of d_N = -sqrt(a_{2*N})
504    dminus = [-x for x in dplus]
505
506    # The third equation in Eq 5.15 of the thesis is WRONG!
507    # d_N must be replaced with N*d_N in that equation.
508    dplus[-1] = (ser[N+s] - N*dplus[N] - sm)/(2*dplus[N])
509    dminus[-1] = (ser[N+s] - N*dminus[N] - sm)/(2*dminus[N])
510
511    if dplus != dminus:
512        return [dplus, dminus]
513    return dplus
514
515
516def construct_d_case_5(ser):
517    # List to store coefficients for plus case
518    dplus = [0, 0]
519
520    # d_0  = sqrt(a_0)
521    dplus[0] = sqrt(ser[0])
522
523    # d_(-1) = a_(-1)/(2*d_0)
524    dplus[-1] = ser[-1]/(2*dplus[0])
525
526    # Coefficients for the minus case are just the negative
527    # of the coefficients for the positive case.
528    dminus = [-x for x in dplus]
529
530    if dplus != dminus:
531        return [dplus, dminus]
532    return dplus
533
534
535def construct_d_case_6(num, den, x):
536    # s_oo = lim x->0 1/x**2 * a(1/x) which is equivalent to
537    # s_oo = lim x->oo x**2 * a(x)
538    s_inf = limit_at_inf(Poly(x**2, x)*num, den, x)
539
540    # d_(-1) = (1 +- sqrt(1 + 4*s_oo))/2
541    if s_inf != -S(1)/4:
542        return [[(1 + sqrt(1 + 4*s_inf))/2], [(1 - sqrt(1 + 4*s_inf))/2]]
543    return [[S(1)/2]]
544
545
546def construct_d(num, den, x, val_inf):
547    """
548    Helper function to calculate the coefficients
549    in the d-vector based on the valuation of the
550    function at oo.
551    """
552    N = -val_inf//2
553    # Multiplicity of oo as a pole
554    mul = -val_inf if val_inf < 0 else 0
555    ser = rational_laurent_series(num, den, x, oo, mul, 1)
556
557    # Case 4
558    if val_inf < 0:
559        d = construct_d_case_4(ser, N)
560
561    # Case 5
562    elif val_inf == 0:
563        d = construct_d_case_5(ser)
564
565    # Case 6
566    else:
567        d = construct_d_case_6(num, den, x)
568
569    return d
570
571
572def rational_laurent_series(num, den, x, r, m, n):
573    r"""
574    The function computes the Laurent series coefficients
575    of a rational function.
576
577    Parameters
578    ==========
579
580    num: A Poly object that is the numerator of `f(x)`.
581    den: A Poly object that is the denominator of `f(x)`.
582    x: The variable of expansion of the series.
583    r: The point of expansion of the series.
584    m: Multiplicity of r if r is a pole of `f(x)`. Should
585    be zero otherwise.
586    n: Order of the term upto which the series is expanded.
587
588    Returns
589    =======
590
591    series: A dictionary that has power of the term as key
592    and coefficient of that term as value.
593
594    Below is a basic outline of how the Laurent series of a
595    rational function `f(x)` about `x_0` is being calculated -
596
597    1. Substitute `x + x_0` in place of `x`. If `x_0`
598    is a pole of `f(x)`, multiply the expression by `x^m`
599    where `m` is the multiplicity of `x_0`. Denote the
600    the resulting expression as g(x). We do this substitution
601    so that we can now find the Laurent series of g(x) about
602    `x = 0`.
603
604    2. We can then assume that the Laurent series of `g(x)`
605    takes the following form -
606
607    .. math:: g(x) = \frac{num(x)}{den(x)} = \sum_{m = 0}^{\infty} a_m x^m
608
609    where `a_m` denotes the Laurent series coefficients.
610
611    3. Multiply the denominator to the RHS of the equation
612    and form a recurrence relation for the coefficients `a_m`.
613    """
614    one = Poly(1, x, extension=True)
615
616    if r == oo:
617        # Series at x = oo is equal to first transforming
618        # the function from x -> 1/x and finding the
619        # series at x = 0
620        num, den = inverse_transform_poly(num, den, x)
621        r = S(0)
622
623    if r:
624        # For an expansion about a non-zero point, a
625        # transformation from x -> x + r must be made
626        num = num.transform(Poly(x + r, x, extension=True), one)
627        den = den.transform(Poly(x + r, x, extension=True), one)
628
629    # Remove the pole from the denominator if the series
630    # expansion is about one of the poles
631    num, den = (num*x**m).cancel(den, include=True)
632
633    # Equate coefficients for the first terms (base case)
634    maxdegree = 1 + max(num.degree(), den.degree())
635    syms = symbols(f'a:{maxdegree}', cls=Dummy)
636    diff = num - den * Poly(syms[::-1], x)
637    coeff_diffs = diff.all_coeffs()[::-1][:maxdegree]
638    (coeffs, ) = linsolve(coeff_diffs, syms)
639
640    # Use the recursion relation for the rest
641    recursion = den.all_coeffs()[::-1]
642    div, rec_rhs = recursion[0], recursion[1:]
643    series = list(coeffs)
644    while len(series) < n:
645        next_coeff = Add(*(c*series[-1-n] for n, c in enumerate(rec_rhs))) / div
646        series.append(-next_coeff)
647    series = {m - i: val for i, val in enumerate(series)}
648    return series
649
650def compute_m_ybar(x, poles, choice, N):
651    """
652    Helper function to calculate -
653
654    1. m - The degree bound for the polynomial
655    solution that must be found for the auxiliary
656    differential equation.
657
658    2. ybar - Part of the solution which can be
659    computed using the poles, c and d vectors.
660    """
661    ybar = 0
662    m = Poly(choice[-1][-1], x, extension=True)
663
664    # Calculate the first (nested) summation for ybar
665    # as given in Step 9 of the Thesis (Pg 82)
666    for i in range(len(poles)):
667        for j in range(len(choice[i])):
668            ybar += choice[i][j]/(x - poles[i])**(j+1)
669        m -= Poly(choice[i][0], x, extension=True)
670
671    # Calculate the second summation for ybar
672    for i in range(N+1):
673        ybar += choice[-1][i]*x**i
674    return (m.expr, ybar)
675
676
677def solve_aux_eq(numa, dena, numy, deny, x, m):
678    """
679    Helper function to find a polynomial solution
680    of degree m for the auxiliary differential
681    equation.
682    """
683    # Assume that the solution is of the type
684    # p(x) = C_0 + C_1*x + ... + C_{m-1}*x**(m-1) + x**m
685    psyms = symbols(f'C0:{m}', cls=Dummy)
686    K = ZZ[psyms]
687    psol = Poly(K.gens, x, domain=K) + Poly(x**m, x, domain=K)
688
689    # Eq (5.16) in Thesis - Pg 81
690    auxeq = (dena*(numy.diff(x)*deny - numy*deny.diff(x) + numy**2) - numa*deny**2)*psol
691    if m >= 1:
692        px = psol.diff(x)
693        auxeq += px*(2*numy*deny*dena)
694    if m >= 2:
695        auxeq += px.diff(x)*(deny**2*dena)
696    if m != 0:
697        # m is a non-zero integer. Find the constant terms using undetermined coefficients
698        return psol, linsolve_dict(auxeq.all_coeffs(), psyms), True
699    else:
700        # m == 0 . Check if 1 (x**0) is a solution to the auxiliary equation
701        return S(1), auxeq, auxeq == 0
702
703
704def remove_redundant_sols(sol1, sol2, x):
705    """
706    Helper function to remove redundant
707    solutions to the differential equation.
708    """
709    # If y1 and y2 are redundant solutions, there is
710    # some value of the arbitrary constant for which
711    # they will be equal
712
713    syms1 = sol1.atoms(Symbol, Dummy)
714    syms2 = sol2.atoms(Symbol, Dummy)
715    num1, den1 = [Poly(e, x, extension=True) for e in sol1.together().as_numer_denom()]
716    num2, den2 = [Poly(e, x, extension=True) for e in sol2.together().as_numer_denom()]
717    # Cross multiply
718    e = num1*den2 - den1*num2
719    # Check if there are any constants
720    syms = list(e.atoms(Symbol, Dummy))
721    if len(syms):
722        # Find values of constants for which solutions are equal
723        redn = linsolve(e.all_coeffs(), syms)
724        if len(redn):
725            # Return the general solution over a particular solution
726            if len(syms1) > len(syms2):
727                return sol2
728            # If both have constants, return the lesser complex solution
729            elif len(syms1) == len(syms2):
730                return sol1 if count_ops(syms1) >= count_ops(syms2) else sol2
731            else:
732                return sol1
733
734
735def get_gen_sol_from_part_sol(part_sols, a, x):
736    """"
737    Helper function which computes the general
738    solution for a Riccati ODE from its particular
739    solutions.
740
741    There are 3 cases to find the general solution
742    from the particular solutions for a Riccati ODE
743    depending on the number of particular solution(s)
744    we have - 1, 2 or 3.
745
746    For more information, see Section 6 of
747    "Methods of Solution of the Riccati Differential Equation"
748    by D. R. Haaheim and F. M. Stein
749    """
750
751    # If no particular solutions are found, a general
752    # solution cannot be found
753    if len(part_sols) == 0:
754        return []
755
756    # In case of a single particular solution, the general
757    # solution can be found by using the substitution
758    # y = y1 + 1/z and solving a Bernoulli ODE to find z.
759    elif len(part_sols) == 1:
760        y1 = part_sols[0]
761        i = exp(Integral(2*y1, x))
762        z = i * Integral(a/i, x)
763        z = z.doit()
764        if a == 0 or z == 0:
765            return y1
766        return y1 + 1/z
767
768    # In case of 2 particular solutions, the general solution
769    # can be found by solving a separable equation. This is
770    # the most common case, i.e. most Riccati ODEs have 2
771    # rational particular solutions.
772    elif len(part_sols) == 2:
773        y1, y2 = part_sols
774        # One of them already has a constant
775        if len(y1.atoms(Dummy)) + len(y2.atoms(Dummy)) > 0:
776            u = exp(Integral(y2 - y1, x)).doit()
777        # Introduce a constant
778        else:
779            C1 = Dummy('C1')
780            u = C1*exp(Integral(y2 - y1, x)).doit()
781        if u == 1:
782            return y2
783        return (y2*u - y1)/(u - 1)
784
785    # In case of 3 particular solutions, a closed form
786    # of the general solution can be obtained directly
787    else:
788        y1, y2, y3 = part_sols[:3]
789        C1 = Dummy('C1')
790        return (C1 + 1)*y2*(y1 - y3)/(C1*y1 + y2 - (C1 + 1)*y3)
791
792
793def solve_riccati(fx, x, b0, b1, b2, gensol=False):
794    """
795    The main function that gives particular/general
796    solutions to Riccati ODEs that have atleast 1
797    rational particular solution.
798    """
799    # Step 1 : Convert to Normal Form
800    a = -b0*b2 + b1**2/4 - b1.diff(x)/2 + 3*b2.diff(x)**2/(4*b2**2) + b1*b2.diff(x)/(2*b2) - \
801        b2.diff(x, 2)/(2*b2)
802    a_t = a.together()
803    num, den = [Poly(e, x, extension=True) for e in a_t.as_numer_denom()]
804    num, den = num.cancel(den, include=True)
805
806    # Step 2
807    presol = []
808
809    # Step 3 : a(x) is 0
810    if num == 0:
811        presol.append(1/(x + Dummy('C1')))
812
813    # Step 4 : a(x) is a non-zero constant
814    elif x not in num.free_symbols.union(den.free_symbols):
815        presol.extend([sqrt(a), -sqrt(a)])
816
817    # Step 5 : Find poles and valuation at infinity
818    poles = roots(den, x)
819    poles, muls = list(poles.keys()), list(poles.values())
820    val_inf = val_at_inf(num, den, x)
821
822    if len(poles):
823        # Check necessary conditions (outlined in the module docstring)
824        if not check_necessary_conds(val_inf, muls):
825            raise ValueError("Rational Solution doesn't exist")
826
827        # Step 6
828        # Construct c-vectors for each singular point
829        c = construct_c(num, den, x, poles, muls)
830
831        # Construct d vectors for each singular point
832        d = construct_d(num, den, x, val_inf)
833
834        # Step 7 : Iterate over all possible combinations and return solutions
835        # For each possible combination, generate an array of 0's and 1's
836        # where 0 means pick 1st choice and 1 means pick the second choice.
837
838        # NOTE: We could exit from the loop if we find 3 particular solutions,
839        # but it is not implemented here as -
840        #   a. Finding 3 particular solutions is very rare. Most of the time,
841        #      only 2 particular solutions are found.
842        #   b. In case we exit after finding 3 particular solutions, it might
843        #      happen that 1 or 2 of them are redundant solutions. So, instead of
844        #      spending some more time in computing the particular solutions,
845        #      we will end up computing the general solution from a single
846        #      particular solution which is usually slower than computing the
847        #      general solution from 2 or 3 particular solutions.
848        c.append(d)
849        choices = product(*c)
850        for choice in choices:
851            m, ybar = compute_m_ybar(x, poles, choice, -val_inf//2)
852            numy, deny = [Poly(e, x, extension=True) for e in ybar.together().as_numer_denom()]
853            # Step 10 : Check if a valid solution exists. If yes, also check
854            # if m is a non-negative integer
855            if m.is_nonnegative == True and m.is_integer == True:
856
857                # Step 11 : Find polynomial solutions of degree m for the auxiliary equation
858                psol, coeffs, exists = solve_aux_eq(num, den, numy, deny, x, m)
859
860                # Step 12 : If valid polynomial solution exists, append solution.
861                if exists:
862                    # m == 0 case
863                    if psol == 1 and coeffs == 0:
864                        # p(x) = 1, so p'(x)/p(x) term need not be added
865                        presol.append(ybar)
866                    # m is a positive integer and there are valid coefficients
867                    elif len(coeffs):
868                        # Substitute the valid coefficients to get p(x)
869                        psol = psol.xreplace(coeffs)
870                        # y(x) = ybar(x) + p'(x)/p(x)
871                        presol.append(ybar + psol.diff(x)/psol)
872
873    # Remove redundant solutions from the list of existing solutions
874    remove = set()
875    for i in range(len(presol)):
876        for j in range(i+1, len(presol)):
877            rem = remove_redundant_sols(presol[i], presol[j], x)
878            if rem is not None:
879                remove.add(rem)
880    sols = [x for x in presol if x not in remove]
881
882    # Step 15 : Inverse transform the solutions of the equation in normal form
883    bp = -b2.diff(x)/(2*b2**2) - b1/(2*b2)
884
885    # If general solution is required, compute it from the particular solutions
886    if gensol:
887        sols = [get_gen_sol_from_part_sol(sols, a, x)]
888
889    # Inverse transform the particular solutions
890    presol = [Eq(fx, riccati_inverse_normal(y, x, b1, b2, bp).cancel(extension=True)) for y in sols]
891    return presol
892