1# -*- coding: utf-8 -*-
2from __future__ import (absolute_import, division, print_function)
3
4import numpy as np
5
6from .plotting import plot_result, plot_phase_plane, info_vlines
7from .util import import_
8
9CubicSpline = import_('scipy.interpolate', 'CubicSpline')
10interp1d = import_('scipy.interpolate', 'interp1d')
11
12
13class Result(object):
14
15    def __init__(self, xout, yout, params, info, odesys):
16        self.xout = xout
17        self.yout = yout
18        self.params = params
19        self.info = info
20        self.odesys = odesys
21
22    def copy(self):
23        return Result(self.xout.copy(), self.yout.copy(), self.params.copy(),
24                      self.info.copy(), self.odesys)
25
26    def __len__(self):
27        return 3
28
29    def __getitem__(self, key):
30        if key == 0:
31            return self.xout
32        elif key == 1:
33            return self.yout
34        elif key == 2:
35            return self.info
36        elif key == 3:
37            raise StopIteration
38        else:
39            raise KeyError("Invalid key: %s (for backward compatibility reasons)." % str(key))
40
41    def named_param(self, param_name):
42        return self.params[self.odesys.param_names.index(param_name)]
43
44    def named_dep(self, name):
45        return self.yout[..., self.odesys.names.index(name)]
46
47    def between(self, lower, upper, xdata=None, ydata=None):
48        """ Get results inside span for independent variable """
49        if xdata is None:
50            xdata = self.xout
51        if ydata is None:
52            ydata = self.yout
53        select_u = xdata < upper
54        xtmp, ytmp = xdata[..., select_u], ydata[..., select_u, :]
55        select_l = xtmp > lower
56        return xtmp[..., select_l], ytmp[..., select_l, :]
57
58    def at(self, x, use_deriv=False, xdata=None, ydata=None, linear=False):
59        """ Returns interpolated result at a given time and an interpolation error-estimate
60
61        By default interpolation is performed using cubic splines.
62
63        Parameters
64        ----------
65        x : array_like or float
66        use_deriv : bool
67            Calculate derivatives at spline knots for enhanced accuracy.
68        xdata : array
69        ydata : array
70        linear : bool
71            Will use (cheaper) linear interpolation. Useful when x is an array.
72            Error estimate will be ``None`` in this case.
73
74        Returns
75        -------
76        interpolated_y : array
77        error_estim_y : array
78
79        """
80        if xdata is None:
81            xdata = self.xout
82        if ydata is None:
83            ydata = self.yout
84        yunit = getattr(ydata, 'units', 1)
85        ydata = getattr(ydata, 'magnitude', ydata)
86
87        if linear:
88            return interp1d(xdata, ydata, axis=0)(x)*yunit, None
89        else:
90            try:
91                len(x)
92            except TypeError:
93                pass
94            else:
95                return [self.at(_, use_deriv, xdata, ydata, linear) for _ in x]
96
97        if x == xdata[0]:
98            res = ydata[0, :]
99            err = res*0
100        elif x == xdata[-1]:
101            res = ydata[-1, :]
102            err = res*0
103        else:
104            idx = np.argmax(xdata > x)
105            if idx == 0:
106                raise ValueError("x outside bounds")
107            idx_l = max(0, idx - 2)
108            idx_u = min(xdata.size, idx_l + 4)
109            slc = slice(idx_l, idx_u)
110            res_cub = CubicSpline(xdata[slc], ydata[slc, :])(x)
111            x0, x1 = xdata[idx - 1], xdata[idx]
112            y0, y1 = ydata[idx - 1, :], ydata[idx, :]
113            xspan, yspan = x1 - x0, y1 - y0
114            avgx, avgy = .5*(x0 + x1), .5*(y0 + y1)
115            if use_deriv:
116                # y = a + b*x + c*x**2 + d*x**3
117                # dydx = b + 2*c*x + 3*d*x**2
118                y0p, y1p = [np.asarray(self.odesys.f_cb(x, y, self.params))*xspan for y in (y0, y1)]
119                lsx = (x - x0)/xspan
120                d = y0p + y1p + 2*y0 - 2*y1
121                c = -2*y0p - y1p - 3*y0 + 3*y1
122                b, a = y0p, y0
123                res_poly = a + b*lsx + c*lsx**2 + d*lsx**3
124                res, err = res_poly, np.abs(res_poly - res_cub)
125            else:
126                res_lin = avgy + yspan/xspan*(x - avgx)
127                res, err = res_cub, np.abs(res_cub - np.asarray(res_lin))
128
129        return res*yunit, err*yunit
130
131    def _internal(self, key, override=None):
132        if override is None:
133            return self.info['internal_' + key]
134        else:
135            return override
136
137    def _internals(self):
138        return (
139            self._internal('xout'),
140            self._internal('yout'),
141            self._internal('params')[:-self.odesys.ny if self.odesys.append_iv else None]
142        )
143
144    def stiffness(self, xyp=None, eigenvals_cb=None):
145        """ Running stiffness ratio from last integration.
146
147        Calculate sittness ratio, i.e. the ratio between the largest and
148        smallest absolute eigenvalue of the jacobian matrix. The user may
149        supply their own routine for calculating the eigenvalues, or they
150        will be calculated from the SVD (singular value decomposition).
151        Note that calculating the SVD for any but the smallest Jacobians may
152        prove to be prohibitively expensive.
153
154        Parameters
155        ----------
156        xyp : length 3 tuple (default: None)
157            internal_xout, internal_yout, internal_params, taken
158            from last integration if not specified.
159        eigenvals_cb : callback (optional)
160            Signature (x, y, p) (internal variables), when not provided an
161            internal routine will use ``self.j_cb`` and ``scipy.linalg.svd``.
162
163        """
164        if eigenvals_cb is None:
165            if self.odesys.band is not None:
166                raise NotImplementedError
167            eigenvals_cb = self.odesys._jac_eigenvals_svd
168
169        if xyp is None:
170            x, y, intern_p = self._internals()
171        else:
172            x, y, intern_p = self.pre_process(*xyp)
173
174        singular_values = []
175        for xval, yvals in zip(x, y):
176            singular_values.append(eigenvals_cb(xval, yvals, intern_p))
177
178        return (np.abs(singular_values).max(axis=-1) /
179                np.abs(singular_values).min(axis=-1))
180
181    def _plot(self, cb, x=None, y=None, legend=None, **kwargs):
182        if x is None:
183            x = self.xout
184        if y is None:
185            y = self.yout
186
187        if 'names' in kwargs:
188            if 'indices' not in kwargs and (getattr(self.odesys, 'names', None) or None) is not None:
189                kwargs['indices'] = [self.odesys.names.index(n) for n in kwargs['names']]
190                kwargs['names'] = self.odesys.names
191        else:
192            kwargs['names'] = getattr(self.odesys, 'names', ())
193
194        if 'latex_names' not in kwargs:
195            _latex_names = getattr(self.odesys, 'latex_names', None)
196            if (_latex_names or None) is not None and not all(ln is None for ln in _latex_names):
197                kwargs['latex_names'] = _latex_names
198        if legend is None:
199            if (kwargs.get('latex_names') or None) is not None or (kwargs['names'] or None) is not None:
200                legend = True
201        return cb(x, y, legend=legend, **kwargs)
202
203    def plot(self, info_vlines_kw=None, between=None, deriv=False, title_info=0, **kwargs):
204        """ Plots the integrated dependent variables from last integration.
205
206        Parameters
207        ----------
208        info_vlines_kw : dict
209            Keyword arguments passed to :func:`.plotting.info_vlines`,
210            an empty dict will be used if `True`. Need to pass `ax` when given.
211        indices : iterable of int
212        between : length 2 tuple
213        deriv : bool
214            Plot derivatives (internal variables).
215        names : iterable of str
216        \\*\\*kwargs:
217            See :func:`pyodesys.plotting.plot_result`
218        """
219        if between is not None:
220            if 'x' in kwargs or 'y' in kwargs:
221                raise ValueError("x/y & between given.")
222            kwargs['x'], kwargs['y'] = self.between(*between)
223        if info_vlines_kw is not None:
224            if info_vlines_kw is True:
225                info_vlines_kw = {}
226            info_vlines(kwargs['ax'], self.xout, self.info, **info_vlines_kw)
227            self._plot(plot_result, plot_kwargs_cb=lambda *args, **kwargs:
228                       dict(c='w', ls='-', linewidth=7, alpha=.4), **kwargs)
229        if deriv:
230            if 'y' in kwargs:
231                raise ValueError("Cannot give both deriv=True and y.")
232            kwargs['y'] = self.odesys.f_cb(*self._internals())
233        ax = self._plot(plot_result, **kwargs)
234        if title_info:
235            ax.set_title(
236                (self.odesys.description or '') +
237                ', '.join(
238                    (['%d steps' % self.info['n_steps']] if self.info.get('n_steps', -1) >= 0 else []) +
239                    [
240                        '%d fev' % self.info['nfev'],
241                        '%d jev' % self.info['njev'],
242                    ] + ([
243                        '%.2g s CPU' % self.info['time_cpu']
244                    ] if title_info > 1 and self.info.get('time_cpu', -1) >= 0 else [])
245                ) +
246                (', success' if self.info['success'] else ', failed'),
247                {'fontsize': 'medium'} if title_info > 1 else {}
248            )
249        return ax
250
251    def plot_phase_plane(self, indices=None, **kwargs):
252        """ Plots a phase portrait from last integration.
253
254        Parameters
255        ----------
256        indices : iterable of int
257        names : iterable of str
258        \\*\\*kwargs:
259            See :func:`pyodesys.plotting.plot_phase_plane`
260
261        """
262        return self._plot(plot_phase_plane, indices=indices, **kwargs)
263
264    def calc_invariant_violations(self, xyp=None):
265        invar = self.odesys.get_invariants_callback()
266        val = invar(*(xyp or self._internals()))
267        return val - val[0, :]
268
269    def plot_invariant_violations(self, **kwargs):
270        viol = self.calc_invariant_violations()
271        abs_viol = np.abs(viol)
272        invar_names = self.odesys.all_invariant_names()
273        return self._plot(plot_result, x=self._internal('xout'), y=abs_viol, names=invar_names,
274                          latex_names=kwargs.pop('latex_names', invar_names), indices=None, **kwargs)
275
276    def extend_by_integration(self, xend, params=None, odesys=None, autonomous=None, npoints=1, **kwargs):
277        odesys = odesys or self.odesys
278        if autonomous is None:
279            autonomous = odesys.autonomous_interface
280        x0 = self.xout[-1]
281        nx0 = self.xout.size
282        res = odesys.integrate(
283            (
284                self.odesys.numpy.linspace((xend - x0)*0, (xend - x0), npoints+1) if autonomous
285                else self.odesys.numpy.linspace(x0, xend, npoints+1)
286            ), self.yout[..., -1, :], params or self.params, **kwargs
287        )
288        self.xout = self.odesys.numpy.concatenate((self.xout, res.xout[1:] + (x0 if autonomous else 0*x0)))
289        self.yout = self.odesys.numpy.concatenate((self.yout, res.yout[..., 1:, :]))
290        new_info = {k: v for k, v in self.info.items() if not (
291            k.startswith('internal') and odesys is not self.odesys)}
292        for k, v in res.info.items():
293            if k.startswith('internal'):
294                if odesys is self.odesys:
295                    new_info[k] = self.odesys.numpy.concatenate((new_info[k], v))
296                else:
297                    continue
298            elif k == 'success':
299                new_info[k] = new_info[k] and v
300            elif k.endswith('_xvals'):
301                if len(v) == 0:
302                    continue
303                new_info[k] = self.odesys.numpy.concatenate((new_info[k], v + (x0 if autonomous else 0*x0)))
304            elif k.endswith('_indices'):
305                new_info[k].extend([itm + nx0 - 1 for itm in v])
306            elif isinstance(v, str):
307                if isinstance(new_info[k], str):
308                    new_info[k] = [new_info[k]]
309                new_info[k].append(v)
310            else:
311                new_info[k] += v
312        self.info = new_info
313        return self
314