1# coding: utf-8 2""" 3Utilities for generating matplotlib plots. 4 5.. note:: 6 7 Avoid importing matplotlib in the module namespace otherwise startup is very slow. 8""" 9import os 10import time 11import itertools 12import numpy as np 13 14from collections import OrderedDict, namedtuple 15from pymatgen.util.plotting import add_fig_kwargs, get_ax_fig_plt, get_ax3d_fig_plt, get_axarray_fig_plt 16from .numtools import data_from_cplx_mode 17 18 19__all__ = [ 20 "set_axlims", 21 "get_ax_fig_plt", 22 "get_ax3d_fig_plt", 23 "plot_array", 24 "ArrayPlotter", 25 "data_from_cplx_mode", 26 "Marker", 27 "plot_unit_cell", 28 "GenericDataFilePlotter", 29 "GenericDataFilesPlotter", 30] 31 32 33# https://matplotlib.org/gallery/lines_bars_and_markers/linestyles.html 34linestyles = OrderedDict( 35 [('solid', (0, ())), 36 ('loosely_dotted', (0, (1, 10))), 37 ('dotted', (0, (1, 5))), 38 ('densely_dotted', (0, (1, 1))), 39 40 ('loosely_dashed', (0, (5, 10))), 41 ('dashed', (0, (5, 5))), 42 ('densely_dashed', (0, (5, 1))), 43 44 ('loosely_dashdotted', (0, (3, 10, 1, 10))), 45 ('dashdotted', (0, (3, 5, 1, 5))), 46 ('densely_dashdotted', (0, (3, 1, 1, 1))), 47 48 ('loosely_dashdotdotted', (0, (3, 10, 1, 10, 1, 10))), 49 ('dashdotdotted', (0, (3, 5, 1, 5, 1, 5))), 50 ('densely_dashdotdotted', (0, (3, 1, 1, 1, 1, 1)))] 51) 52 53 54def ax_append_title(ax, title, loc="center", fontsize=None): 55 """Add title to previous ax.title. Return new title.""" 56 prev_title = ax.get_title(loc=loc) 57 new_title = prev_title + title 58 ax.set_title(new_title, loc=loc, fontsize=fontsize) 59 return new_title 60 61 62def ax_share(xy_string, *ax_list): 63 """ 64 Share x- or y-axis of two or more subplots after they are created 65 66 Args: 67 xy_string: "x" to share x-axis, "xy" for both 68 ax_list: List of axes to share. 69 70 Example: 71 72 ax_share("y", ax0, ax1) 73 ax_share("xy", *(ax0, ax1, ax2)) 74 """ 75 if "x" in xy_string: 76 for ix, ax in enumerate(ax_list): 77 others = [a for a in ax_list if a != ax] 78 ax.get_shared_x_axes().join(*others) 79 80 if "y" in xy_string: 81 for ix, ax in enumerate(ax_list): 82 others = [a for a in ax_list if a != ax] 83 ax.get_shared_y_axes().join(*others) 84 85 86#def set_grid(fig, boolean): 87# if hasattr(fig, "axes"): 88# for ax in fig.axes: 89# if ax.grid: ax.grid.set_visible(boolean) 90# else: 91# if ax.grid: ax.grid.set_visible(boolean) 92 93 94def set_axlims(ax, lims, axname): 95 """ 96 Set the data limits for the axis ax. 97 98 Args: 99 lims: tuple(2) for (left, right), tuple(1) or scalar for left only. 100 axname: "x" for x-axis, "y" for y-axis. 101 102 Return: (left, right) 103 """ 104 left, right = None, None 105 if lims is None: return (left, right) 106 107 len_lims = None 108 try: 109 len_lims = len(lims) 110 except TypeError: 111 # Assume Scalar 112 left = float(lims) 113 114 if len_lims is not None: 115 if len(lims) == 2: 116 left, right = lims[0], lims[1] 117 elif len(lims) == 1: 118 left = lims[0] 119 120 set_lim = getattr(ax, {"x": "set_xlim", "y": "set_ylim"}[axname]) 121 if left != right: 122 set_lim(left, right) 123 124 return left, right 125 126 127def set_ax_xylabels(ax, xlabel, ylabel, exchange_xy): 128 """ 129 Set the x- and the y-label of axis ax, exchanging x and y if exchange_xy 130 """ 131 if exchange_xy: xlabel, ylabel = ylabel, xlabel 132 ax.set_xlabel(xlabel) 133 ax.set_ylabel(ylabel) 134 135 136def set_visible(ax, boolean, *args): 137 """ 138 Hide/Show the artists of axis ax listed in args. 139 """ 140 if "legend" in args and ax.legend(): 141 ax.legend().set_visible(boolean) 142 if "title" in args and ax.title: 143 ax.title.set_visible(boolean) 144 if "xlabel" in args and ax.xaxis.label: 145 ax.xaxis.label.set_visible(boolean) 146 if "ylabel" in args and ax.yaxis.label: 147 ax.yaxis.label.set_visible(boolean) 148 if "xticklabels" in args: 149 for label in ax.get_xticklabels(): 150 label.set_visible(boolean) 151 if "yticklabels" in args: 152 for label in ax.get_yticklabels(): 153 label.set_visible(boolean) 154 155 156def rotate_ticklabels(ax, rotation, axname="x"): 157 """Rotate the ticklables of axis ``ax``""" 158 if "x" in axname: 159 for tick in ax.get_xticklabels(): 160 tick.set_rotation(rotation) 161 if "y" in axname: 162 for tick in ax.get_yticklabels(): 163 tick.set_rotation(rotation) 164 165 166@add_fig_kwargs 167def plot_xy_with_hue(data, x, y, hue, decimals=None, ax=None, 168 xlims=None, ylims=None, fontsize=12, **kwargs): 169 """ 170 Plot y = f(x) relation for different values of `hue`. 171 Useful for convergence tests done wrt to two parameters. 172 173 Args: 174 data: |pandas-DataFrame| containing columns `x`, `y`, and `hue`. 175 x: Name of the column used as x-value 176 y: Name of the column(s) used as y-value 177 hue: Variable that define subsets of the data, which will be drawn on separate lines 178 decimals: Number of decimal places to round `hue` columns. Ignore if None 179 ax: |matplotlib-Axes| or None if a new figure should be created. 180 xlims ylims: Set the data limits for the x(y)-axis. Accept tuple e.g. `(left, right)` 181 or scalar e.g. `left`. If left (right) is None, default values are used 182 fontsize: Legend fontsize. 183 kwargs: Keywork arguments are passed to ax.plot method. 184 185 Returns: |matplotlib-Figure| 186 """ 187 if isinstance(y, (list, tuple)): 188 # Recursive call for each ax in ax_list. 189 num_plots, ncols, nrows = len(y), 1, 1 190 if num_plots > 1: 191 ncols = 2 192 nrows = (num_plots // ncols) + (num_plots % ncols) 193 194 ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols, 195 sharex=False, sharey=False, squeeze=False) 196 197 ax_list = ax_list.ravel() 198 if num_plots % ncols != 0: ax_list[-1].axis('off') 199 200 for yname, ax in zip(y, ax_list): 201 plot_xy_with_hue(data, x, str(yname), hue, decimals=decimals, ax=ax, 202 xlims=xlims, ylims=ylims, fontsize=fontsize, show=False, **kwargs) 203 return fig 204 205 # Check here because pandas error messages are a bit criptic. 206 miss = [k for k in (x, y, hue) if k not in data] 207 if miss: 208 raise ValueError("Cannot find `%s` in dataframe.\nAvailable keys are: %s" % (str(miss), str(data.keys()))) 209 210 # Truncate values in hue column so that we can group. 211 if decimals is not None: 212 data = data.round({hue: decimals}) 213 214 ax, fig, plt = get_ax_fig_plt(ax=ax) 215 for key, grp in data.groupby(hue): 216 # Sort xs and rearrange ys 217 xy = np.array(sorted(zip(grp[x], grp[y]), key=lambda t: t[0])) 218 xvals, yvals = xy[:, 0], xy[:, 1] 219 220 #label = "{} = {}".format(hue, key) 221 label = "%s" % (str(key)) 222 if not kwargs: 223 ax.plot(xvals, yvals, 'o-', label=label) 224 else: 225 ax.plot(xvals, yvals, label=label, **kwargs) 226 227 ax.grid(True) 228 ax.set_xlabel(x) 229 ax.set_ylabel(y) 230 set_axlims(ax, xlims, "x") 231 set_axlims(ax, ylims, "y") 232 ax.legend(loc="best", fontsize=fontsize, shadow=True) 233 234 return fig 235 236 237@add_fig_kwargs 238def plot_array(array, color_map=None, cplx_mode="abs", **kwargs): 239 """ 240 Use imshow for plotting 2D or 1D arrays. 241 242 Example:: 243 244 plot_array(np.random.rand(10,10)) 245 246 See <http://stackoverflow.com/questions/7229971/2d-grid-data-visualization-in-python> 247 248 Args: 249 array: Array-like object (1D or 2D). 250 color_map: color map. 251 cplx_mode: 252 Flag defining how to handle complex arrays. Possible values in ("re", "im", "abs", "angle") 253 "re" for the real part, "im" for the imaginary part. 254 "abs" means that the absolute value of the complex number is shown. 255 "angle" will display the phase of the complex number in radians. 256 257 Returns: |matplotlib-Figure| 258 """ 259 # Handle vectors 260 array = np.atleast_2d(array) 261 array = data_from_cplx_mode(cplx_mode, array) 262 263 import matplotlib as mpl 264 from matplotlib import pyplot as plt 265 if color_map is None: 266 # make a color map of fixed colors 267 color_map = mpl.colors.LinearSegmentedColormap.from_list('my_colormap', 268 ['blue', 'black', 'red'], 256) 269 270 img = plt.imshow(array, interpolation='nearest', cmap=color_map, origin='lower') 271 272 # Make a color bar 273 plt.colorbar(img, cmap=color_map) 274 275 # Set grid 276 plt.grid(True, color='white') 277 278 fig = plt.gcf() 279 return fig 280 281 282class ArrayPlotter(object): 283 284 def __init__(self, *labels_and_arrays): 285 """ 286 Args: 287 labels_and_arrays: List [("label1", arr1), ("label2", arr2")] 288 """ 289 self._arr_dict = OrderedDict() 290 for label, array in labels_and_arrays: 291 self.add_array(label, array) 292 293 def __len__(self): 294 return len(self._arr_dict) 295 296 def __iter__(self): 297 return self._arr_dict.__iter__() 298 299 def keys(self): 300 return self._arr_dict.keys() 301 302 def items(self): 303 return self._arr_dict.items() 304 305 def add_array(self, label, array): 306 """Add array with the given name.""" 307 if label in self._arr_dict: 308 raise ValueError("%s is already in %s" % (label, list(self._arr_dict.keys()))) 309 310 self._arr_dict[label] = array 311 312 def add_arrays(self, labels, arr_list): 313 """ 314 Add a list of arrays 315 316 Args: 317 labels: List of labels. 318 arr_list: List of arrays. 319 """ 320 assert len(labels) == len(arr_list) 321 for label, arr in zip(labels, arr_list): 322 self.add_array(label, arr) 323 324 @add_fig_kwargs 325 def plot(self, cplx_mode="abs", colormap="jet", fontsize=8, **kwargs): 326 """ 327 Args: 328 cplx_mode: "abs" for absolute value, "re", "im", "angle" 329 colormap: matplotlib colormap. 330 fontsize: legend and label fontsize. 331 332 Returns: |matplotlib-Figure| 333 """ 334 # Build grid of plots. 335 num_plots, ncols, nrows = len(self), 1, 1 336 if num_plots > 1: 337 ncols = 2 338 nrows = num_plots // ncols + (num_plots % ncols) 339 340 import matplotlib.pyplot as plt 341 fig, ax_mat = plt.subplots(nrows=nrows, ncols=ncols, sharex=False, sharey=False, squeeze=False) 342 # Don't show the last ax if num_plots is odd. 343 if num_plots % ncols != 0: ax_mat[-1, -1].axis("off") 344 345 from mpl_toolkits.axes_grid1 import make_axes_locatable 346 from matplotlib.ticker import MultipleLocator 347 348 for ax, (label, arr) in zip(ax_mat.flat, self.items()): 349 data = data_from_cplx_mode(cplx_mode, arr) 350 # Use origin to place the [0, 0] index of the array in the lower left corner of the axes. 351 img = ax.matshow(data, interpolation='nearest', cmap=colormap, origin='lower', aspect="auto") 352 ax.set_title("(%s) %s" % (cplx_mode, label), fontsize=fontsize) 353 354 # Make a color bar for this ax 355 # Create divider for existing axes instance 356 # http://stackoverflow.com/questions/18266642/multiple-imshow-subplots-each-with-colorbar 357 divider3 = make_axes_locatable(ax) 358 # Append axes to the right of ax, with 10% width of ax 359 cax3 = divider3.append_axes("right", size="10%", pad=0.05) 360 # Create colorbar in the appended axes 361 # Tick locations can be set with the kwarg `ticks` 362 # and the format of the ticklabels with kwarg `format` 363 cbar3 = plt.colorbar(img, cax=cax3, ticks=MultipleLocator(0.2), format="%.2f") 364 # Remove xticks from ax 365 ax.xaxis.set_visible(False) 366 # Manually set ticklocations 367 #ax.set_yticks([0.0, 2.5, 3.14, 4.0, 5.2, 7.0]) 368 369 # Set grid 370 ax.grid(True, color='white') 371 372 fig.tight_layout() 373 return fig 374 375 376#TODO use object and introduce c for color, client code should be able to customize it. 377# Rename it to ScatterData 378class Marker(namedtuple("Marker", "x y s")): 379 """ 380 Stores the position and the size of the marker. 381 A marker is a list of tuple(x, y, s) where x, and y are the position 382 in the graph and s is the size of the marker. 383 Used for plotting purpose e.g. QP data, energy derivatives... 384 385 Example:: 386 387 x, y, s = [1, 2, 3], [4, 5, 6], [0.1, 0.2, -0.3] 388 marker = Marker(x, y, s) 389 marker.extend((x, y, s)) 390 391 """ 392 def __new__(cls, *xys): 393 """Extends the base class adding consistency check.""" 394 if not xys: 395 xys = ([], [], []) 396 return super().__new__(cls, *xys) 397 398 if len(xys) != 3: 399 raise TypeError("Expecting 3 entries in xys got %d" % len(xys)) 400 401 x = np.asarray(xys[0]) 402 y = np.asarray(xys[1]) 403 s = np.asarray(xys[2]) 404 xys = (x, y, s) 405 406 for s in xys[-1]: 407 if np.iscomplex(s): 408 raise ValueError("Found ambiguous complex entry %s" % str(s)) 409 410 return super().__new__(cls, *xys) 411 412 def __bool__(self): 413 return bool(len(self.s)) 414 415 __nonzero__ = __bool__ 416 417 def extend(self, xys): 418 """ 419 Extend the marker values. 420 """ 421 if len(xys) != 3: 422 raise TypeError("Expecting 3 entries in xys got %d" % len(xys)) 423 424 self.x.extend(xys[0]) 425 self.y.extend(xys[1]) 426 self.s.extend(xys[2]) 427 428 lens = np.array((len(self.x), len(self.y), len(self.s))) 429 if np.any(lens != lens[0]): 430 raise TypeError("x, y, s vectors should have same lengths but got %s" % str(lens)) 431 432 def posneg_marker(self): 433 """ 434 Split data into two sets: the first one contains all the points with positive size. 435 The first set contains all the points with negative size. 436 """ 437 pos_x, pos_y, pos_s = [], [], [] 438 neg_x, neg_y, neg_s = [], [], [] 439 440 for x, y, s in zip(self.x, self.y, self.s): 441 if s >= 0.0: 442 pos_x.append(x) 443 pos_y.append(y) 444 pos_s.append(s) 445 else: 446 neg_x.append(x) 447 neg_y.append(y) 448 neg_s.append(s) 449 450 return self.__class__(pos_x, pos_y, pos_s), Marker(neg_x, neg_y, neg_s) 451 452 453class MplExpose(object): # pragma: no cover 454 """ 455 Example: 456 457 with MplExpose() as e: 458 e(obj.plot1(show=False)) 459 e(obj.plot2(show=False)) 460 """ 461 def __init__(self, slide_mode=False, slide_timeout=None, verbose=1): 462 """ 463 Args: 464 slide_mode: If true, iterate over figures. Default: Expose all figures at once. 465 slide_timeout: Close figure after slide-timeout seconds Block if None. 466 verbose: verbosity level 467 """ 468 self.figures = [] 469 self.slide_mode = bool(slide_mode) 470 self.timeout_ms = slide_timeout 471 self.verbose = verbose 472 if self.timeout_ms is not None: 473 self.timeout_ms = int(self.timeout_ms * 1000) 474 assert self.timeout_ms >= 0 475 476 if self.verbose: 477 if self.slide_mode: 478 print("\nSliding matplotlib figures with slide timeout: %s [s]" % slide_timeout) 479 else: 480 print("\nLoading all matplotlib figures before showing them. It may take some time...") 481 482 self.start_time = time.time() 483 484 def __call__(self, obj): 485 """ 486 Add an object to MplExpose. Support mpl figure, list of figures or 487 generator yielding figures. 488 """ 489 import types 490 if isinstance(obj, (types.GeneratorType, list, tuple)): 491 for fig in obj: 492 self.add_fig(fig) 493 else: 494 self.add_fig(obj) 495 496 def add_fig(self, fig): 497 """Add a matplotlib figure.""" 498 if fig is None: return 499 500 if not self.slide_mode: 501 self.figures.append(fig) 502 else: 503 #print("Printing and closing", fig) 504 import matplotlib.pyplot as plt 505 if self.timeout_ms is not None: 506 # Creating a timer object 507 # timer calls plt.close after interval milliseconds to close the window. 508 timer = fig.canvas.new_timer(interval=self.timeout_ms) 509 timer.add_callback(plt.close, fig) 510 timer.start() 511 512 plt.show() 513 fig.clear() 514 515 def __enter__(self): 516 return self 517 518 def __exit__(self, exc_type, exc_val, exc_tb): 519 """Activated at the end of the with statement. """ 520 self.expose() 521 522 def expose(self): 523 """Show all figures. Clear figures if needed.""" 524 if not self.slide_mode: 525 print("All figures in memory, elapsed time: %.3f s" % (time.time() - self.start_time)) 526 import matplotlib.pyplot as plt 527 plt.show() 528 for fig in self.figures: 529 fig.clear() 530 531 532def plot_unit_cell(lattice, ax=None, **kwargs): 533 """ 534 Adds the unit cell of the lattice to a matplotlib Axes3D 535 536 Args: 537 lattice: Lattice object 538 ax: matplotlib :class:`Axes3D` or None if a new figure should be created. 539 kwargs: kwargs passed to the matplotlib function 'plot'. Color defaults to black 540 and linewidth to 3. 541 542 Returns: 543 matplotlib figure and ax 544 """ 545 ax, fig, plt = get_ax3d_fig_plt(ax) 546 547 if "color" not in kwargs: kwargs["color"] = "k" 548 if "linewidth" not in kwargs: kwargs["linewidth"] = 3 549 550 v = 8 * [None] 551 v[0] = lattice.get_cartesian_coords([0.0, 0.0, 0.0]) 552 v[1] = lattice.get_cartesian_coords([1.0, 0.0, 0.0]) 553 v[2] = lattice.get_cartesian_coords([1.0, 1.0, 0.0]) 554 v[3] = lattice.get_cartesian_coords([0.0, 1.0, 0.0]) 555 v[4] = lattice.get_cartesian_coords([0.0, 1.0, 1.0]) 556 v[5] = lattice.get_cartesian_coords([1.0, 1.0, 1.0]) 557 v[6] = lattice.get_cartesian_coords([1.0, 0.0, 1.0]) 558 v[7] = lattice.get_cartesian_coords([0.0, 0.0, 1.0]) 559 560 for i, j in ((0, 1), (1, 2), (2, 3), (0, 3), (3, 4), (4, 5), (5, 6), 561 (6, 7), (7, 4), (0, 7), (1, 6), (2, 5), (3, 4)): 562 ax.plot(*zip(v[i], v[j]), **kwargs) 563 564 # Plot cartesian frame 565 ax_add_cartesian_frame(ax) 566 567 return fig, ax 568 569 570def ax_add_cartesian_frame(ax, start=(0, 0, 0)): 571 """ 572 Add cartesian frame to 3d axis at point `start`. 573 """ 574 # https://stackoverflow.com/questions/22867620/putting-arrowheads-on-vectors-in-matplotlibs-3d-plot 575 from matplotlib.patches import FancyArrowPatch 576 from mpl_toolkits.mplot3d import proj3d 577 arrow_opts = {"color": "k"} 578 arrow_opts.update(dict(lw=1, arrowstyle="-|>",)) 579 580 class Arrow3D(FancyArrowPatch): 581 def __init__(self, xs, ys, zs, *args, **kwargs): 582 FancyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs) 583 self._verts3d = xs, ys, zs 584 585 def draw(self, renderer): 586 xs3d, ys3d, zs3d = self._verts3d 587 xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, renderer.M) 588 self.set_positions((xs[0], ys[0]), (xs[1], ys[1])) 589 FancyArrowPatch.draw(self, renderer) 590 591 start = np.array(start) 592 for end in ((1, 0, 0), (0, 1, 0), (0, 0, 1)): 593 end = start + np.array(end) 594 xs, ys, zs = list(zip(start, end)) 595 p = Arrow3D(xs, ys, zs, 596 connectionstyle='arc3', mutation_scale=20, 597 alpha=0.8, **arrow_opts) 598 ax.add_artist(p) 599 600 return ax 601 602 603def plot_structure(structure, ax=None, to_unit_cell=False, alpha=0.7, 604 style="points+labels", color_scheme="VESTA", **kwargs): 605 """ 606 Plot structure with matplotlib (minimalistic version) 607 608 Args: 609 structure: Structure object 610 ax: matplotlib :class:`Axes3D` or None if a new figure should be created. 611 alpha: The alpha blending value, between 0 (transparent) and 1 (opaque) 612 to_unit_cell: True if sites should be wrapped into the first unit cell. 613 style: "points+labels" to show atoms sites with labels. 614 color_scheme: color scheme for atom types. Allowed values in ("Jmol", "VESTA") 615 616 Returns: |matplotlib-Figure| 617 """ 618 fig, ax = plot_unit_cell(structure.lattice, ax=ax, linewidth=1) 619 620 from pymatgen.analysis.molecule_structure_comparator import CovalentRadius 621 from pymatgen.vis.structure_vtk import EL_COLORS 622 xyzs, colors = np.empty((len(structure), 4)), [] 623 624 for i, site in enumerate(structure): 625 symbol = site.specie.symbol 626 color = tuple(i / 255 for i in EL_COLORS[color_scheme][symbol]) 627 radius = CovalentRadius.radius[symbol] 628 if to_unit_cell and hasattr(site, "to_unit_cell"): site = site.to_unit_cell() 629 # Use cartesian coordinates. 630 x, y, z = site.coords 631 xyzs[i] = (x, y, z, radius) 632 colors.append(color) 633 if "labels" in style: 634 ax.text(x, y, z, symbol) 635 636 # The definition of sizes is not optimal because matplotlib uses points 637 # wherease we would like something that depends on the radius (5000 seems to give reasonable plots) 638 # For possibile approaches, see 639 # https://stackoverflow.com/questions/9081553/python-scatter-plot-size-and-style-of-the-marker/24567352#24567352 640 # https://gist.github.com/syrte/592a062c562cd2a98a83 641 if "points" in style: 642 x, y, z, s = xyzs.T.copy() 643 s = 5000 * s ** 2 644 ax.scatter(x, y, zs=z, s=s, c=colors, alpha=alpha) #facecolors="white", #edgecolors="blue" 645 646 ax.set_title(structure.composition.formula) 647 ax.set_axis_off() 648 649 return fig 650 651 652def _generic_parser_fh(fh): 653 """ 654 Parse file with data in tabular format. Supports multi datasets a la gnuplot. 655 Mainly used for files without any schema, not even CSV 656 657 Args: 658 fh: File object 659 660 Returns: 661 OrderedDict title --> numpy array 662 where title is taken from the first (non-empty) line preceding the dataset 663 """ 664 arr_list = [None] 665 data = [] 666 head_list = [] 667 count = -1 668 last_header = None 669 for l in fh: 670 l = l.strip() 671 if not l or l.startswith("#"): 672 count = -1 673 last_header = l 674 if arr_list[-1] is not None: arr_list.append(None) 675 continue 676 677 count += 1 678 if count == 0: head_list.append(last_header) 679 if arr_list[-1] is None: arr_list[-1] = [] 680 data = arr_list[-1] 681 data.append(list(map(float, l.split()))) 682 683 if len(head_list) != len(arr_list): 684 raise RuntimeError("len(head_list) != len(arr_list), %d != %d" % (len(head_list), len(arr_list))) 685 686 od = OrderedDict() 687 for key, data in zip(head_list, arr_list): 688 key = " ".join(key.split()) 689 if key in od: 690 print("Header %s already in dictionary. Using new key %s" % (key, 2 * key)) 691 key = 2 * key 692 od[key] = np.array(data).T.copy() 693 694 return od 695 696 697class GenericDataFilePlotter(object): 698 """ 699 Extract data from a generic text file with results 700 in tabular format and plot data with matplotlib. 701 Multiple datasets are supported. 702 No attempt is made to handle metadata (e.g. column name) 703 Mainly used to handle text files written without any schema. 704 """ 705 def __init__(self, filepath): 706 with open(filepath, "rt") as fh: 707 self.od = _generic_parser_fh(fh) 708 709 def __str__(self): 710 return self.to_string() 711 712 def to_string(self, verbose=0): 713 """String representation with verbosity level `verbose`.""" 714 lines = [] 715 for key, arr in self.od.items(): 716 lines.append("key: `%s` --> array shape: %s" % (key, str(arr.shape))) 717 return "\n".join(lines) 718 719 @add_fig_kwargs 720 def plot(self, use_index=False, fontsize=8, **kwargs): 721 """ 722 Plot all arrays. Use multiple axes if datasets. 723 724 Args: 725 use_index: By default, the x-values are taken from the first column. 726 If use_index is False, the x-values are the row index. 727 fontsize: fontsize for title. 728 kwargs: options passed to ``ax.plot``. 729 730 Return: |matplotlib-figure| 731 """ 732 # build grid of plots. 733 num_plots, ncols, nrows = len(self.od), 1, 1 734 if num_plots > 1: 735 ncols = 2 736 nrows = (num_plots // ncols) + (num_plots % ncols) 737 738 ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols, 739 sharex=False, sharey=False, squeeze=False) 740 ax_list = ax_list.ravel() 741 742 # Don't show the last ax if num_plots is odd. 743 if num_plots % ncols != 0: ax_list[-1].axis("off") 744 745 for ax, (key, arr) in zip(ax_list, self.od.items()): 746 ax.set_title(key, fontsize=fontsize) 747 ax.grid(True) 748 xs = arr[0] if not use_index else list(range(len(arr[0]))) 749 for ys in arr[1:] if not use_index else arr: 750 ax.plot(xs, ys) 751 752 return fig 753 754 755class GenericDataFilesPlotter(object): 756 757 @classmethod 758 def from_files(cls, filepaths): 759 """ 760 Build object from a list of `filenames`. 761 """ 762 new = cls() 763 for filepath in filepaths: 764 new.add_file(filepath) 765 return new 766 767 def __init__(self): 768 self.odlist = [] 769 self.filepaths = [] 770 771 def __str__(self): 772 return self.to_string() 773 774 def to_string(self, verbose=0): 775 lines = [] 776 app = lines.append 777 for od, filepath in zip(self.odlist, self.filepaths): 778 app("File: %s" % filepath) 779 for key, arr in od.items(): 780 lines.append("\tkey: `%s` --> array shape: %s" % (key, str(arr.shape))) 781 782 return "\n".join(lines) 783 784 def add_file(self, filepath): 785 """Add data from `filepath`""" 786 with open(filepath, "rt") as fh: 787 self.odlist.append(_generic_parser_fh(fh)) 788 self.filepaths.append(filepath) 789 790 @add_fig_kwargs 791 def plot(self, use_index=False, fontsize=8, colormap="viridis", **kwargs): 792 """ 793 Plot all arrays. Use multiple axes if datasets. 794 795 Args: 796 use_index: By default, the x-values are taken from the first column. 797 If use_index is False, the x-values are the row index. 798 fontsize: fontsize for title. 799 colormap: matplotlib color map. 800 kwargs: options passed to ``ax.plot``. 801 802 Return: |matplotlib-figure| 803 """ 804 if not self.odlist: return None 805 806 # Compute intersection of all keys. 807 # Here we loose the initial ordering in the dict but oh well! 808 klist = [list(d.keys()) for d in self.odlist] 809 keys = set(klist[0]).intersection(*klist) 810 if not keys: 811 print("Warning: cannot find common keys in files. Check input data") 812 return None 813 814 # Build grid of plots. 815 num_plots, ncols, nrows = len(keys), 1, 1 816 if num_plots > 1: 817 ncols = 2 818 nrows = (num_plots // ncols) + (num_plots % ncols) 819 820 ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols, 821 sharex=False, sharey=False, squeeze=False) 822 ax_list = ax_list.ravel() 823 824 # Don't show the last ax if num_plots is odd. 825 if num_plots % ncols != 0: ax_list[-1].axis("off") 826 827 cmap = plt.get_cmap(colormap) 828 line_cycle = itertools.cycle(["-", ":", "--", "-.",]) 829 830 # One ax for key, each ax may show multiple arrays 831 # so we need different line styles that are consistent with input data. 832 # Figure may be crowded but it's difficult to do better without metadata 833 # so I'm not gonna spend time to implement more complicated logic. 834 for ax, key in zip(ax_list, keys): 835 ax.set_title(key, fontsize=fontsize) 836 ax.grid(True) 837 for iod, (od, filepath) in enumerate(zip(self.odlist, self.filepaths)): 838 if key not in od: continue 839 arr = od[key] 840 color = cmap(iod / len(self.odlist)) 841 xvals = arr[0] if not use_index else list(range(len(arr[0]))) 842 arr_list = arr[1:] if not use_index else arr 843 for iarr, (ys, linestyle) in enumerate(zip(arr_list, line_cycle)): 844 ax.plot(xvals, ys, color=color, linestyle=linestyle, 845 label=os.path.relpath(filepath) if iarr == 0 else None) 846 847 ax.legend(loc="best", fontsize=fontsize, shadow=True) 848 849 return fig 850