1# -*- coding: utf-8 -*- 2 3from __future__ import (absolute_import, division, print_function) 4 5from math import log 6 7import numpy as np 8 9 10def _set_scale(cb, argstr): 11 if argstr.count(';') == 0: 12 cb(argstr) 13 else: 14 arg, kw = argstr.split(';') 15 cb(arg, **eval('dict(%s)' % kw)) 16 17 18def _latex_from_dimensionality(dim): 19 # see https://github.com/python-quantities/python-quantities/issues/148 20 from quantities.markup import format_units_latex 21 return format_units_latex(dim, mult=r'\\cdot') 22 23 24def plot_result(x, y, indices=None, plot_kwargs_cb=None, ax=None, 25 ls=('-', '--', ':', '-.'), 26 c=('tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple', 27 'tab:brown', 'tab:pink', 'tab:gray', 'tab:olive', 'tab:cyan', 'black'), 28 m=('o', 'v', '8', 's', 'p', 'x', '+', 'd', 's'), 29 m_lim=-1, lines=None, interpolate=None, interp_from_deriv=None, 30 names=None, latex_names=None, xlabel=None, ylabel=None, 31 xscale=None, yscale=None, legend=False, yerr=None, labels=None, tex_lbl_fmt='$%s$', 32 fig_kw=None, xlim=None, ylim=None): 33 """ 34 Plot the depepndent variables vs. the independent variable 35 36 Parameters 37 ---------- 38 x : array_like 39 Values of the independent variable. 40 y : array_like 41 Values of the independent variable. This must hold 42 ``y.shape[0] == len(x)``, plot_results will draw 43 ``y.shape[1]`` lines. If ``interpolate != None`` 44 y is expected two be three dimensional, otherwise two dimensional. 45 indices : iterable of integers 46 What indices to plot (default: None => all). 47 plot : callback (default: None) 48 If None, use ``matplotlib.pyplot.plot``. 49 plot_kwargs_cb : callback(int) -> dict 50 Keyword arguments for plot for each index (0:len(y)-1). 51 ax : Axes 52 ls : iterable 53 Linestyles to cycle through (only used if plot and plot_kwargs_cb 54 are both None). 55 c : iterable 56 Colors to cycle through (only used if plot and plot_kwargs_cb 57 are both None). 58 m : iterable 59 Markers to cycle through (only used if plot and plot_kwargs_cb 60 are both None and m_lim > 0). 61 m_lim : int (default: -1) 62 Upper limit (exclusive, number of points) for using markers instead of 63 lines. 64 lines : None 65 default: draw between markers unless we are interpolating as well. 66 interpolate : bool or int (default: None) 67 Density-multiplier for grid of independent variable when interpolating 68 if True => 20. negative integer signifies log-spaced grid. 69 interp_from_deriv : callback (default: None) 70 When ``None``: ``scipy.interpolate.BPoly.from_derivatives`` 71 names : iterable of str 72 latex_names : iterable of str 73 labels : iterable of str 74 If ``None``, use ``latex_names`` or ``names`` (in that order). 75 76 """ 77 import matplotlib.pyplot as plt 78 79 if ax is None: 80 _fig, ax = plt.subplots(1, 1, **(fig_kw or {})) 81 if plot_kwargs_cb is None: 82 def plot_kwargs_cb(idx, lines=False, markers=False, labels=None): 83 84 kw = {'c': c[idx % len(c)]} 85 86 if lines: 87 kw['ls'] = ls[idx % len(ls)] 88 if isinstance(lines, float): 89 kw['alpha'] = lines 90 else: 91 kw['ls'] = 'None' 92 93 if markers: 94 kw['marker'] = m[idx % len(m)] 95 96 if labels: 97 kw['label'] = labels[idx] 98 return kw 99 else: 100 plot_kwargs_cb = plot_kwargs_cb or (lambda idx: {}) 101 102 if interpolate is None: 103 interpolate = y.ndim == 3 and y.shape[1] > 1 104 105 if interpolate and y.ndim == 3: 106 _y = y[:, 0, :] 107 else: 108 _y = y 109 110 if indices is None: 111 indices = range(_y.shape[-1]) # e.g. PartiallySolvedSys 112 if lines is None: 113 lines = interpolate in (None, False) 114 markers = len(x) < m_lim 115 116 if yerr is not None: 117 for idx in indices: 118 clr = plot_kwargs_cb(idx)['c'] 119 ax.fill_between(x, _y[:, idx] - yerr[:, idx], _y[:, idx] + yerr[:, idx], facecolor=clr, alpha=.3) 120 121 if isinstance(yscale, str) and 'linthreshy' in yscale: 122 arg, kw = yscale.split(';') 123 thresh = eval('dict(%s)' % kw)['linthreshy'] 124 ax.axhline(thresh, linewidth=.5, linestyle='--', color='k', alpha=.5) 125 ax.axhline(-thresh, linewidth=.5, linestyle='--', color='k', alpha=.5) 126 127 if labels is None: 128 labels = names if latex_names is None else [tex_lbl_fmt % ln.strip('$') for ln in latex_names] 129 130 for idx in indices: 131 ax.plot(x, _y[:, idx], **plot_kwargs_cb( 132 idx, lines=lines, labels=labels)) 133 if markers: 134 ax.plot(x, _y[:, idx], **plot_kwargs_cb( 135 idx, lines=False, markers=markers, labels=labels)) 136 137 if xlabel is None: 138 try: 139 ax.set_xlabel(_latex_from_dimensionality(x.dimensionality)) 140 except AttributeError: 141 pass 142 else: 143 ax.set_xlabel(xlabel) 144 145 if ylabel is None: 146 try: 147 ax.set_ylabel(_latex_from_dimensionality(_y.dimensionality)) 148 except AttributeError: 149 pass 150 else: 151 ax.set_ylabel(ylabel) 152 153 if interpolate: 154 if interpolate is True: 155 interpolate = 20 156 157 if isinstance(interpolate, int): 158 if interpolate > 0: 159 x_plot = np.concatenate( 160 [np.linspace(a, b, interpolate) 161 for a, b in zip(x[:-1], x[1:])]) 162 elif interpolate < 0: 163 x_plot = np.concatenate([ 164 np.logspace(np.log10(a), np.log10(b), 165 -interpolate) for a, b 166 in zip(x[:-1], x[1:])]) 167 else: 168 x_plot = interpolate 169 170 if interp_from_deriv is None: 171 import scipy.interpolate 172 interp_from_deriv = scipy.interpolate.BPoly.from_derivatives 173 174 y2 = np.empty((x_plot.size, _y.shape[-1])) 175 for idx in range(_y.shape[-1]): 176 interp_cb = interp_from_deriv(x, y[..., idx]) 177 y2[:, idx] = interp_cb(x_plot) 178 179 for idx in indices: 180 ax.plot(x_plot, y2[:, idx], **plot_kwargs_cb( 181 idx, lines=True, markers=False)) 182 return x_plot, y2 183 184 if xscale is not None: 185 _set_scale(ax.set_xscale, xscale) 186 if yscale is not None: 187 _set_scale(ax.set_yscale, yscale) 188 189 if legend is True: 190 ax.legend() 191 elif legend in (None, False): 192 pass 193 else: 194 ax.legend(**legend) 195 196 if xlim: 197 ax.set_xlim(xlim) 198 if ylim: 199 ax.set_ylim(ylim) 200 return ax 201 202 203def plot_phase_plane(x, y, indices=None, plot=None, names=None, ax=None, **kwargs): 204 """ Plot the phase portrait of two dependent variables 205 206 Parameters 207 ---------- 208 x: array_like 209 Values of the independent variable. 210 y: array_like 211 Values of the dependent variables. 212 indices: pair of integers (default: None) 213 What dependent variable to plot for (None => (0, 1)). 214 plot: callable (default: None) 215 Uses ``matplotlib.pyplot.plot`` if ``None`` 216 names: iterable of strings 217 Labels for x and y axis. 218 \\*\\*kwargs: 219 Keyword arguemtns passed to ``plot()``. 220 221 """ 222 if indices is None: 223 indices = (0, 1) 224 if len(indices) != 2: 225 raise ValueError('Only two phase variables supported at the moment') 226 227 if ax is None: 228 import matplotlib.pyplot as plt 229 ax = plt.subplot(1, 1, 1) 230 231 if names is not None: 232 ax.set_xlabel(names[indices[0]]) 233 ax.set_ylabel(names[indices[1]]) 234 235 ax.plot(y[:, indices[0]], y[:, indices[1]], **kwargs) 236 237 238def right_hand_ylabels(ax, labels): 239 ax2 = ax.twinx() 240 ylim = ax.get_ylim() 241 yspan = ylim[1]-ylim[0] 242 ax2.set_ylim(ylim) 243 yticks = [ylim[0] + (idx + 0.5)*yspan/len(labels) for idx in range(len(labels))] 244 ax2.tick_params(length=0) 245 ax2.set_yticks(yticks) 246 ax2.set_yticklabels(labels) 247 248 249def info_vlines(ax, xout, info, vline_colors=('maroon', 'purple'), 250 vline_keys=('steps', 'rhs_xvals', 'jac_xvals'), 251 post_proc=None, alpha=None, fpes=None, every=None): 252 """ Plot vertical lines in the background 253 254 Parameters 255 ---------- 256 ax : axes 257 xout : array_like 258 info : dict 259 vline_colors : iterable of str 260 vline_keys : iterable of str 261 Choose from ``'steps', 'rhs_xvals', 'jac_xvals', 262 'fe_underflow', 'fe_overflow', 'fe_invalid', 'fe_divbyzero'``. 263 vline_post_proc : callable 264 alpha : float 265 266 """ 267 268 nvk = len(vline_keys) 269 for idx, key in enumerate(vline_keys): 270 if key == 'steps': 271 vlines = xout 272 elif key.startswith('fe_'): 273 if fpes is None: 274 raise ValueError("Need fpes when vline_keys contain fe_*") 275 vlines = xout[info['fpes'] & fpes[key.upper()] > 0] 276 else: 277 vlines = info[key] if post_proc is None else post_proc(info[key]) 278 279 if alpha is None: 280 alpha = 0.01 + 1/log(len(vlines)+3) 281 282 if every is None: 283 ln_np1 = log(len(vlines)+1) 284 every = min(round((ln_np1 - 4)/log(2)), 1) 285 286 ax.vlines(vlines[::every], idx/nvk + 0.002, (idx+1)/nvk - 0.002, 287 colors=vline_colors[idx % len(vline_colors)], 288 alpha=alpha, transform=ax.get_xaxis_transform()) 289 right_hand_ylabels(ax, [k[3] if k.startswith('fe_') else k[0] for k in vline_keys]) 290