1r""" 2This module contains :py:meth:`~sympy.solvers.ode.riccati.solve_riccati`, 3a function which gives all rational particular solutions to first order 4Riccati ODEs. A general first order Riccati ODE is given by - 5 6.. math:: y' = b_0(x) + b_1(x)w + b_2(x)w^2 7 8where `b_0, b_1` and `b_2` can be arbitrary rational functions of `x` 9with `b_2 \ne 0`. When `b_2 = 0`, the equation is not a Riccati ODE 10anymore and becomes a Linear ODE. Similarly, when `b_0 = 0`, the equation 11is a Bernoulli ODE. The algorithm presented below can find rational 12solution(s) to all ODEs with `b_2 \ne 0` that have a rational solution, 13or prove that no rational solution exists for the equation. 14 15Background 16========== 17 18A Riccati equation can be transformed to its normal form 19 20.. math:: y' + y^2 = a(x) 21 22using the transformation 23 24.. math:: y = -b_2(x) - \frac{b'_2(x)}{2 b_2(x)} - \frac{b_1(x)}{2} 25 26where `a(x)` is given by 27 28.. math:: a(x) = \frac{1}{4}\left(\frac{b_2'}{b_2} + b_1\right)^2 - \frac{1}{2}\left(\frac{b_2'}{b_2} + b_1\right)' - b_0 b_2 29 30Thus, we can develop an algorithm to solve for the Riccati equation 31in its normal form, which would in turn give us the solution for 32the original Riccati equation. 33 34Algorithm 35========= 36 37The algorithm implemented here is presented in the Ph.D thesis 38"Rational and Algebraic Solutions of First-Order Algebraic ODEs" 39by N. Thieu Vo. The entire thesis can be found here - 40https://www3.risc.jku.at/publications/download/risc_5387/PhDThesisThieu.pdf 41 42We have only implemented the Rational Riccati solver (Algorithm 11, 43Pg 78-82 in Thesis). Before we proceed towards the implementation 44of the algorithm, a few definitions to understand are - 45 461. Valuation of a Rational Function at `\infty`: 47 The valuation of a rational function `p(x)` at `\infty` is equal 48 to the difference between the degree of the denominator and the 49 numerator of `p(x)`. 50 51 NOTE: A general definition of valuation of a rational function 52 at any value of `x` can be found in Pg 63 of the thesis, but 53 is not of any interest for this algorithm. 54 552. Zeros and Poles of a Rational Function: 56 Let `a(x) = \frac{S(x)}{T(x)}, T \ne 0` be a rational function 57 of `x`. Then - 58 59 a. The Zeros of `a(x)` are the roots of `S(x)`. 60 b. The Poles of `a(x)` are the roots of `T(x)`. However, `\infty` 61 can also be a pole of a(x). We say that `a(x)` has a pole at 62 `\infty` if `a(\frac{1}{x})` has a pole at 0. 63 64Every pole is associated with an order that is equal to the multiplicity 65of its appearence as a root of `T(x)`. A pole is called a simple pole if 66it has an order 1. Similarly, a pole is called a multiple pole if it has 67an order `\ge` 2. 68 69Necessary Conditions 70==================== 71 72For a Riccati equation in its normal form, 73 74.. math:: y' + y^2 = a(x) 75 76we can define 77 78a. A pole is called a movable pole if it is a pole of `y(x)` and is not 79a pole of `a(x)`. 80b. Similarly, a pole is called a non-movable pole if it is a pole of both 81`y(x)` and `a(x)`. 82 83Then, the algorithm states that a rational solution exists only if - 84 85a. Every pole of `a(x)` must be either a simple pole or a multiple pole 86of even order. 87b. The valuation of `a(x)` at `\infty` must be even or be `\ge` 2. 88 89This algorithm finds all possible rational solutions for the Riccati ODE. 90If no rational solutions are found, it means that no rational solutions 91exist. 92 93The algorithm works for Riccati ODEs where the coefficients are rational 94functions in the independent variable `x` with rational number coefficients 95i.e. in `Q(x)`. The coefficients in the rational function cannot be floats, 96irrational numbers, symbols or any other kind of expression. The reasons 97for this are - 98 991. When using symbols, different symbols could take the same value and this 100would affect the multiplicity of poles if symbols are present here. 101 1022. An integer degree bound is required to calculate a polynomial solution 103to an auxiliary differential equation, which in turn gives the particular 104solution for the original ODE. If symbols/floats/irrational numbers are 105present, we cannot determine if the expression for the degree bound is an 106integer or not. 107 108Solution 109======== 110 111With these definitions, we can state a general form for the solution of 112the equation. `y(x)` must have the form - 113 114.. math:: y(x) = \sum_{i=1}^{n} \sum_{j=1}^{r_i} \frac{c_{ij}}{(x - x_i)^j} + \sum_{i=1}^{m} \frac{1}{x - \chi_i} + \sum_{i=0}^{N} d_i x^i 115 116where `x_1, x_2, ..., x_n` are non-movable poles of `a(x)`, 117`\chi_1, \chi_2, ..., \chi_m` are movable poles of `a(x)`, and the values 118of `N, n, r_1, r_2, ..., r_n` can be determined from `a(x)`. The 119coefficient vectors `(d_0, d_1, ..., d_N)` and `(c_{i1}, c_{i2}, ..., c_{i r_i})` 120can be determined from `a(x)`. We will have 2 choices each of these vectors 121and part of the procedure is figuring out which of the 2 should be used 122to get the solution correctly. 123 124Implementation 125============== 126 127In this implementatin, we use ``Poly`` to represent a rational function 128rather than using ``Expr`` since ``Poly`` is much faster. Since we cannot 129represent rational functions directly using ``Poly``, we instead represent 130a rational function with 2 ``Poly`` objects - one for its numerator and 131the other for its denominator. 132 133The code is written to match the steps given in the thesis (Pg 82) 134 135Step 0 : Match the equation - 136Find `b_0, b_1` and `b_2`. If `b_2 = 0` or no such functions exist, raise 137an error 138 139Step 1 : Transform the equation to its normal form as explained in the 140theory section. 141 142Step 2 : Initialize an empty set of solutions, ``sol``. 143 144Step 3 : If `a(x) = 0`, append `\frac{1}/{(x - C1)}` to ``sol``. 145 146Step 4 : If `a(x)` is a rational non-zero number, append `\pm \sqrt{a}` 147to ``sol``. 148 149Step 5 : Find the poles and their multiplicities of `a(x)`. Let 150the number of poles be `n`. Also find the valuation of `a(x)` at 151`\infty` using ``val_at_inf``. 152 153NOTE: Although the algorithm considers `\infty` as a pole, it is 154not mentioned if it a part of the set of finite poles. `\infty` 155is NOT a part of the set of finite poles. If a pole exists at 156`\infty`, we use its multiplicty to find the laurent series of 157`a(x)` about `\infty`. 158 159Step 6 : Find `n` c-vectors (one for each pole) and 1 d-vector using 160``construct_c`` and ``construct_d``. Now, determine all the ``2**(n + 1)`` 161combinations of choosing between 2 choices for each of the `n` c-vectors 162and 1 d-vector. 163 164NOTE: The equation for `d_{-1}` in Case 4 (Pg 80) has a printinig 165mistake. The term `- d_N` must be replaced with `-N d_N`. The same 166has been explained in the code as well. 167 168For each of these above combinations, do 169 170Step 8 : Compute `m` in ``compute_m_ybar``. `m` is the degree bound of 171the polynomial solution we must find for the auxiliary equation. 172 173Step 9 : In ``compute_m_ybar``, compute ybar as well where ``ybar`` is 174one part of y(x) - 175 176.. math:: \overline{y}(x) = \sum_{i=1}^{n} \sum_{j=1}^{r_i} \frac{c_{ij}}{(x - x_i)^j} + \sum_{i=0}^{N} d_i x^i 177 178Step 10 : If `m` is a non-negative integer - 179 180Step 11: Find a polynomial solution of degree `m` for the auxiliary equation. 181 182There are 2 cases possible - 183 184 a. `m` is a non-negative integer: We can solve for the coefficients 185 in `p(x)` using Undetermined Coefficients. 186 187 b. `m` is not a non-negative integer: In this case, we cannot find 188 a polynomial solution to the auxiliary equation, and hence, we ignore 189 this value of `m`. 190 191Step 12 : For each `p(x)` that exists, append `ybar + \frac{p'(x)}{p(x)}` 192to ``sol``. 193 194Step 13 : For each solution in ``sol``, apply an inverse transformation, 195so that the solutions of the original equation are found using the 196solutions of the equation in its normal form. 197""" 198 199 200from itertools import product 201from sympy.core import S 202from sympy.core.add import Add 203from sympy.core.numbers import oo, Float 204from sympy.core.function import count_ops 205from sympy.core.relational import Eq 206from sympy.core.symbol import symbols, Symbol, Dummy 207from sympy.functions import sqrt, exp 208from sympy.functions.elementary.complexes import sign 209from sympy.integrals.integrals import Integral 210from sympy.polys.domains import ZZ 211from sympy.polys.polytools import Poly 212from sympy.polys.polyroots import roots 213from sympy.solvers.solveset import linsolve 214 215 216def riccati_normal(w, x, b1, b2): 217 """ 218 Given a solution `w(x)` to the equation 219 220 .. math:: w'(x) = b_0(x) + b_1(x)*w(x) + b_2(x)*w(x)^2 221 222 and rational function coefficients `b_1(x)` and 223 `b_2(x)`, this function transforms the solution to 224 give a solution `y(x)` for its corresponding normal 225 Riccati ODE 226 227 .. math:: y'(x) + y(x)^2 = a(x) 228 229 using the transformation 230 231 .. math:: y(x) = -b_2(x)*w(x) - b'_2(x)/(2*b_2(x)) - b_1(x)/2 232 """ 233 return -b2*w - b2.diff(x)/(2*b2) - b1/2 234 235 236def riccati_inverse_normal(y, x, b1, b2, bp=None): 237 """ 238 Inverse transforming the solution to the normal 239 Riccati ODE to get the solution to the Riccati ODE. 240 """ 241 # bp is the expression which is independent of the solution 242 # and hence, it need not be computed again 243 if bp is None: 244 bp = -b2.diff(x)/(2*b2**2) - b1/(2*b2) 245 # w(x) = -y(x)/b2(x) - b2'(x)/(2*b2(x)^2) - b1(x)/(2*b2(x)) 246 return -y/b2 + bp 247 248 249def riccati_reduced(eq, f, x): 250 """ 251 Convert a Riccati ODE into its corresponding 252 normal Riccati ODE. 253 """ 254 match, funcs = match_riccati(eq, f, x) 255 # If equation is not a Riccati ODE, exit 256 if not match: 257 return False 258 # Using the rational functions, find the expression for a(x) 259 b0, b1, b2 = funcs 260 a = -b0*b2 + b1**2/4 - b1.diff(x)/2 + 3*b2.diff(x)**2/(4*b2**2) + b1*b2.diff(x)/(2*b2) - \ 261 b2.diff(x, 2)/(2*b2) 262 # Normal form of Riccati ODE is f'(x) + f(x)^2 = a(x) 263 return f(x).diff(x) + f(x)**2 - a 264 265def linsolve_dict(eq, syms): 266 """ 267 Get the output of linsolve as a dict 268 """ 269 # Convert tuple type return value of linsolve 270 # to a dictionary for ease of use 271 sol = linsolve(eq, syms) 272 if not sol: 273 return {} 274 return {k:v for k, v in zip(syms, list(sol)[0])} 275 276 277def match_riccati(eq, f, x): 278 """ 279 A function that matches and returns the coefficients 280 if an equation is a Riccati ODE 281 282 Parameters 283 ========== 284 285 eq: Equation to be matched 286 f: Dependent variable 287 x: Independent variable 288 289 Returns 290 ======= 291 292 match: True if equation is a Riccati ODE, False otherwise 293 funcs: [b0, b1, b2] if match is True, [] otherwise. Here, 294 b0, b1 and b2 are rational functions which match the equation. 295 """ 296 # Group terms based on f(x) 297 if isinstance(eq, Eq): 298 eq = eq.lhs - eq.rhs 299 eq = eq.expand().collect(f(x)) 300 cf = eq.coeff(f(x).diff(x)) 301 302 # There must be an f(x).diff(x) term. 303 # eq must be an Add object since we are using the expanded 304 # equation and it must have atleast 2 terms (b2 != 0) 305 if cf != 0 and isinstance(eq, Add): 306 307 # Divide all coefficients by the coefficient of f(x).diff(x) 308 # and add the terms again to get the same equation 309 eq = Add(*((x/cf).cancel() for x in eq.args)).collect(f(x)) 310 311 # Match the equation with the pattern 312 b1 = -eq.coeff(f(x)) 313 b2 = -eq.coeff(f(x)**2) 314 b0 = (f(x).diff(x) - b1*f(x) - b2*f(x)**2 - eq).expand() 315 funcs = [b0, b1, b2] 316 317 # Check if coefficients are not symbols and floats 318 if any([len(x.atoms(Symbol)) > 1 or len(x.atoms(Float)) for x in [b0, b1, b2]]): 319 return False, [] 320 321 # If b_0(x) contains f(x), it is not a Riccati ODE 322 if len(b0.atoms(f)) or not all([b2 != 0, b0.is_rational_function(x), \ 323 b1.is_rational_function(x), b2.is_rational_function(x)]): 324 return False, [] 325 return True, funcs 326 return False, [] 327 328 329def val_at_inf(num, den, x): 330 # Valuation of a rational function at oo = deg(denom) - deg(numer) 331 return den.degree(x) - num.degree(x) 332 333 334def check_necessary_conds(val_inf, muls): 335 """ 336 The necessary conditions for a rational solution 337 to exist are as follows - 338 339 i) Every pole of a(x) must be either a simple pole 340 or a multiple pole of even order. 341 342 ii) The valuation of a(x) at infinity must be even 343 or be greater than or equal to 2. 344 345 Here, a simple pole is a pole with multiplicity 1 346 and a multiple pole is a pole with multiplicity 347 greater than 1. 348 """ 349 return (val_inf >= 2 or (val_inf <= 0 and val_inf%2 == 0)) and \ 350 all([mul == 1 or (mul%2 == 0 and mul >= 2) for mul in muls]) 351 352 353def inverse_transform_poly(num, den, x): 354 """ 355 A function to make the substitution 356 x -> 1/x in a rational function that 357 is represented using Poly objects for 358 numerator and denominator. 359 """ 360 # Declare for reuse 361 one = Poly(1, x) 362 xpoly = Poly(x, x) 363 364 # Check if degree of numerator is same as denominator 365 pwr = val_at_inf(num, den, x) 366 if pwr >= 0: 367 # Denominator has greater degree. Substituting x with 368 # 1/x would make the extra power go to the numerator 369 if num.expr != 0: 370 num = num.transform(one, xpoly) * x**pwr 371 den = den.transform(one, xpoly) 372 else: 373 # Numerator has greater degree. Substituting x with 374 # 1/x would make the extra power go to the denominator 375 num = num.transform(one, xpoly) 376 den = den.transform(one, xpoly) * x**(-pwr) 377 return num.cancel(den, include=True) 378 379 380def limit_at_inf(num, den, x): 381 """ 382 Find the limit of a rational function 383 at oo 384 """ 385 # pwr = degree(num) - degree(den) 386 pwr = -val_at_inf(num, den, x) 387 # Numerator has a greater degree than denominator 388 # Limit at infinity would depend on the sign of the 389 # leading coefficients of numerator and denominator 390 if pwr > 0: 391 return oo*sign(num.LC()/den.LC()) 392 # Degree of numerator is equal to that of denominator 393 # Limit at infinity is just the ratio of leading coeffs 394 elif pwr == 0: 395 return num.LC()/den.LC() 396 # Degree of numerator is less than that of denominator 397 # Limit at infinity is just 0 398 else: 399 return 0 400 401 402def construct_c_case_1(num, den, x, pole): 403 # Find the coefficient of 1/(x - pole)**2 in the 404 # Laurent series expansion of a(x) about pole. 405 num1, den1 = (num*Poly((x - pole)**2, x, extension=True)).cancel(den, include=True) 406 r = (num1.subs(x, pole))/(den1.subs(x, pole)) 407 408 # If multiplicity is 2, the coefficient to be added 409 # in the c-vector is c = (1 +- sqrt(1 + 4*r))/2 410 if r != -S(1)/4: 411 return [[(1 + sqrt(1 + 4*r))/2], [(1 - sqrt(1 + 4*r))/2]] 412 return [[S(1)/2]] 413 414 415def construct_c_case_2(num, den, x, pole, mul): 416 # Generate the coefficients using the recurrence 417 # relation mentioned in (5.14) in the thesis (Pg 80) 418 419 # r_i = mul/2 420 ri = mul//2 421 422 # Find the Laurent series coefficients about the pole 423 ser = rational_laurent_series(num, den, x, pole, mul, 6) 424 425 # Start with an empty memo to store the coefficients 426 # This is for the plus case 427 cplus = [0 for i in range(ri)] 428 429 # Base Case 430 cplus[ri-1] = sqrt(ser[2*ri]) 431 432 # Iterate backwards to find all coefficients 433 s = ri - 1 434 sm = 0 435 for s in range(ri-1, 0, -1): 436 sm = 0 437 for j in range(s+1, ri): 438 sm += cplus[j-1]*cplus[ri+s-j-1] 439 if s!= 1: 440 cplus[s-1] = (ser[ri+s] - sm)/(2*cplus[ri-1]) 441 442 # Memo for the minus case 443 cminus = [-x for x in cplus] 444 445 # Find the 0th coefficient in the recurrence 446 cplus[0] = (ser[ri+s] - sm - ri*cplus[ri-1])/(2*cplus[ri-1]) 447 cminus[0] = (ser[ri+s] - sm - ri*cminus[ri-1])/(2*cminus[ri-1]) 448 449 # Add both the plus and minus cases' coefficients 450 if cplus != cminus: 451 return [cplus, cminus] 452 return cplus 453 454 455def construct_c_case_3(): 456 # If multiplicity is 1, the coefficient to be added 457 # in the c-vector is 1 (no choice) 458 return [[1]] 459 460 461def construct_c(num, den, x, poles, muls): 462 """ 463 Helper function to calculate the coefficients 464 in the c-vector for each pole. 465 """ 466 c = [] 467 for pole, mul in zip(poles, muls): 468 c.append([]) 469 470 # Case 3 471 if mul == 1: 472 # Add the coefficients from Case 3 473 c[-1].extend(construct_c_case_3()) 474 475 # Case 1 476 elif mul == 2: 477 # Add the coefficients from Case 1 478 c[-1].extend(construct_c_case_1(num, den, x, pole)) 479 480 # Case 2 481 else: 482 # Add the coefficients from Case 2 483 c[-1].extend(construct_c_case_2(num, den, x, pole, mul)) 484 485 return c 486 487 488def construct_d_case_4(ser, N): 489 # Initialize an empty vector 490 dplus = [0 for i in range(N+2)] 491 # d_N = sqrt(a_{2*N}) 492 dplus[N] = sqrt(ser[2*N]) 493 494 # Use the recurrence relations to find 495 # the value of d_s 496 for s in range(N-1, -2, -1): 497 sm = 0 498 for j in range(s+1, N): 499 sm += dplus[j]*dplus[N+s-j] 500 if s != -1: 501 dplus[s] = (ser[N+s] - sm)/(2*dplus[N]) 502 503 # Coefficients for the case of d_N = -sqrt(a_{2*N}) 504 dminus = [-x for x in dplus] 505 506 # The third equation in Eq 5.15 of the thesis is WRONG! 507 # d_N must be replaced with N*d_N in that equation. 508 dplus[-1] = (ser[N+s] - N*dplus[N] - sm)/(2*dplus[N]) 509 dminus[-1] = (ser[N+s] - N*dminus[N] - sm)/(2*dminus[N]) 510 511 if dplus != dminus: 512 return [dplus, dminus] 513 return dplus 514 515 516def construct_d_case_5(ser): 517 # List to store coefficients for plus case 518 dplus = [0, 0] 519 520 # d_0 = sqrt(a_0) 521 dplus[0] = sqrt(ser[0]) 522 523 # d_(-1) = a_(-1)/(2*d_0) 524 dplus[-1] = ser[-1]/(2*dplus[0]) 525 526 # Coefficients for the minus case are just the negative 527 # of the coefficients for the positive case. 528 dminus = [-x for x in dplus] 529 530 if dplus != dminus: 531 return [dplus, dminus] 532 return dplus 533 534 535def construct_d_case_6(num, den, x): 536 # s_oo = lim x->0 1/x**2 * a(1/x) which is equivalent to 537 # s_oo = lim x->oo x**2 * a(x) 538 s_inf = limit_at_inf(Poly(x**2, x)*num, den, x) 539 540 # d_(-1) = (1 +- sqrt(1 + 4*s_oo))/2 541 if s_inf != -S(1)/4: 542 return [[(1 + sqrt(1 + 4*s_inf))/2], [(1 - sqrt(1 + 4*s_inf))/2]] 543 return [[S(1)/2]] 544 545 546def construct_d(num, den, x, val_inf): 547 """ 548 Helper function to calculate the coefficients 549 in the d-vector based on the valuation of the 550 function at oo. 551 """ 552 N = -val_inf//2 553 # Multiplicity of oo as a pole 554 mul = -val_inf if val_inf < 0 else 0 555 ser = rational_laurent_series(num, den, x, oo, mul, 1) 556 557 # Case 4 558 if val_inf < 0: 559 d = construct_d_case_4(ser, N) 560 561 # Case 5 562 elif val_inf == 0: 563 d = construct_d_case_5(ser) 564 565 # Case 6 566 else: 567 d = construct_d_case_6(num, den, x) 568 569 return d 570 571 572def rational_laurent_series(num, den, x, r, m, n): 573 r""" 574 The function computes the Laurent series coefficients 575 of a rational function. 576 577 Parameters 578 ========== 579 580 num: A Poly object that is the numerator of `f(x)`. 581 den: A Poly object that is the denominator of `f(x)`. 582 x: The variable of expansion of the series. 583 r: The point of expansion of the series. 584 m: Multiplicity of r if r is a pole of `f(x)`. Should 585 be zero otherwise. 586 n: Order of the term upto which the series is expanded. 587 588 Returns 589 ======= 590 591 series: A dictionary that has power of the term as key 592 and coefficient of that term as value. 593 594 Below is a basic outline of how the Laurent series of a 595 rational function `f(x)` about `x_0` is being calculated - 596 597 1. Substitute `x + x_0` in place of `x`. If `x_0` 598 is a pole of `f(x)`, multiply the expression by `x^m` 599 where `m` is the multiplicity of `x_0`. Denote the 600 the resulting expression as g(x). We do this substitution 601 so that we can now find the Laurent series of g(x) about 602 `x = 0`. 603 604 2. We can then assume that the Laurent series of `g(x)` 605 takes the following form - 606 607 .. math:: g(x) = \frac{num(x)}{den(x)} = \sum_{m = 0}^{\infty} a_m x^m 608 609 where `a_m` denotes the Laurent series coefficients. 610 611 3. Multiply the denominator to the RHS of the equation 612 and form a recurrence relation for the coefficients `a_m`. 613 """ 614 one = Poly(1, x, extension=True) 615 616 if r == oo: 617 # Series at x = oo is equal to first transforming 618 # the function from x -> 1/x and finding the 619 # series at x = 0 620 num, den = inverse_transform_poly(num, den, x) 621 r = S(0) 622 623 if r: 624 # For an expansion about a non-zero point, a 625 # transformation from x -> x + r must be made 626 num = num.transform(Poly(x + r, x, extension=True), one) 627 den = den.transform(Poly(x + r, x, extension=True), one) 628 629 # Remove the pole from the denominator if the series 630 # expansion is about one of the poles 631 num, den = (num*x**m).cancel(den, include=True) 632 633 # Equate coefficients for the first terms (base case) 634 maxdegree = 1 + max(num.degree(), den.degree()) 635 syms = symbols(f'a:{maxdegree}', cls=Dummy) 636 diff = num - den * Poly(syms[::-1], x) 637 coeff_diffs = diff.all_coeffs()[::-1][:maxdegree] 638 (coeffs, ) = linsolve(coeff_diffs, syms) 639 640 # Use the recursion relation for the rest 641 recursion = den.all_coeffs()[::-1] 642 div, rec_rhs = recursion[0], recursion[1:] 643 series = list(coeffs) 644 while len(series) < n: 645 next_coeff = Add(*(c*series[-1-n] for n, c in enumerate(rec_rhs))) / div 646 series.append(-next_coeff) 647 series = {m - i: val for i, val in enumerate(series)} 648 return series 649 650def compute_m_ybar(x, poles, choice, N): 651 """ 652 Helper function to calculate - 653 654 1. m - The degree bound for the polynomial 655 solution that must be found for the auxiliary 656 differential equation. 657 658 2. ybar - Part of the solution which can be 659 computed using the poles, c and d vectors. 660 """ 661 ybar = 0 662 m = Poly(choice[-1][-1], x, extension=True) 663 664 # Calculate the first (nested) summation for ybar 665 # as given in Step 9 of the Thesis (Pg 82) 666 for i in range(len(poles)): 667 for j in range(len(choice[i])): 668 ybar += choice[i][j]/(x - poles[i])**(j+1) 669 m -= Poly(choice[i][0], x, extension=True) 670 671 # Calculate the second summation for ybar 672 for i in range(N+1): 673 ybar += choice[-1][i]*x**i 674 return (m.expr, ybar) 675 676 677def solve_aux_eq(numa, dena, numy, deny, x, m): 678 """ 679 Helper function to find a polynomial solution 680 of degree m for the auxiliary differential 681 equation. 682 """ 683 # Assume that the solution is of the type 684 # p(x) = C_0 + C_1*x + ... + C_{m-1}*x**(m-1) + x**m 685 psyms = symbols(f'C0:{m}', cls=Dummy) 686 K = ZZ[psyms] 687 psol = Poly(K.gens, x, domain=K) + Poly(x**m, x, domain=K) 688 689 # Eq (5.16) in Thesis - Pg 81 690 auxeq = (dena*(numy.diff(x)*deny - numy*deny.diff(x) + numy**2) - numa*deny**2)*psol 691 if m >= 1: 692 px = psol.diff(x) 693 auxeq += px*(2*numy*deny*dena) 694 if m >= 2: 695 auxeq += px.diff(x)*(deny**2*dena) 696 if m != 0: 697 # m is a non-zero integer. Find the constant terms using undetermined coefficients 698 return psol, linsolve_dict(auxeq.all_coeffs(), psyms), True 699 else: 700 # m == 0 . Check if 1 (x**0) is a solution to the auxiliary equation 701 return S(1), auxeq, auxeq == 0 702 703 704def remove_redundant_sols(sol1, sol2, x): 705 """ 706 Helper function to remove redundant 707 solutions to the differential equation. 708 """ 709 # If y1 and y2 are redundant solutions, there is 710 # some value of the arbitrary constant for which 711 # they will be equal 712 713 syms1 = sol1.atoms(Symbol, Dummy) 714 syms2 = sol2.atoms(Symbol, Dummy) 715 num1, den1 = [Poly(e, x, extension=True) for e in sol1.together().as_numer_denom()] 716 num2, den2 = [Poly(e, x, extension=True) for e in sol2.together().as_numer_denom()] 717 # Cross multiply 718 e = num1*den2 - den1*num2 719 # Check if there are any constants 720 syms = list(e.atoms(Symbol, Dummy)) 721 if len(syms): 722 # Find values of constants for which solutions are equal 723 redn = linsolve(e.all_coeffs(), syms) 724 if len(redn): 725 # Return the general solution over a particular solution 726 if len(syms1) > len(syms2): 727 return sol2 728 # If both have constants, return the lesser complex solution 729 elif len(syms1) == len(syms2): 730 return sol1 if count_ops(syms1) >= count_ops(syms2) else sol2 731 else: 732 return sol1 733 734 735def get_gen_sol_from_part_sol(part_sols, a, x): 736 """" 737 Helper function which computes the general 738 solution for a Riccati ODE from its particular 739 solutions. 740 741 There are 3 cases to find the general solution 742 from the particular solutions for a Riccati ODE 743 depending on the number of particular solution(s) 744 we have - 1, 2 or 3. 745 746 For more information, see Section 6 of 747 "Methods of Solution of the Riccati Differential Equation" 748 by D. R. Haaheim and F. M. Stein 749 """ 750 751 # If no particular solutions are found, a general 752 # solution cannot be found 753 if len(part_sols) == 0: 754 return [] 755 756 # In case of a single particular solution, the general 757 # solution can be found by using the substitution 758 # y = y1 + 1/z and solving a Bernoulli ODE to find z. 759 elif len(part_sols) == 1: 760 y1 = part_sols[0] 761 i = exp(Integral(2*y1, x)) 762 z = i * Integral(a/i, x) 763 z = z.doit() 764 if a == 0 or z == 0: 765 return y1 766 return y1 + 1/z 767 768 # In case of 2 particular solutions, the general solution 769 # can be found by solving a separable equation. This is 770 # the most common case, i.e. most Riccati ODEs have 2 771 # rational particular solutions. 772 elif len(part_sols) == 2: 773 y1, y2 = part_sols 774 # One of them already has a constant 775 if len(y1.atoms(Dummy)) + len(y2.atoms(Dummy)) > 0: 776 u = exp(Integral(y2 - y1, x)).doit() 777 # Introduce a constant 778 else: 779 C1 = Dummy('C1') 780 u = C1*exp(Integral(y2 - y1, x)).doit() 781 if u == 1: 782 return y2 783 return (y2*u - y1)/(u - 1) 784 785 # In case of 3 particular solutions, a closed form 786 # of the general solution can be obtained directly 787 else: 788 y1, y2, y3 = part_sols[:3] 789 C1 = Dummy('C1') 790 return (C1 + 1)*y2*(y1 - y3)/(C1*y1 + y2 - (C1 + 1)*y3) 791 792 793def solve_riccati(fx, x, b0, b1, b2, gensol=False): 794 """ 795 The main function that gives particular/general 796 solutions to Riccati ODEs that have atleast 1 797 rational particular solution. 798 """ 799 # Step 1 : Convert to Normal Form 800 a = -b0*b2 + b1**2/4 - b1.diff(x)/2 + 3*b2.diff(x)**2/(4*b2**2) + b1*b2.diff(x)/(2*b2) - \ 801 b2.diff(x, 2)/(2*b2) 802 a_t = a.together() 803 num, den = [Poly(e, x, extension=True) for e in a_t.as_numer_denom()] 804 num, den = num.cancel(den, include=True) 805 806 # Step 2 807 presol = [] 808 809 # Step 3 : a(x) is 0 810 if num == 0: 811 presol.append(1/(x + Dummy('C1'))) 812 813 # Step 4 : a(x) is a non-zero constant 814 elif x not in num.free_symbols.union(den.free_symbols): 815 presol.extend([sqrt(a), -sqrt(a)]) 816 817 # Step 5 : Find poles and valuation at infinity 818 poles = roots(den, x) 819 poles, muls = list(poles.keys()), list(poles.values()) 820 val_inf = val_at_inf(num, den, x) 821 822 if len(poles): 823 # Check necessary conditions (outlined in the module docstring) 824 if not check_necessary_conds(val_inf, muls): 825 raise ValueError("Rational Solution doesn't exist") 826 827 # Step 6 828 # Construct c-vectors for each singular point 829 c = construct_c(num, den, x, poles, muls) 830 831 # Construct d vectors for each singular point 832 d = construct_d(num, den, x, val_inf) 833 834 # Step 7 : Iterate over all possible combinations and return solutions 835 # For each possible combination, generate an array of 0's and 1's 836 # where 0 means pick 1st choice and 1 means pick the second choice. 837 838 # NOTE: We could exit from the loop if we find 3 particular solutions, 839 # but it is not implemented here as - 840 # a. Finding 3 particular solutions is very rare. Most of the time, 841 # only 2 particular solutions are found. 842 # b. In case we exit after finding 3 particular solutions, it might 843 # happen that 1 or 2 of them are redundant solutions. So, instead of 844 # spending some more time in computing the particular solutions, 845 # we will end up computing the general solution from a single 846 # particular solution which is usually slower than computing the 847 # general solution from 2 or 3 particular solutions. 848 c.append(d) 849 choices = product(*c) 850 for choice in choices: 851 m, ybar = compute_m_ybar(x, poles, choice, -val_inf//2) 852 numy, deny = [Poly(e, x, extension=True) for e in ybar.together().as_numer_denom()] 853 # Step 10 : Check if a valid solution exists. If yes, also check 854 # if m is a non-negative integer 855 if m.is_nonnegative == True and m.is_integer == True: 856 857 # Step 11 : Find polynomial solutions of degree m for the auxiliary equation 858 psol, coeffs, exists = solve_aux_eq(num, den, numy, deny, x, m) 859 860 # Step 12 : If valid polynomial solution exists, append solution. 861 if exists: 862 # m == 0 case 863 if psol == 1 and coeffs == 0: 864 # p(x) = 1, so p'(x)/p(x) term need not be added 865 presol.append(ybar) 866 # m is a positive integer and there are valid coefficients 867 elif len(coeffs): 868 # Substitute the valid coefficients to get p(x) 869 psol = psol.xreplace(coeffs) 870 # y(x) = ybar(x) + p'(x)/p(x) 871 presol.append(ybar + psol.diff(x)/psol) 872 873 # Remove redundant solutions from the list of existing solutions 874 remove = set() 875 for i in range(len(presol)): 876 for j in range(i+1, len(presol)): 877 rem = remove_redundant_sols(presol[i], presol[j], x) 878 if rem is not None: 879 remove.add(rem) 880 sols = [x for x in presol if x not in remove] 881 882 # Step 15 : Inverse transform the solutions of the equation in normal form 883 bp = -b2.diff(x)/(2*b2**2) - b1/(2*b2) 884 885 # If general solution is required, compute it from the particular solutions 886 if gensol: 887 sols = [get_gen_sol_from_part_sol(sols, a, x)] 888 889 # Inverse transform the particular solutions 890 presol = [Eq(fx, riccati_inverse_normal(y, x, b1, b2, bp).cancel(extension=True)) for y in sols] 891 return presol 892