1import numpy as np
2from .base import OdeSolver, DenseOutput
3from .common import (validate_max_step, validate_tol, select_initial_step,
4                     norm, warn_extraneous, validate_first_step)
5from . import dop853_coefficients
6
7# Multiply steps computed from asymptotic behaviour of errors by this.
8SAFETY = 0.9
9
10MIN_FACTOR = 0.2  # Minimum allowed decrease in a step size.
11MAX_FACTOR = 10  # Maximum allowed increase in a step size.
12
13
14def rk_step(fun, t, y, f, h, A, B, C, K):
15    """Perform a single Runge-Kutta step.
16
17    This function computes a prediction of an explicit Runge-Kutta method and
18    also estimates the error of a less accurate method.
19
20    Notation for Butcher tableau is as in [1]_.
21
22    Parameters
23    ----------
24    fun : callable
25        Right-hand side of the system.
26    t : float
27        Current time.
28    y : ndarray, shape (n,)
29        Current state.
30    f : ndarray, shape (n,)
31        Current value of the derivative, i.e., ``fun(x, y)``.
32    h : float
33        Step to use.
34    A : ndarray, shape (n_stages, n_stages)
35        Coefficients for combining previous RK stages to compute the next
36        stage. For explicit methods the coefficients at and above the main
37        diagonal are zeros.
38    B : ndarray, shape (n_stages,)
39        Coefficients for combining RK stages for computing the final
40        prediction.
41    C : ndarray, shape (n_stages,)
42        Coefficients for incrementing time for consecutive RK stages.
43        The value for the first stage is always zero.
44    K : ndarray, shape (n_stages + 1, n)
45        Storage array for putting RK stages here. Stages are stored in rows.
46        The last row is a linear combination of the previous rows with
47        coefficients
48
49    Returns
50    -------
51    y_new : ndarray, shape (n,)
52        Solution at t + h computed with a higher accuracy.
53    f_new : ndarray, shape (n,)
54        Derivative ``fun(t + h, y_new)``.
55
56    References
57    ----------
58    .. [1] E. Hairer, S. P. Norsett G. Wanner, "Solving Ordinary Differential
59           Equations I: Nonstiff Problems", Sec. II.4.
60    """
61    K[0] = f
62    for s, (a, c) in enumerate(zip(A[1:], C[1:]), start=1):
63        dy = np.dot(K[:s].T, a[:s]) * h
64        K[s] = fun(t + c * h, y + dy)
65
66    y_new = y + h * np.dot(K[:-1].T, B)
67    f_new = fun(t + h, y_new)
68
69    K[-1] = f_new
70
71    return y_new, f_new
72
73
74class RungeKutta(OdeSolver):
75    """Base class for explicit Runge-Kutta methods."""
76    C: np.ndarray = NotImplemented
77    A: np.ndarray = NotImplemented
78    B: np.ndarray = NotImplemented
79    E: np.ndarray = NotImplemented
80    P: np.ndarray = NotImplemented
81    order: int = NotImplemented
82    error_estimator_order: int = NotImplemented
83    n_stages: int = NotImplemented
84
85    def __init__(self, fun, t0, y0, t_bound, max_step=np.inf,
86                 rtol=1e-3, atol=1e-6, vectorized=False,
87                 first_step=None, **extraneous):
88        warn_extraneous(extraneous)
89        super().__init__(fun, t0, y0, t_bound, vectorized,
90                         support_complex=True)
91        self.y_old = None
92        self.max_step = validate_max_step(max_step)
93        self.rtol, self.atol = validate_tol(rtol, atol, self.n)
94        self.f = self.fun(self.t, self.y)
95        if first_step is None:
96            self.h_abs = select_initial_step(
97                self.fun, self.t, self.y, self.f, self.direction,
98                self.error_estimator_order, self.rtol, self.atol)
99        else:
100            self.h_abs = validate_first_step(first_step, t0, t_bound)
101        self.K = np.empty((self.n_stages + 1, self.n), dtype=self.y.dtype)
102        self.error_exponent = -1 / (self.error_estimator_order + 1)
103        self.h_previous = None
104
105    def _estimate_error(self, K, h):
106        return np.dot(K.T, self.E) * h
107
108    def _estimate_error_norm(self, K, h, scale):
109        return norm(self._estimate_error(K, h) / scale)
110
111    def _step_impl(self):
112        t = self.t
113        y = self.y
114
115        max_step = self.max_step
116        rtol = self.rtol
117        atol = self.atol
118
119        min_step = 10 * np.abs(np.nextafter(t, self.direction * np.inf) - t)
120
121        if self.h_abs > max_step:
122            h_abs = max_step
123        elif self.h_abs < min_step:
124            h_abs = min_step
125        else:
126            h_abs = self.h_abs
127
128        step_accepted = False
129        step_rejected = False
130
131        while not step_accepted:
132            if h_abs < min_step:
133                return False, self.TOO_SMALL_STEP
134
135            h = h_abs * self.direction
136            t_new = t + h
137
138            if self.direction * (t_new - self.t_bound) > 0:
139                t_new = self.t_bound
140
141            h = t_new - t
142            h_abs = np.abs(h)
143
144            y_new, f_new = rk_step(self.fun, t, y, self.f, h, self.A,
145                                   self.B, self.C, self.K)
146            scale = atol + np.maximum(np.abs(y), np.abs(y_new)) * rtol
147            error_norm = self._estimate_error_norm(self.K, h, scale)
148
149            if error_norm < 1:
150                if error_norm == 0:
151                    factor = MAX_FACTOR
152                else:
153                    factor = min(MAX_FACTOR,
154                                 SAFETY * error_norm ** self.error_exponent)
155
156                if step_rejected:
157                    factor = min(1, factor)
158
159                h_abs *= factor
160
161                step_accepted = True
162            else:
163                h_abs *= max(MIN_FACTOR,
164                             SAFETY * error_norm ** self.error_exponent)
165                step_rejected = True
166
167        self.h_previous = h
168        self.y_old = y
169
170        self.t = t_new
171        self.y = y_new
172
173        self.h_abs = h_abs
174        self.f = f_new
175
176        return True, None
177
178    def _dense_output_impl(self):
179        Q = self.K.T.dot(self.P)
180        return RkDenseOutput(self.t_old, self.t, self.y_old, Q)
181
182
183class RK23(RungeKutta):
184    """Explicit Runge-Kutta method of order 3(2).
185
186    This uses the Bogacki-Shampine pair of formulas [1]_. The error is controlled
187    assuming accuracy of the second-order method, but steps are taken using the
188    third-order accurate formula (local extrapolation is done). A cubic Hermite
189    polynomial is used for the dense output.
190
191    Can be applied in the complex domain.
192
193    Parameters
194    ----------
195    fun : callable
196        Right-hand side of the system. The calling signature is ``fun(t, y)``.
197        Here ``t`` is a scalar and there are two options for ndarray ``y``.
198        It can either have shape (n,), then ``fun`` must return array_like with
199        shape (n,). Or alternatively it can have shape (n, k), then ``fun``
200        must return array_like with shape (n, k), i.e. each column
201        corresponds to a single column in ``y``. The choice between the two
202        options is determined by `vectorized` argument (see below).
203    t0 : float
204        Initial time.
205    y0 : array_like, shape (n,)
206        Initial state.
207    t_bound : float
208        Boundary time - the integration won't continue beyond it. It also
209        determines the direction of the integration.
210    first_step : float or None, optional
211        Initial step size. Default is ``None`` which means that the algorithm
212        should choose.
213    max_step : float, optional
214        Maximum allowed step size. Default is np.inf, i.e., the step size is not
215        bounded and determined solely by the solver.
216    rtol, atol : float and array_like, optional
217        Relative and absolute tolerances. The solver keeps the local error
218        estimates less than ``atol + rtol * abs(y)``. Here, `rtol` controls a
219        relative accuracy (number of correct digits). But if a component of `y`
220        is approximately below `atol`, the error only needs to fall within
221        the same `atol` threshold, and the number of correct digits is not
222        guaranteed. If components of y have different scales, it might be
223        beneficial to set different `atol` values for different components by
224        passing array_like with shape (n,) for `atol`. Default values are
225        1e-3 for `rtol` and 1e-6 for `atol`.
226    vectorized : bool, optional
227        Whether `fun` is implemented in a vectorized fashion. Default is False.
228
229    Attributes
230    ----------
231    n : int
232        Number of equations.
233    status : string
234        Current status of the solver: 'running', 'finished' or 'failed'.
235    t_bound : float
236        Boundary time.
237    direction : float
238        Integration direction: +1 or -1.
239    t : float
240        Current time.
241    y : ndarray
242        Current state.
243    t_old : float
244        Previous time. None if no steps were made yet.
245    step_size : float
246        Size of the last successful step. None if no steps were made yet.
247    nfev : int
248        Number evaluations of the system's right-hand side.
249    njev : int
250        Number of evaluations of the Jacobian. Is always 0 for this solver as it does not use the Jacobian.
251    nlu : int
252        Number of LU decompositions. Is always 0 for this solver.
253
254    References
255    ----------
256    .. [1] P. Bogacki, L.F. Shampine, "A 3(2) Pair of Runge-Kutta Formulas",
257           Appl. Math. Lett. Vol. 2, No. 4. pp. 321-325, 1989.
258    """
259    order = 3
260    error_estimator_order = 2
261    n_stages = 3
262    C = np.array([0, 1/2, 3/4])
263    A = np.array([
264        [0, 0, 0],
265        [1/2, 0, 0],
266        [0, 3/4, 0]
267    ])
268    B = np.array([2/9, 1/3, 4/9])
269    E = np.array([5/72, -1/12, -1/9, 1/8])
270    P = np.array([[1, -4 / 3, 5 / 9],
271                  [0, 1, -2/3],
272                  [0, 4/3, -8/9],
273                  [0, -1, 1]])
274
275
276class RK45(RungeKutta):
277    """Explicit Runge-Kutta method of order 5(4).
278
279    This uses the Dormand-Prince pair of formulas [1]_. The error is controlled
280    assuming accuracy of the fourth-order method accuracy, but steps are taken
281    using the fifth-order accurate formula (local extrapolation is done).
282    A quartic interpolation polynomial is used for the dense output [2]_.
283
284    Can be applied in the complex domain.
285
286    Parameters
287    ----------
288    fun : callable
289        Right-hand side of the system. The calling signature is ``fun(t, y)``.
290        Here ``t`` is a scalar, and there are two options for the ndarray ``y``:
291        It can either have shape (n,); then ``fun`` must return array_like with
292        shape (n,). Alternatively it can have shape (n, k); then ``fun``
293        must return an array_like with shape (n, k), i.e., each column
294        corresponds to a single column in ``y``. The choice between the two
295        options is determined by `vectorized` argument (see below).
296    t0 : float
297        Initial time.
298    y0 : array_like, shape (n,)
299        Initial state.
300    t_bound : float
301        Boundary time - the integration won't continue beyond it. It also
302        determines the direction of the integration.
303    first_step : float or None, optional
304        Initial step size. Default is ``None`` which means that the algorithm
305        should choose.
306    max_step : float, optional
307        Maximum allowed step size. Default is np.inf, i.e., the step size is not
308        bounded and determined solely by the solver.
309    rtol, atol : float and array_like, optional
310        Relative and absolute tolerances. The solver keeps the local error
311        estimates less than ``atol + rtol * abs(y)``. Here `rtol` controls a
312        relative accuracy (number of correct digits). But if a component of `y`
313        is approximately below `atol`, the error only needs to fall within
314        the same `atol` threshold, and the number of correct digits is not
315        guaranteed. If components of y have different scales, it might be
316        beneficial to set different `atol` values for different components by
317        passing array_like with shape (n,) for `atol`. Default values are
318        1e-3 for `rtol` and 1e-6 for `atol`.
319    vectorized : bool, optional
320        Whether `fun` is implemented in a vectorized fashion. Default is False.
321
322    Attributes
323    ----------
324    n : int
325        Number of equations.
326    status : string
327        Current status of the solver: 'running', 'finished' or 'failed'.
328    t_bound : float
329        Boundary time.
330    direction : float
331        Integration direction: +1 or -1.
332    t : float
333        Current time.
334    y : ndarray
335        Current state.
336    t_old : float
337        Previous time. None if no steps were made yet.
338    step_size : float
339        Size of the last successful step. None if no steps were made yet.
340    nfev : int
341        Number evaluations of the system's right-hand side.
342    njev : int
343        Number of evaluations of the Jacobian. Is always 0 for this solver as it does not use the Jacobian.
344    nlu : int
345        Number of LU decompositions. Is always 0 for this solver.
346
347    References
348    ----------
349    .. [1] J. R. Dormand, P. J. Prince, "A family of embedded Runge-Kutta
350           formulae", Journal of Computational and Applied Mathematics, Vol. 6,
351           No. 1, pp. 19-26, 1980.
352    .. [2] L. W. Shampine, "Some Practical Runge-Kutta Formulas", Mathematics
353           of Computation,, Vol. 46, No. 173, pp. 135-150, 1986.
354    """
355    order = 5
356    error_estimator_order = 4
357    n_stages = 6
358    C = np.array([0, 1/5, 3/10, 4/5, 8/9, 1])
359    A = np.array([
360        [0, 0, 0, 0, 0],
361        [1/5, 0, 0, 0, 0],
362        [3/40, 9/40, 0, 0, 0],
363        [44/45, -56/15, 32/9, 0, 0],
364        [19372/6561, -25360/2187, 64448/6561, -212/729, 0],
365        [9017/3168, -355/33, 46732/5247, 49/176, -5103/18656]
366    ])
367    B = np.array([35/384, 0, 500/1113, 125/192, -2187/6784, 11/84])
368    E = np.array([-71/57600, 0, 71/16695, -71/1920, 17253/339200, -22/525,
369                  1/40])
370    # Corresponds to the optimum value of c_6 from [2]_.
371    P = np.array([
372        [1, -8048581381/2820520608, 8663915743/2820520608,
373         -12715105075/11282082432],
374        [0, 0, 0, 0],
375        [0, 131558114200/32700410799, -68118460800/10900136933,
376         87487479700/32700410799],
377        [0, -1754552775/470086768, 14199869525/1410260304,
378         -10690763975/1880347072],
379        [0, 127303824393/49829197408, -318862633887/49829197408,
380         701980252875 / 199316789632],
381        [0, -282668133/205662961, 2019193451/616988883, -1453857185/822651844],
382        [0, 40617522/29380423, -110615467/29380423, 69997945/29380423]])
383
384
385class DOP853(RungeKutta):
386    """Explicit Runge-Kutta method of order 8.
387
388    This is a Python implementation of "DOP853" algorithm originally written
389    in Fortran [1]_, [2]_. Note that this is not a literate translation, but
390    the algorithmic core and coefficients are the same.
391
392    Can be applied in the complex domain.
393
394    Parameters
395    ----------
396    fun : callable
397        Right-hand side of the system. The calling signature is ``fun(t, y)``.
398        Here, ``t`` is a scalar, and there are two options for the ndarray ``y``:
399        It can either have shape (n,); then ``fun`` must return array_like with
400        shape (n,). Alternatively it can have shape (n, k); then ``fun``
401        must return an array_like with shape (n, k), i.e. each column
402        corresponds to a single column in ``y``. The choice between the two
403        options is determined by `vectorized` argument (see below).
404    t0 : float
405        Initial time.
406    y0 : array_like, shape (n,)
407        Initial state.
408    t_bound : float
409        Boundary time - the integration won't continue beyond it. It also
410        determines the direction of the integration.
411    first_step : float or None, optional
412        Initial step size. Default is ``None`` which means that the algorithm
413        should choose.
414    max_step : float, optional
415        Maximum allowed step size. Default is np.inf, i.e. the step size is not
416        bounded and determined solely by the solver.
417    rtol, atol : float and array_like, optional
418        Relative and absolute tolerances. The solver keeps the local error
419        estimates less than ``atol + rtol * abs(y)``. Here `rtol` controls a
420        relative accuracy (number of correct digits). But if a component of `y`
421        is approximately below `atol`, the error only needs to fall within
422        the same `atol` threshold, and the number of correct digits is not
423        guaranteed. If components of y have different scales, it might be
424        beneficial to set different `atol` values for different components by
425        passing array_like with shape (n,) for `atol`. Default values are
426        1e-3 for `rtol` and 1e-6 for `atol`.
427    vectorized : bool, optional
428        Whether `fun` is implemented in a vectorized fashion. Default is False.
429
430    Attributes
431    ----------
432    n : int
433        Number of equations.
434    status : string
435        Current status of the solver: 'running', 'finished' or 'failed'.
436    t_bound : float
437        Boundary time.
438    direction : float
439        Integration direction: +1 or -1.
440    t : float
441        Current time.
442    y : ndarray
443        Current state.
444    t_old : float
445        Previous time. None if no steps were made yet.
446    step_size : float
447        Size of the last successful step. None if no steps were made yet.
448    nfev : int
449        Number evaluations of the system's right-hand side.
450    njev : int
451        Number of evaluations of the Jacobian. Is always 0 for this solver
452        as it does not use the Jacobian.
453    nlu : int
454        Number of LU decompositions. Is always 0 for this solver.
455
456    References
457    ----------
458    .. [1] E. Hairer, S. P. Norsett G. Wanner, "Solving Ordinary Differential
459           Equations I: Nonstiff Problems", Sec. II.
460    .. [2] `Page with original Fortran code of DOP853
461            <http://www.unige.ch/~hairer/software.html>`_.
462    """
463    n_stages = dop853_coefficients.N_STAGES
464    order = 8
465    error_estimator_order = 7
466    A = dop853_coefficients.A[:n_stages, :n_stages]
467    B = dop853_coefficients.B
468    C = dop853_coefficients.C[:n_stages]
469    E3 = dop853_coefficients.E3
470    E5 = dop853_coefficients.E5
471    D = dop853_coefficients.D
472
473    A_EXTRA = dop853_coefficients.A[n_stages + 1:]
474    C_EXTRA = dop853_coefficients.C[n_stages + 1:]
475
476    def __init__(self, fun, t0, y0, t_bound, max_step=np.inf,
477                 rtol=1e-3, atol=1e-6, vectorized=False,
478                 first_step=None, **extraneous):
479        super().__init__(fun, t0, y0, t_bound, max_step, rtol, atol,
480                         vectorized, first_step, **extraneous)
481        self.K_extended = np.empty((dop853_coefficients.N_STAGES_EXTENDED,
482                                    self.n), dtype=self.y.dtype)
483        self.K = self.K_extended[:self.n_stages + 1]
484
485    def _estimate_error(self, K, h):  # Left for testing purposes.
486        err5 = np.dot(K.T, self.E5)
487        err3 = np.dot(K.T, self.E3)
488        denom = np.hypot(np.abs(err5), 0.1 * np.abs(err3))
489        correction_factor = np.ones_like(err5)
490        mask = denom > 0
491        correction_factor[mask] = np.abs(err5[mask]) / denom[mask]
492        return h * err5 * correction_factor
493
494    def _estimate_error_norm(self, K, h, scale):
495        err5 = np.dot(K.T, self.E5) / scale
496        err3 = np.dot(K.T, self.E3) / scale
497        err5_norm_2 = np.linalg.norm(err5)**2
498        err3_norm_2 = np.linalg.norm(err3)**2
499        if err5_norm_2 == 0 and err3_norm_2 == 0:
500            return 0.0
501        denom = err5_norm_2 + 0.01 * err3_norm_2
502        return np.abs(h) * err5_norm_2 / np.sqrt(denom * len(scale))
503
504    def _dense_output_impl(self):
505        K = self.K_extended
506        h = self.h_previous
507        for s, (a, c) in enumerate(zip(self.A_EXTRA, self.C_EXTRA),
508                                   start=self.n_stages + 1):
509            dy = np.dot(K[:s].T, a[:s]) * h
510            K[s] = self.fun(self.t_old + c * h, self.y_old + dy)
511
512        F = np.empty((dop853_coefficients.INTERPOLATOR_POWER, self.n),
513                     dtype=self.y_old.dtype)
514
515        f_old = K[0]
516        delta_y = self.y - self.y_old
517
518        F[0] = delta_y
519        F[1] = h * f_old - delta_y
520        F[2] = 2 * delta_y - h * (self.f + f_old)
521        F[3:] = h * np.dot(self.D, K)
522
523        return Dop853DenseOutput(self.t_old, self.t, self.y_old, F)
524
525
526class RkDenseOutput(DenseOutput):
527    def __init__(self, t_old, t, y_old, Q):
528        super().__init__(t_old, t)
529        self.h = t - t_old
530        self.Q = Q
531        self.order = Q.shape[1] - 1
532        self.y_old = y_old
533
534    def _call_impl(self, t):
535        x = (t - self.t_old) / self.h
536        if t.ndim == 0:
537            p = np.tile(x, self.order + 1)
538            p = np.cumprod(p)
539        else:
540            p = np.tile(x, (self.order + 1, 1))
541            p = np.cumprod(p, axis=0)
542        y = self.h * np.dot(self.Q, p)
543        if y.ndim == 2:
544            y += self.y_old[:, None]
545        else:
546            y += self.y_old
547
548        return y
549
550
551class Dop853DenseOutput(DenseOutput):
552    def __init__(self, t_old, t, y_old, F):
553        super().__init__(t_old, t)
554        self.h = t - t_old
555        self.F = F
556        self.y_old = y_old
557
558    def _call_impl(self, t):
559        x = (t - self.t_old) / self.h
560
561        if t.ndim == 0:
562            y = np.zeros_like(self.y_old)
563        else:
564            x = x[:, None]
565            y = np.zeros((len(x), len(self.y_old)), dtype=self.y_old.dtype)
566
567        for i, f in enumerate(reversed(self.F)):
568            y += f
569            if i % 2 == 0:
570                y *= x
571            else:
572                y *= 1 - x
573        y += self.y_old
574
575        return y.T
576