1# coding: utf-8
2"""
3Utilities for generating matplotlib plots.
4
5.. note::
6
7    Avoid importing matplotlib in the module namespace otherwise startup is very slow.
8"""
9import os
10import time
11import itertools
12import numpy as np
13
14from collections import OrderedDict, namedtuple
15from pymatgen.util.plotting import add_fig_kwargs, get_ax_fig_plt, get_ax3d_fig_plt, get_axarray_fig_plt
16from .numtools import data_from_cplx_mode
17
18
19__all__ = [
20    "set_axlims",
21    "get_ax_fig_plt",
22    "get_ax3d_fig_plt",
23    "plot_array",
24    "ArrayPlotter",
25    "data_from_cplx_mode",
26    "Marker",
27    "plot_unit_cell",
28    "GenericDataFilePlotter",
29    "GenericDataFilesPlotter",
30]
31
32
33# https://matplotlib.org/gallery/lines_bars_and_markers/linestyles.html
34linestyles = OrderedDict(
35    [('solid',               (0, ())),
36     ('loosely_dotted',      (0, (1, 10))),
37     ('dotted',              (0, (1, 5))),
38     ('densely_dotted',      (0, (1, 1))),
39
40     ('loosely_dashed',      (0, (5, 10))),
41     ('dashed',              (0, (5, 5))),
42     ('densely_dashed',      (0, (5, 1))),
43
44     ('loosely_dashdotted',  (0, (3, 10, 1, 10))),
45     ('dashdotted',          (0, (3, 5, 1, 5))),
46     ('densely_dashdotted',  (0, (3, 1, 1, 1))),
47
48     ('loosely_dashdotdotted', (0, (3, 10, 1, 10, 1, 10))),
49     ('dashdotdotted',         (0, (3, 5, 1, 5, 1, 5))),
50     ('densely_dashdotdotted', (0, (3, 1, 1, 1, 1, 1)))]
51)
52
53
54def ax_append_title(ax, title, loc="center", fontsize=None):
55    """Add title to previous ax.title. Return new title."""
56    prev_title = ax.get_title(loc=loc)
57    new_title = prev_title + title
58    ax.set_title(new_title, loc=loc, fontsize=fontsize)
59    return new_title
60
61
62def ax_share(xy_string, *ax_list):
63    """
64    Share x- or y-axis of two or more subplots after they are created
65
66    Args:
67        xy_string: "x" to share x-axis, "xy" for both
68        ax_list: List of axes to share.
69
70    Example:
71
72        ax_share("y", ax0, ax1)
73        ax_share("xy", *(ax0, ax1, ax2))
74    """
75    if "x" in xy_string:
76        for ix, ax in enumerate(ax_list):
77            others = [a for a in ax_list if a != ax]
78            ax.get_shared_x_axes().join(*others)
79
80    if "y" in xy_string:
81        for ix, ax in enumerate(ax_list):
82            others = [a for a in ax_list if a != ax]
83            ax.get_shared_y_axes().join(*others)
84
85
86#def set_grid(fig, boolean):
87#    if hasattr(fig, "axes"):
88#        for ax in fig.axes:
89#            if ax.grid: ax.grid.set_visible(boolean)
90#    else:
91#            if ax.grid: ax.grid.set_visible(boolean)
92
93
94def set_axlims(ax, lims, axname):
95    """
96    Set the data limits for the axis ax.
97
98    Args:
99        lims: tuple(2) for (left, right), tuple(1) or scalar for left only.
100        axname: "x" for x-axis, "y" for y-axis.
101
102    Return: (left, right)
103    """
104    left, right = None, None
105    if lims is None: return (left, right)
106
107    len_lims = None
108    try:
109        len_lims = len(lims)
110    except TypeError:
111        # Assume Scalar
112        left = float(lims)
113
114    if len_lims is not None:
115        if len(lims) == 2:
116            left, right = lims[0], lims[1]
117        elif len(lims) == 1:
118            left = lims[0]
119
120    set_lim = getattr(ax, {"x": "set_xlim", "y": "set_ylim"}[axname])
121    if left != right:
122        set_lim(left, right)
123
124    return left, right
125
126
127def set_ax_xylabels(ax, xlabel, ylabel, exchange_xy):
128    """
129    Set the x- and the y-label of axis ax, exchanging x and y if exchange_xy
130    """
131    if exchange_xy: xlabel, ylabel = ylabel, xlabel
132    ax.set_xlabel(xlabel)
133    ax.set_ylabel(ylabel)
134
135
136def set_visible(ax, boolean, *args):
137    """
138    Hide/Show the artists of axis ax listed in args.
139    """
140    if "legend" in args and ax.legend():
141        ax.legend().set_visible(boolean)
142    if "title" in args and ax.title:
143        ax.title.set_visible(boolean)
144    if "xlabel" in args and ax.xaxis.label:
145        ax.xaxis.label.set_visible(boolean)
146    if "ylabel" in args and ax.yaxis.label:
147        ax.yaxis.label.set_visible(boolean)
148    if "xticklabels" in args:
149        for label in ax.get_xticklabels():
150            label.set_visible(boolean)
151    if "yticklabels" in args:
152        for label in ax.get_yticklabels():
153            label.set_visible(boolean)
154
155
156def rotate_ticklabels(ax, rotation, axname="x"):
157    """Rotate the ticklables of axis ``ax``"""
158    if "x" in axname:
159        for tick in ax.get_xticklabels():
160            tick.set_rotation(rotation)
161    if "y" in axname:
162        for tick in ax.get_yticklabels():
163            tick.set_rotation(rotation)
164
165
166@add_fig_kwargs
167def plot_xy_with_hue(data, x, y, hue, decimals=None, ax=None,
168                     xlims=None, ylims=None, fontsize=12, **kwargs):
169    """
170    Plot y = f(x) relation for different values of `hue`.
171    Useful for convergence tests done wrt to two parameters.
172
173    Args:
174        data: |pandas-DataFrame| containing columns `x`, `y`, and `hue`.
175        x: Name of the column used as x-value
176        y: Name of the column(s) used as y-value
177        hue: Variable that define subsets of the data, which will be drawn on separate lines
178        decimals: Number of decimal places to round `hue` columns. Ignore if None
179        ax: |matplotlib-Axes| or None if a new figure should be created.
180        xlims ylims: Set the data limits for the x(y)-axis. Accept tuple e.g. `(left, right)`
181            or scalar e.g. `left`. If left (right) is None, default values are used
182        fontsize: Legend fontsize.
183        kwargs: Keywork arguments are passed to ax.plot method.
184
185    Returns: |matplotlib-Figure|
186    """
187    if isinstance(y, (list, tuple)):
188        # Recursive call for each ax in ax_list.
189        num_plots, ncols, nrows = len(y), 1, 1
190        if num_plots > 1:
191            ncols = 2
192            nrows = (num_plots // ncols) + (num_plots % ncols)
193
194        ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
195                                                sharex=False, sharey=False, squeeze=False)
196
197        ax_list = ax_list.ravel()
198        if num_plots % ncols != 0: ax_list[-1].axis('off')
199
200        for yname, ax in zip(y, ax_list):
201            plot_xy_with_hue(data, x, str(yname), hue, decimals=decimals, ax=ax,
202                             xlims=xlims, ylims=ylims, fontsize=fontsize, show=False, **kwargs)
203        return fig
204
205    # Check here because pandas error messages are a bit criptic.
206    miss = [k for k in (x, y, hue) if k not in data]
207    if miss:
208        raise ValueError("Cannot find `%s` in dataframe.\nAvailable keys are: %s" % (str(miss), str(data.keys())))
209
210    # Truncate values in hue column so that we can group.
211    if decimals is not None:
212        data = data.round({hue: decimals})
213
214    ax, fig, plt = get_ax_fig_plt(ax=ax)
215    for key, grp in data.groupby(hue):
216        # Sort xs and rearrange ys
217        xy = np.array(sorted(zip(grp[x], grp[y]), key=lambda t: t[0]))
218        xvals, yvals = xy[:, 0], xy[:, 1]
219
220        #label = "{} = {}".format(hue, key)
221        label = "%s" % (str(key))
222        if not kwargs:
223            ax.plot(xvals, yvals, 'o-', label=label)
224        else:
225            ax.plot(xvals, yvals, label=label, **kwargs)
226
227    ax.grid(True)
228    ax.set_xlabel(x)
229    ax.set_ylabel(y)
230    set_axlims(ax, xlims, "x")
231    set_axlims(ax, ylims, "y")
232    ax.legend(loc="best", fontsize=fontsize, shadow=True)
233
234    return fig
235
236
237@add_fig_kwargs
238def plot_array(array, color_map=None, cplx_mode="abs", **kwargs):
239    """
240    Use imshow for plotting 2D or 1D arrays.
241
242    Example::
243
244        plot_array(np.random.rand(10,10))
245
246    See <http://stackoverflow.com/questions/7229971/2d-grid-data-visualization-in-python>
247
248    Args:
249        array: Array-like object (1D or 2D).
250        color_map: color map.
251        cplx_mode:
252            Flag defining how to handle complex arrays. Possible values in ("re", "im", "abs", "angle")
253            "re" for the real part, "im" for the imaginary part.
254            "abs" means that the absolute value of the complex number is shown.
255            "angle" will display the phase of the complex number in radians.
256
257    Returns: |matplotlib-Figure|
258    """
259    # Handle vectors
260    array = np.atleast_2d(array)
261    array = data_from_cplx_mode(cplx_mode, array)
262
263    import matplotlib as mpl
264    from matplotlib import pyplot as plt
265    if color_map is None:
266        # make a color map of fixed colors
267        color_map = mpl.colors.LinearSegmentedColormap.from_list('my_colormap',
268                                                                 ['blue', 'black', 'red'], 256)
269
270    img = plt.imshow(array, interpolation='nearest', cmap=color_map, origin='lower')
271
272    # Make a color bar
273    plt.colorbar(img, cmap=color_map)
274
275    # Set grid
276    plt.grid(True, color='white')
277
278    fig = plt.gcf()
279    return fig
280
281
282class ArrayPlotter(object):
283
284    def __init__(self, *labels_and_arrays):
285        """
286        Args:
287            labels_and_arrays: List [("label1", arr1), ("label2", arr2")]
288        """
289        self._arr_dict = OrderedDict()
290        for label, array in labels_and_arrays:
291            self.add_array(label, array)
292
293    def __len__(self):
294        return len(self._arr_dict)
295
296    def __iter__(self):
297        return self._arr_dict.__iter__()
298
299    def keys(self):
300        return self._arr_dict.keys()
301
302    def items(self):
303        return self._arr_dict.items()
304
305    def add_array(self, label, array):
306        """Add array with the given name."""
307        if label in self._arr_dict:
308            raise ValueError("%s is already in %s" % (label, list(self._arr_dict.keys())))
309
310        self._arr_dict[label] = array
311
312    def add_arrays(self, labels, arr_list):
313        """
314        Add a list of arrays
315
316        Args:
317            labels: List of labels.
318            arr_list: List of arrays.
319        """
320        assert len(labels) == len(arr_list)
321        for label, arr in zip(labels, arr_list):
322            self.add_array(label, arr)
323
324    @add_fig_kwargs
325    def plot(self, cplx_mode="abs", colormap="jet", fontsize=8, **kwargs):
326        """
327        Args:
328            cplx_mode: "abs" for absolute value, "re", "im", "angle"
329            colormap: matplotlib colormap.
330            fontsize: legend and label fontsize.
331
332        Returns: |matplotlib-Figure|
333        """
334        # Build grid of plots.
335        num_plots, ncols, nrows = len(self), 1, 1
336        if num_plots > 1:
337            ncols = 2
338            nrows = num_plots // ncols + (num_plots % ncols)
339
340        import matplotlib.pyplot as plt
341        fig, ax_mat = plt.subplots(nrows=nrows, ncols=ncols, sharex=False, sharey=False, squeeze=False)
342        # Don't show the last ax if num_plots is odd.
343        if num_plots % ncols != 0: ax_mat[-1, -1].axis("off")
344
345        from mpl_toolkits.axes_grid1 import make_axes_locatable
346        from matplotlib.ticker import MultipleLocator
347
348        for ax, (label, arr) in zip(ax_mat.flat, self.items()):
349            data = data_from_cplx_mode(cplx_mode, arr)
350            # Use origin to place the [0, 0] index of the array in the lower left corner of the axes.
351            img = ax.matshow(data, interpolation='nearest', cmap=colormap, origin='lower', aspect="auto")
352            ax.set_title("(%s) %s" % (cplx_mode, label), fontsize=fontsize)
353
354            # Make a color bar for this ax
355            # Create divider for existing axes instance
356            # http://stackoverflow.com/questions/18266642/multiple-imshow-subplots-each-with-colorbar
357            divider3 = make_axes_locatable(ax)
358            # Append axes to the right of ax, with 10% width of ax
359            cax3 = divider3.append_axes("right", size="10%", pad=0.05)
360            # Create colorbar in the appended axes
361            # Tick locations can be set with the kwarg `ticks`
362            # and the format of the ticklabels with kwarg `format`
363            cbar3 = plt.colorbar(img, cax=cax3, ticks=MultipleLocator(0.2), format="%.2f")
364            # Remove xticks from ax
365            ax.xaxis.set_visible(False)
366            # Manually set ticklocations
367            #ax.set_yticks([0.0, 2.5, 3.14, 4.0, 5.2, 7.0])
368
369            # Set grid
370            ax.grid(True, color='white')
371
372        fig.tight_layout()
373        return fig
374
375
376#TODO use object and introduce c for color, client code should be able to customize it.
377# Rename it to ScatterData
378class Marker(namedtuple("Marker", "x y s")):
379    """
380    Stores the position and the size of the marker.
381    A marker is a list of tuple(x, y, s) where x, and y are the position
382    in the graph and s is the size of the marker.
383    Used for plotting purpose e.g. QP data, energy derivatives...
384
385    Example::
386
387        x, y, s = [1, 2, 3], [4, 5, 6], [0.1, 0.2, -0.3]
388        marker = Marker(x, y, s)
389        marker.extend((x, y, s))
390
391    """
392    def __new__(cls, *xys):
393        """Extends the base class adding consistency check."""
394        if not xys:
395            xys = ([], [], [])
396            return super().__new__(cls, *xys)
397
398        if len(xys) != 3:
399            raise TypeError("Expecting 3 entries in xys got %d" % len(xys))
400
401        x = np.asarray(xys[0])
402        y = np.asarray(xys[1])
403        s = np.asarray(xys[2])
404        xys = (x, y, s)
405
406        for s in xys[-1]:
407            if np.iscomplex(s):
408                raise ValueError("Found ambiguous complex entry %s" % str(s))
409
410        return super().__new__(cls, *xys)
411
412    def __bool__(self):
413        return bool(len(self.s))
414
415    __nonzero__ = __bool__
416
417    def extend(self, xys):
418        """
419        Extend the marker values.
420        """
421        if len(xys) != 3:
422            raise TypeError("Expecting 3 entries in xys got %d" % len(xys))
423
424        self.x.extend(xys[0])
425        self.y.extend(xys[1])
426        self.s.extend(xys[2])
427
428        lens = np.array((len(self.x), len(self.y), len(self.s)))
429        if np.any(lens != lens[0]):
430            raise TypeError("x, y, s vectors should have same lengths but got %s" % str(lens))
431
432    def posneg_marker(self):
433        """
434        Split data into two sets: the first one contains all the points with positive size.
435        The first set contains all the points with negative size.
436        """
437        pos_x, pos_y, pos_s = [], [], []
438        neg_x, neg_y, neg_s = [], [], []
439
440        for x, y, s in zip(self.x, self.y, self.s):
441            if s >= 0.0:
442                pos_x.append(x)
443                pos_y.append(y)
444                pos_s.append(s)
445            else:
446                neg_x.append(x)
447                neg_y.append(y)
448                neg_s.append(s)
449
450        return self.__class__(pos_x, pos_y, pos_s), Marker(neg_x, neg_y, neg_s)
451
452
453class MplExpose(object): # pragma: no cover
454    """
455    Example:
456
457        with MplExpose() as e:
458            e(obj.plot1(show=False))
459            e(obj.plot2(show=False))
460    """
461    def __init__(self, slide_mode=False, slide_timeout=None, verbose=1):
462        """
463        Args:
464            slide_mode: If true, iterate over figures. Default: Expose all figures at once.
465            slide_timeout: Close figure after slide-timeout seconds Block if None.
466            verbose: verbosity level
467        """
468        self.figures = []
469        self.slide_mode = bool(slide_mode)
470        self.timeout_ms = slide_timeout
471        self.verbose = verbose
472        if self.timeout_ms is not None:
473            self.timeout_ms = int(self.timeout_ms * 1000)
474            assert self.timeout_ms >= 0
475
476        if self.verbose:
477            if self.slide_mode:
478                print("\nSliding matplotlib figures with slide timeout: %s [s]" % slide_timeout)
479            else:
480                print("\nLoading all matplotlib figures before showing them. It may take some time...")
481
482        self.start_time = time.time()
483
484    def __call__(self, obj):
485        """
486        Add an object to MplExpose. Support mpl figure, list of figures or
487        generator yielding figures.
488        """
489        import types
490        if isinstance(obj, (types.GeneratorType, list, tuple)):
491            for fig in obj:
492                self.add_fig(fig)
493        else:
494            self.add_fig(obj)
495
496    def add_fig(self, fig):
497        """Add a matplotlib figure."""
498        if fig is None: return
499
500        if not self.slide_mode:
501            self.figures.append(fig)
502        else:
503            #print("Printing and closing", fig)
504            import matplotlib.pyplot as plt
505            if self.timeout_ms is not None:
506                # Creating a timer object
507                # timer calls plt.close after interval milliseconds to close the window.
508                timer = fig.canvas.new_timer(interval=self.timeout_ms)
509                timer.add_callback(plt.close, fig)
510                timer.start()
511
512            plt.show()
513            fig.clear()
514
515    def __enter__(self):
516        return self
517
518    def __exit__(self, exc_type, exc_val, exc_tb):
519        """Activated at the end of the with statement. """
520        self.expose()
521
522    def expose(self):
523        """Show all figures. Clear figures if needed."""
524        if not self.slide_mode:
525            print("All figures in memory, elapsed time: %.3f s" % (time.time() - self.start_time))
526            import matplotlib.pyplot as plt
527            plt.show()
528            for fig in self.figures:
529                fig.clear()
530
531
532def plot_unit_cell(lattice, ax=None, **kwargs):
533    """
534    Adds the unit cell of the lattice to a matplotlib Axes3D
535
536    Args:
537        lattice: Lattice object
538        ax: matplotlib :class:`Axes3D` or None if a new figure should be created.
539        kwargs: kwargs passed to the matplotlib function 'plot'. Color defaults to black
540            and linewidth to 3.
541
542    Returns:
543        matplotlib figure and ax
544    """
545    ax, fig, plt = get_ax3d_fig_plt(ax)
546
547    if "color" not in kwargs: kwargs["color"] = "k"
548    if "linewidth" not in kwargs: kwargs["linewidth"] = 3
549
550    v = 8 * [None]
551    v[0] = lattice.get_cartesian_coords([0.0, 0.0, 0.0])
552    v[1] = lattice.get_cartesian_coords([1.0, 0.0, 0.0])
553    v[2] = lattice.get_cartesian_coords([1.0, 1.0, 0.0])
554    v[3] = lattice.get_cartesian_coords([0.0, 1.0, 0.0])
555    v[4] = lattice.get_cartesian_coords([0.0, 1.0, 1.0])
556    v[5] = lattice.get_cartesian_coords([1.0, 1.0, 1.0])
557    v[6] = lattice.get_cartesian_coords([1.0, 0.0, 1.0])
558    v[7] = lattice.get_cartesian_coords([0.0, 0.0, 1.0])
559
560    for i, j in ((0, 1), (1, 2), (2, 3), (0, 3), (3, 4), (4, 5), (5, 6),
561                 (6, 7), (7, 4), (0, 7), (1, 6), (2, 5), (3, 4)):
562        ax.plot(*zip(v[i], v[j]), **kwargs)
563
564    # Plot cartesian frame
565    ax_add_cartesian_frame(ax)
566
567    return fig, ax
568
569
570def ax_add_cartesian_frame(ax, start=(0, 0, 0)):
571    """
572    Add cartesian frame to 3d axis at point `start`.
573    """
574    # https://stackoverflow.com/questions/22867620/putting-arrowheads-on-vectors-in-matplotlibs-3d-plot
575    from matplotlib.patches import FancyArrowPatch
576    from mpl_toolkits.mplot3d import proj3d
577    arrow_opts = {"color": "k"}
578    arrow_opts.update(dict(lw=1, arrowstyle="-|>",))
579
580    class Arrow3D(FancyArrowPatch):
581        def __init__(self, xs, ys, zs, *args, **kwargs):
582            FancyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs)
583            self._verts3d = xs, ys, zs
584
585        def draw(self, renderer):
586            xs3d, ys3d, zs3d = self._verts3d
587            xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, renderer.M)
588            self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
589            FancyArrowPatch.draw(self, renderer)
590
591    start = np.array(start)
592    for end in ((1, 0, 0), (0, 1, 0), (0, 0, 1)):
593        end = start + np.array(end)
594        xs, ys, zs = list(zip(start, end))
595        p = Arrow3D(xs, ys, zs,
596                   connectionstyle='arc3', mutation_scale=20,
597                   alpha=0.8, **arrow_opts)
598        ax.add_artist(p)
599
600    return ax
601
602
603def plot_structure(structure, ax=None, to_unit_cell=False, alpha=0.7,
604                   style="points+labels", color_scheme="VESTA", **kwargs):
605    """
606    Plot structure with matplotlib (minimalistic version)
607
608    Args:
609        structure: Structure object
610        ax: matplotlib :class:`Axes3D` or None if a new figure should be created.
611        alpha: The alpha blending value, between 0 (transparent) and 1 (opaque)
612        to_unit_cell: True if sites should be wrapped into the first unit cell.
613        style: "points+labels" to show atoms sites with labels.
614        color_scheme: color scheme for atom types. Allowed values in ("Jmol", "VESTA")
615
616    Returns: |matplotlib-Figure|
617    """
618    fig, ax = plot_unit_cell(structure.lattice, ax=ax, linewidth=1)
619
620    from pymatgen.analysis.molecule_structure_comparator import CovalentRadius
621    from pymatgen.vis.structure_vtk import EL_COLORS
622    xyzs, colors = np.empty((len(structure), 4)), []
623
624    for i, site in enumerate(structure):
625        symbol = site.specie.symbol
626        color = tuple(i / 255 for i in EL_COLORS[color_scheme][symbol])
627        radius = CovalentRadius.radius[symbol]
628        if to_unit_cell and hasattr(site, "to_unit_cell"): site = site.to_unit_cell()
629        # Use cartesian coordinates.
630        x, y, z = site.coords
631        xyzs[i] = (x, y, z, radius)
632        colors.append(color)
633        if "labels" in style:
634            ax.text(x, y, z, symbol)
635
636    # The definition of sizes is not optimal because matplotlib uses points
637    # wherease we would like something that depends on the radius (5000 seems to give reasonable plots)
638    # For possibile approaches, see
639    # https://stackoverflow.com/questions/9081553/python-scatter-plot-size-and-style-of-the-marker/24567352#24567352
640    # https://gist.github.com/syrte/592a062c562cd2a98a83
641    if "points" in style:
642        x, y, z, s = xyzs.T.copy()
643        s = 5000 * s ** 2
644        ax.scatter(x, y, zs=z, s=s, c=colors, alpha=alpha)  #facecolors="white", #edgecolors="blue"
645
646    ax.set_title(structure.composition.formula)
647    ax.set_axis_off()
648
649    return fig
650
651
652def _generic_parser_fh(fh):
653    """
654    Parse file with data in tabular format. Supports multi datasets a la gnuplot.
655    Mainly used for files without any schema, not even CSV
656
657    Args:
658        fh: File object
659
660    Returns:
661        OrderedDict title --> numpy array
662        where title is taken from the first (non-empty) line preceding the dataset
663    """
664    arr_list = [None]
665    data = []
666    head_list = []
667    count = -1
668    last_header = None
669    for l in fh:
670        l = l.strip()
671        if not l or l.startswith("#"):
672            count = -1
673            last_header = l
674            if arr_list[-1] is not None: arr_list.append(None)
675            continue
676
677        count += 1
678        if count == 0: head_list.append(last_header)
679        if arr_list[-1] is None: arr_list[-1] = []
680        data = arr_list[-1]
681        data.append(list(map(float, l.split())))
682
683    if len(head_list) != len(arr_list):
684        raise RuntimeError("len(head_list) != len(arr_list), %d != %d" % (len(head_list), len(arr_list)))
685
686    od = OrderedDict()
687    for key, data in zip(head_list, arr_list):
688        key = " ".join(key.split())
689        if key in od:
690            print("Header %s already in dictionary. Using new key %s" % (key, 2 * key))
691            key = 2 * key
692        od[key] = np.array(data).T.copy()
693
694    return od
695
696
697class GenericDataFilePlotter(object):
698    """
699    Extract data from a generic text file with results
700    in tabular format and plot data with matplotlib.
701    Multiple datasets are supported.
702    No attempt is made to handle metadata (e.g. column name)
703    Mainly used to handle text files written without any schema.
704    """
705    def __init__(self, filepath):
706        with open(filepath, "rt") as fh:
707            self.od = _generic_parser_fh(fh)
708
709    def __str__(self):
710        return self.to_string()
711
712    def to_string(self, verbose=0):
713        """String representation with verbosity level `verbose`."""
714        lines = []
715        for key, arr in self.od.items():
716            lines.append("key: `%s` --> array shape: %s" % (key, str(arr.shape)))
717        return "\n".join(lines)
718
719    @add_fig_kwargs
720    def plot(self, use_index=False, fontsize=8, **kwargs):
721        """
722        Plot all arrays. Use multiple axes if datasets.
723
724        Args:
725            use_index: By default, the x-values are taken from the first column.
726                If use_index is False, the x-values are the row index.
727            fontsize: fontsize for title.
728            kwargs: options passed to ``ax.plot``.
729
730        Return: |matplotlib-figure|
731        """
732        # build grid of plots.
733        num_plots, ncols, nrows = len(self.od), 1, 1
734        if num_plots > 1:
735            ncols = 2
736            nrows = (num_plots // ncols) + (num_plots % ncols)
737
738        ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
739                                                sharex=False, sharey=False, squeeze=False)
740        ax_list = ax_list.ravel()
741
742        # Don't show the last ax if num_plots is odd.
743        if num_plots % ncols != 0: ax_list[-1].axis("off")
744
745        for ax, (key, arr) in zip(ax_list, self.od.items()):
746            ax.set_title(key, fontsize=fontsize)
747            ax.grid(True)
748            xs = arr[0] if not use_index else list(range(len(arr[0])))
749            for ys in arr[1:] if not use_index else arr:
750                ax.plot(xs, ys)
751
752        return fig
753
754
755class GenericDataFilesPlotter(object):
756
757    @classmethod
758    def from_files(cls, filepaths):
759        """
760        Build object from a list of `filenames`.
761        """
762        new = cls()
763        for filepath in filepaths:
764            new.add_file(filepath)
765        return new
766
767    def __init__(self):
768        self.odlist = []
769        self.filepaths = []
770
771    def __str__(self):
772        return self.to_string()
773
774    def to_string(self, verbose=0):
775        lines = []
776        app = lines.append
777        for od, filepath in zip(self.odlist, self.filepaths):
778            app("File: %s" % filepath)
779            for key, arr in od.items():
780                lines.append("\tkey: `%s` --> array shape: %s" % (key, str(arr.shape)))
781
782        return "\n".join(lines)
783
784    def add_file(self, filepath):
785        """Add data from `filepath`"""
786        with open(filepath, "rt") as fh:
787            self.odlist.append(_generic_parser_fh(fh))
788            self.filepaths.append(filepath)
789
790    @add_fig_kwargs
791    def plot(self, use_index=False, fontsize=8, colormap="viridis", **kwargs):
792        """
793        Plot all arrays. Use multiple axes if datasets.
794
795        Args:
796            use_index: By default, the x-values are taken from the first column.
797                If use_index is False, the x-values are the row index.
798            fontsize: fontsize for title.
799            colormap: matplotlib color map.
800            kwargs: options passed to ``ax.plot``.
801
802        Return: |matplotlib-figure|
803        """
804        if not self.odlist: return None
805
806        # Compute intersection of all keys.
807        # Here we loose the initial ordering in the dict but oh well!
808        klist = [list(d.keys()) for d in self.odlist]
809        keys = set(klist[0]).intersection(*klist)
810        if not keys:
811            print("Warning: cannot find common keys in files. Check input data")
812            return None
813
814        # Build grid of plots.
815        num_plots, ncols, nrows = len(keys), 1, 1
816        if num_plots > 1:
817            ncols = 2
818            nrows = (num_plots // ncols) + (num_plots % ncols)
819
820        ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
821                                                sharex=False, sharey=False, squeeze=False)
822        ax_list = ax_list.ravel()
823
824        # Don't show the last ax if num_plots is odd.
825        if num_plots % ncols != 0: ax_list[-1].axis("off")
826
827        cmap = plt.get_cmap(colormap)
828        line_cycle = itertools.cycle(["-", ":", "--", "-.",])
829
830        # One ax for key, each ax may show multiple arrays
831        # so we need different line styles that are consistent with input data.
832        # Figure may be crowded but it's difficult to do better without metadata
833        # so I'm not gonna spend time to implement more complicated logic.
834        for ax, key in zip(ax_list, keys):
835            ax.set_title(key, fontsize=fontsize)
836            ax.grid(True)
837            for iod, (od, filepath) in enumerate(zip(self.odlist, self.filepaths)):
838                if key not in od: continue
839                arr = od[key]
840                color = cmap(iod / len(self.odlist))
841                xvals = arr[0] if not use_index else list(range(len(arr[0])))
842                arr_list = arr[1:] if not use_index else arr
843                for iarr, (ys, linestyle) in enumerate(zip(arr_list, line_cycle)):
844                    ax.plot(xvals, ys, color=color, linestyle=linestyle,
845                            label=os.path.relpath(filepath) if iarr == 0 else None)
846
847            ax.legend(loc="best", fontsize=fontsize, shadow=True)
848
849        return fig
850