1# -*- coding: utf-8 -*-
2
3from __future__ import (absolute_import, division, print_function)
4
5from math import log
6
7import numpy as np
8
9
10def _set_scale(cb, argstr):
11    if argstr.count(';') == 0:
12        cb(argstr)
13    else:
14        arg, kw = argstr.split(';')
15        cb(arg, **eval('dict(%s)' % kw))
16
17
18def _latex_from_dimensionality(dim):
19    # see https://github.com/python-quantities/python-quantities/issues/148
20    from quantities.markup import format_units_latex
21    return format_units_latex(dim, mult=r'\\cdot')
22
23
24def plot_result(x, y, indices=None, plot_kwargs_cb=None, ax=None,
25                ls=('-', '--', ':', '-.'),
26                c=('tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple',
27                   'tab:brown', 'tab:pink', 'tab:gray', 'tab:olive', 'tab:cyan', 'black'),
28                m=('o', 'v', '8', 's', 'p', 'x', '+', 'd', 's'),
29                m_lim=-1, lines=None, interpolate=None, interp_from_deriv=None,
30                names=None, latex_names=None, xlabel=None, ylabel=None,
31                xscale=None, yscale=None, legend=False, yerr=None, labels=None, tex_lbl_fmt='$%s$',
32                fig_kw=None, xlim=None, ylim=None):
33    """
34    Plot the depepndent variables vs. the independent variable
35
36    Parameters
37    ----------
38    x : array_like
39        Values of the independent variable.
40    y : array_like
41        Values of the independent variable. This must hold
42        ``y.shape[0] == len(x)``, plot_results will draw
43        ``y.shape[1]`` lines. If ``interpolate != None``
44        y is expected two be three dimensional, otherwise two dimensional.
45    indices : iterable of integers
46        What indices to plot (default: None => all).
47    plot : callback (default: None)
48        If None, use ``matplotlib.pyplot.plot``.
49    plot_kwargs_cb : callback(int) -> dict
50        Keyword arguments for plot for each index (0:len(y)-1).
51    ax : Axes
52    ls : iterable
53        Linestyles to cycle through (only used if plot and plot_kwargs_cb
54        are both None).
55    c : iterable
56        Colors to cycle through (only used if plot and plot_kwargs_cb
57        are both None).
58    m : iterable
59        Markers to cycle through (only used if plot and plot_kwargs_cb
60        are both None and m_lim > 0).
61    m_lim : int (default: -1)
62        Upper limit (exclusive, number of points) for using markers instead of
63        lines.
64    lines : None
65        default: draw between markers unless we are interpolating as well.
66    interpolate : bool or int (default: None)
67        Density-multiplier for grid of independent variable when interpolating
68        if True => 20. negative integer signifies log-spaced grid.
69    interp_from_deriv : callback (default: None)
70        When ``None``: ``scipy.interpolate.BPoly.from_derivatives``
71    names : iterable of str
72    latex_names : iterable of str
73    labels : iterable of str
74        If ``None``, use ``latex_names`` or ``names`` (in that order).
75
76    """
77    import matplotlib.pyplot as plt
78
79    if ax is None:
80        _fig, ax = plt.subplots(1, 1, **(fig_kw or {}))
81    if plot_kwargs_cb is None:
82        def plot_kwargs_cb(idx, lines=False, markers=False, labels=None):
83
84            kw = {'c': c[idx % len(c)]}
85
86            if lines:
87                kw['ls'] = ls[idx % len(ls)]
88                if isinstance(lines, float):
89                    kw['alpha'] = lines
90            else:
91                kw['ls'] = 'None'
92
93            if markers:
94                kw['marker'] = m[idx % len(m)]
95
96            if labels:
97                kw['label'] = labels[idx]
98            return kw
99    else:
100        plot_kwargs_cb = plot_kwargs_cb or (lambda idx: {})
101
102    if interpolate is None:
103        interpolate = y.ndim == 3 and y.shape[1] > 1
104
105    if interpolate and y.ndim == 3:
106        _y = y[:, 0, :]
107    else:
108        _y = y
109
110    if indices is None:
111        indices = range(_y.shape[-1])  # e.g. PartiallySolvedSys
112    if lines is None:
113        lines = interpolate in (None, False)
114    markers = len(x) < m_lim
115
116    if yerr is not None:
117        for idx in indices:
118            clr = plot_kwargs_cb(idx)['c']
119            ax.fill_between(x, _y[:, idx] - yerr[:, idx], _y[:, idx] + yerr[:, idx], facecolor=clr, alpha=.3)
120
121    if isinstance(yscale, str) and 'linthreshy' in yscale:
122        arg, kw = yscale.split(';')
123        thresh = eval('dict(%s)' % kw)['linthreshy']
124        ax.axhline(thresh, linewidth=.5, linestyle='--', color='k', alpha=.5)
125        ax.axhline(-thresh, linewidth=.5, linestyle='--', color='k', alpha=.5)
126
127    if labels is None:
128        labels = names if latex_names is None else [tex_lbl_fmt % ln.strip('$') for ln in latex_names]
129
130    for idx in indices:
131        ax.plot(x, _y[:, idx], **plot_kwargs_cb(
132            idx, lines=lines, labels=labels))
133        if markers:
134            ax.plot(x, _y[:, idx], **plot_kwargs_cb(
135                idx, lines=False, markers=markers, labels=labels))
136
137    if xlabel is None:
138        try:
139            ax.set_xlabel(_latex_from_dimensionality(x.dimensionality))
140        except AttributeError:
141            pass
142    else:
143        ax.set_xlabel(xlabel)
144
145    if ylabel is None:
146        try:
147            ax.set_ylabel(_latex_from_dimensionality(_y.dimensionality))
148        except AttributeError:
149            pass
150    else:
151        ax.set_ylabel(ylabel)
152
153    if interpolate:
154        if interpolate is True:
155            interpolate = 20
156
157        if isinstance(interpolate, int):
158            if interpolate > 0:
159                x_plot = np.concatenate(
160                    [np.linspace(a, b, interpolate)
161                     for a, b in zip(x[:-1], x[1:])])
162            elif interpolate < 0:
163                x_plot = np.concatenate([
164                    np.logspace(np.log10(a), np.log10(b),
165                                -interpolate) for a, b
166                    in zip(x[:-1], x[1:])])
167        else:
168            x_plot = interpolate
169
170        if interp_from_deriv is None:
171            import scipy.interpolate
172            interp_from_deriv = scipy.interpolate.BPoly.from_derivatives
173
174        y2 = np.empty((x_plot.size, _y.shape[-1]))
175        for idx in range(_y.shape[-1]):
176            interp_cb = interp_from_deriv(x, y[..., idx])
177            y2[:, idx] = interp_cb(x_plot)
178
179        for idx in indices:
180            ax.plot(x_plot, y2[:, idx], **plot_kwargs_cb(
181                idx, lines=True, markers=False))
182        return x_plot, y2
183
184    if xscale is not None:
185        _set_scale(ax.set_xscale, xscale)
186    if yscale is not None:
187        _set_scale(ax.set_yscale, yscale)
188
189    if legend is True:
190        ax.legend()
191    elif legend in (None, False):
192        pass
193    else:
194        ax.legend(**legend)
195
196    if xlim:
197        ax.set_xlim(xlim)
198    if ylim:
199        ax.set_ylim(ylim)
200    return ax
201
202
203def plot_phase_plane(x, y, indices=None, plot=None, names=None, ax=None, **kwargs):
204    """ Plot the phase portrait of two dependent variables
205
206    Parameters
207    ----------
208    x: array_like
209        Values of the independent variable.
210    y: array_like
211        Values of the dependent variables.
212    indices: pair of integers (default: None)
213        What dependent variable to plot for (None => (0, 1)).
214    plot: callable (default: None)
215        Uses ``matplotlib.pyplot.plot`` if ``None``
216    names: iterable of strings
217        Labels for x and y axis.
218    \\*\\*kwargs:
219        Keyword arguemtns passed to ``plot()``.
220
221    """
222    if indices is None:
223        indices = (0, 1)
224    if len(indices) != 2:
225        raise ValueError('Only two phase variables supported at the moment')
226
227    if ax is None:
228        import matplotlib.pyplot as plt
229        ax = plt.subplot(1, 1, 1)
230
231    if names is not None:
232        ax.set_xlabel(names[indices[0]])
233        ax.set_ylabel(names[indices[1]])
234
235    ax.plot(y[:, indices[0]], y[:, indices[1]], **kwargs)
236
237
238def right_hand_ylabels(ax, labels):
239    ax2 = ax.twinx()
240    ylim = ax.get_ylim()
241    yspan = ylim[1]-ylim[0]
242    ax2.set_ylim(ylim)
243    yticks = [ylim[0] + (idx + 0.5)*yspan/len(labels) for idx in range(len(labels))]
244    ax2.tick_params(length=0)
245    ax2.set_yticks(yticks)
246    ax2.set_yticklabels(labels)
247
248
249def info_vlines(ax, xout, info, vline_colors=('maroon', 'purple'),
250                vline_keys=('steps', 'rhs_xvals', 'jac_xvals'),
251                post_proc=None, alpha=None, fpes=None, every=None):
252    """ Plot vertical lines in the background
253
254    Parameters
255    ----------
256    ax : axes
257    xout : array_like
258    info : dict
259    vline_colors : iterable of str
260    vline_keys : iterable of str
261        Choose from ``'steps', 'rhs_xvals', 'jac_xvals',
262        'fe_underflow', 'fe_overflow', 'fe_invalid', 'fe_divbyzero'``.
263    vline_post_proc : callable
264    alpha : float
265
266    """
267
268    nvk = len(vline_keys)
269    for idx, key in enumerate(vline_keys):
270        if key == 'steps':
271            vlines = xout
272        elif key.startswith('fe_'):
273            if fpes is None:
274                raise ValueError("Need fpes when vline_keys contain fe_*")
275            vlines = xout[info['fpes'] & fpes[key.upper()] > 0]
276        else:
277            vlines = info[key] if post_proc is None else post_proc(info[key])
278
279        if alpha is None:
280            alpha = 0.01 + 1/log(len(vlines)+3)
281
282        if every is None:
283            ln_np1 = log(len(vlines)+1)
284            every = min(round((ln_np1 - 4)/log(2)), 1)
285
286        ax.vlines(vlines[::every], idx/nvk + 0.002, (idx+1)/nvk - 0.002,
287                  colors=vline_colors[idx % len(vline_colors)],
288                  alpha=alpha, transform=ax.get_xaxis_transform())
289    right_hand_ylabels(ax, [k[3] if k.startswith('fe_') else k[0] for k in vline_keys])
290