1# coding: utf-8 2# Copyright (c) Pymatgen Development Team. 3# Distributed under the terms of the MIT License. 4""" 5Utilities for generating nicer plots. 6""" 7import math 8import sys 9from matplotlib import colors, cm 10 11import numpy as np 12 13from pymatgen.core.periodic_table import Element 14 15if sys.version_info >= (3, 8): 16 from typing import Literal 17else: 18 from typing_extensions import Literal 19 20 21def pretty_plot(width=8, height=None, plt=None, dpi=None, color_cycle=("qualitative", "Set1_9")): 22 """ 23 Provides a publication quality plot, with nice defaults for font sizes etc. 24 25 Args: 26 width (float): Width of plot in inches. Defaults to 8in. 27 height (float): Height of plot in inches. Defaults to width * golden 28 ratio. 29 plt (matplotlib.pyplot): If plt is supplied, changes will be made to an 30 existing plot. Otherwise, a new plot will be created. 31 dpi (int): Sets dot per inch for figure. Defaults to 300. 32 color_cycle (tuple): Set the color cycle for new plots to one of the 33 color sets in palettable. Defaults to a qualitative Set1_9. 34 35 Returns: 36 Matplotlib plot object with properly sized fonts. 37 """ 38 ticksize = int(width * 2.5) 39 40 golden_ratio = (math.sqrt(5) - 1) / 2 41 42 if not height: 43 height = int(width * golden_ratio) 44 45 if plt is None: 46 import importlib 47 48 import matplotlib.pyplot as plt 49 50 mod = importlib.import_module("palettable.colorbrewer.%s" % color_cycle[0]) 51 colors = getattr(mod, color_cycle[1]).mpl_colors 52 from cycler import cycler 53 54 plt.figure(figsize=(width, height), facecolor="w", dpi=dpi) 55 ax = plt.gca() 56 ax.set_prop_cycle(cycler("color", colors)) 57 else: 58 fig = plt.gcf() 59 fig.set_size_inches(width, height) 60 plt.xticks(fontsize=ticksize) 61 plt.yticks(fontsize=ticksize) 62 63 ax = plt.gca() 64 ax.set_title(ax.get_title(), size=width * 4) 65 66 labelsize = int(width * 3) 67 68 ax.set_xlabel(ax.get_xlabel(), size=labelsize) 69 ax.set_ylabel(ax.get_ylabel(), size=labelsize) 70 71 return plt 72 73 74def pretty_plot_two_axis( 75 x, y1, y2, xlabel=None, y1label=None, y2label=None, width=8, height=None, dpi=300, **plot_kwargs 76): 77 """ 78 Variant of pretty_plot that does a dual axis plot. Adapted from matplotlib 79 examples. Makes it easier to create plots with different axes. 80 81 Args: 82 x (np.ndarray/list): Data for x-axis. 83 y1 (dict/np.ndarray/list): Data for y1 axis (left). If a dict, it will 84 be interpreted as a {label: sequence}. 85 y2 (dict/np.ndarray/list): Data for y2 axis (right). If a dict, it will 86 be interpreted as a {label: sequence}. 87 xlabel (str): If not None, this will be the label for the x-axis. 88 y1label (str): If not None, this will be the label for the y1-axis. 89 y2label (str): If not None, this will be the label for the y2-axis. 90 width (float): Width of plot in inches. Defaults to 8in. 91 height (float): Height of plot in inches. Defaults to width * golden 92 ratio. 93 dpi (int): Sets dot per inch for figure. Defaults to 300. 94 plot_kwargs: Passthrough kwargs to matplotlib's plot method. E.g., 95 linewidth, etc. 96 97 Returns: 98 matplotlib.pyplot 99 """ 100 # pylint: disable=E1101 101 import palettable.colorbrewer.diverging 102 103 colors = palettable.colorbrewer.diverging.RdYlBu_4.mpl_colors 104 c1 = colors[0] 105 c2 = colors[-1] 106 107 golden_ratio = (math.sqrt(5) - 1) / 2 108 109 if not height: 110 height = int(width * golden_ratio) 111 112 import matplotlib.pyplot as plt 113 114 width = 12 115 labelsize = int(width * 3) 116 ticksize = int(width * 2.5) 117 styles = ["-", "--", "-.", "."] 118 119 fig, ax1 = plt.subplots() 120 fig.set_size_inches((width, height)) 121 if dpi: 122 fig.set_dpi(dpi) 123 if isinstance(y1, dict): 124 for i, (k, v) in enumerate(y1.items()): 125 ax1.plot(x, v, c=c1, marker="s", ls=styles[i % len(styles)], label=k, **plot_kwargs) 126 ax1.legend(fontsize=labelsize) 127 else: 128 ax1.plot(x, y1, c=c1, marker="s", ls="-", **plot_kwargs) 129 130 if xlabel: 131 ax1.set_xlabel(xlabel, fontsize=labelsize) 132 133 if y1label: 134 # Make the y-axis label, ticks and tick labels match the line color. 135 ax1.set_ylabel(y1label, color=c1, fontsize=labelsize) 136 137 ax1.tick_params("x", labelsize=ticksize) 138 ax1.tick_params("y", colors=c1, labelsize=ticksize) 139 140 ax2 = ax1.twinx() 141 if isinstance(y2, dict): 142 for i, (k, v) in enumerate(y2.items()): 143 ax2.plot(x, v, c=c2, marker="o", ls=styles[i % len(styles)], label=k) 144 ax2.legend(fontsize=labelsize) 145 else: 146 ax2.plot(x, y2, c=c2, marker="o", ls="-") 147 148 if y2label: 149 # Make the y-axis label, ticks and tick labels match the line color. 150 ax2.set_ylabel(y2label, color=c2, fontsize=labelsize) 151 152 ax2.tick_params("y", colors=c2, labelsize=ticksize) 153 return plt 154 155 156def pretty_polyfit_plot(x, y, deg=1, xlabel=None, ylabel=None, **kwargs): 157 r""" 158 Convenience method to plot data with trend lines based on polynomial fit. 159 160 Args: 161 x: Sequence of x data. 162 y: Sequence of y data. 163 deg (int): Degree of polynomial. Defaults to 1. 164 xlabel (str): Label for x-axis. 165 ylabel (str): Label for y-axis. 166 \\*\\*kwargs: Keyword args passed to pretty_plot. 167 168 Returns: 169 matplotlib.pyplot object. 170 """ 171 plt = pretty_plot(**kwargs) 172 pp = np.polyfit(x, y, deg) 173 xp = np.linspace(min(x), max(x), 200) 174 plt.plot(xp, np.polyval(pp, xp), "k--", x, y, "o") 175 if xlabel: 176 plt.xlabel(xlabel) 177 if ylabel: 178 plt.ylabel(ylabel) 179 return plt 180 181 182def _decide_fontcolor(rgba: tuple) -> Literal["black", "white"]: 183 red, green, blue, _ = rgba 184 if (red * 0.299 + green * 0.587 + blue * 0.114) * 255 > 186: 185 return "black" 186 187 return "white" 188 189 190def periodic_table_heatmap( 191 elemental_data, 192 cbar_label="", 193 cbar_label_size=14, 194 show_plot=False, 195 cmap="YlOrRd", 196 cmap_range=None, 197 blank_color="grey", 198 edge_color="white", 199 value_format=None, 200 value_fontsize=10, 201 symbol_fontsize=14, 202 max_row=9, 203 readable_fontcolor=False, 204): 205 """ 206 A static method that generates a heat map overlayed on a periodic table. 207 208 Args: 209 elemental_data (dict): A dictionary with the element as a key and a 210 value assigned to it, e.g. surface energy and frequency, etc. 211 Elements missing in the elemental_data will be grey by default 212 in the final table elemental_data={"Fe": 4.2, "O": 5.0}. 213 cbar_label (string): Label of the colorbar. Default is "". 214 cbar_label_size (float): Font size for the colorbar label. Default is 14. 215 cmap_range (tuple): Minimum and maximum value of the colormap scale. 216 If None, the colormap will autotmatically scale to the range of the 217 data. 218 show_plot (bool): Whether to show the heatmap. Default is False. 219 value_format (str): Formatting string to show values. If None, no value 220 is shown. Example: "%.4f" shows float to four decimals. 221 value_fontsize (float): Font size for values. Default is 10. 222 symbol_fontsize (float): Font size for element symbols. Default is 14. 223 cmap (string): Color scheme of the heatmap. Default is 'YlOrRd'. 224 Refer to the matplotlib documentation for other options. 225 blank_color (string): Color assigned for the missing elements in 226 elemental_data. Default is "grey". 227 edge_color (string): Color assigned for the edge of elements in the 228 periodic table. Default is "white". 229 max_row (integer): Maximum number of rows of the periodic table to be 230 shown. Default is 9, which means the periodic table heat map covers 231 the first 9 rows of elements. 232 readable_fontcolor (bool): Whether to use readable fontcolor depending 233 on background color. Default is False. 234 """ 235 236 # Convert primitive_elemental data in the form of numpy array for plotting. 237 if cmap_range is not None: 238 max_val = cmap_range[1] 239 min_val = cmap_range[0] 240 else: 241 max_val = max(elemental_data.values()) 242 min_val = min(elemental_data.values()) 243 244 max_row = min(max_row, 9) 245 246 if max_row <= 0: 247 raise ValueError("The input argument 'max_row' must be positive!") 248 249 value_table = np.empty((max_row, 18)) * np.nan 250 blank_value = min_val - 0.01 251 252 for el in Element: 253 if el.row > max_row: 254 continue 255 value = elemental_data.get(el.symbol, blank_value) 256 value_table[el.row - 1, el.group - 1] = value 257 258 # Initialize the plt object 259 import matplotlib.pyplot as plt 260 261 fig, ax = plt.subplots() 262 plt.gcf().set_size_inches(12, 8) 263 264 # We set nan type values to masked values (ie blank spaces) 265 data_mask = np.ma.masked_invalid(value_table.tolist()) 266 heatmap = ax.pcolor( 267 data_mask, 268 cmap=cmap, 269 edgecolors=edge_color, 270 linewidths=1, 271 vmin=min_val - 0.001, 272 vmax=max_val + 0.001, 273 ) 274 cbar = fig.colorbar(heatmap) 275 276 # Grey out missing elements in input data 277 cbar.cmap.set_under(blank_color) 278 279 # Set the colorbar label and tick marks 280 cbar.set_label(cbar_label, rotation=270, labelpad=25, size=cbar_label_size) 281 cbar.ax.tick_params(labelsize=cbar_label_size) 282 283 # Refine and make the table look nice 284 ax.axis("off") 285 ax.invert_yaxis() 286 287 # Set the scalermap for fontcolor 288 norm = colors.Normalize(vmin=min_val, vmax=max_val) 289 scalar_cmap = cm.ScalarMappable(norm=norm, cmap=cmap) 290 291 # Label each block with corresponding element and value 292 for i, row in enumerate(value_table): 293 for j, el in enumerate(row): 294 if not np.isnan(el): 295 symbol = Element.from_row_and_group(i + 1, j + 1).symbol 296 rgba = scalar_cmap.to_rgba(el) 297 fontcolor = _decide_fontcolor(rgba) if readable_fontcolor else "black" 298 plt.text( 299 j + 0.5, 300 i + 0.25, 301 symbol, 302 horizontalalignment="center", 303 verticalalignment="center", 304 fontsize=symbol_fontsize, 305 color=fontcolor, 306 ) 307 if el != blank_value and value_format is not None: 308 plt.text( 309 j + 0.5, 310 i + 0.5, 311 value_format % el, 312 horizontalalignment="center", 313 verticalalignment="center", 314 fontsize=value_fontsize, 315 color=fontcolor, 316 ) 317 318 plt.tight_layout() 319 320 if show_plot: 321 plt.show() 322 323 return plt 324 325 326def format_formula(formula): 327 """ 328 Converts str of chemical formula into 329 latex format for labelling purposes 330 331 Args: 332 formula (str): Chemical formula 333 """ 334 335 formatted_formula = "" 336 number_format = "" 337 for i, s in enumerate(formula): 338 if s.isdigit(): 339 if not number_format: 340 number_format = "_{" 341 number_format += s 342 if i == len(formula) - 1: 343 number_format += "}" 344 formatted_formula += number_format 345 else: 346 if number_format: 347 number_format += "}" 348 formatted_formula += number_format 349 number_format = "" 350 formatted_formula += s 351 352 return r"$%s$" % (formatted_formula) 353 354 355def van_arkel_triangle(list_of_materials, annotate=True): 356 """ 357 A static method that generates a binary van Arkel-Ketelaar triangle to 358 quantify the ionic, metallic and covalent character of a compound 359 by plotting the electronegativity difference (y) vs average (x). 360 See: 361 A.E. van Arkel, Molecules and Crystals in Inorganic Chemistry, 362 Interscience, New York (1956) 363 and 364 J.A.A Ketelaar, Chemical Constitution (2nd edn.), An Introduction 365 to the Theory of the Chemical Bond, Elsevier, New York (1958) 366 367 Args: 368 list_of_materials (list): A list of computed entries of binary 369 materials or a list of lists containing two elements (str). 370 annotate (bool): Whether or not to lable the points on the 371 triangle with reduced formula (if list of entries) or pair 372 of elements (if list of list of str). 373 """ 374 375 # F-Fr has the largest X difference. We set this 376 # as our top corner of the triangle (most ionic) 377 pt1 = np.array([(Element("F").X + Element("Fr").X) / 2, abs(Element("F").X - Element("Fr").X)]) 378 # Cs-Fr has the lowest average X. We set this as our 379 # bottom left corner of the triangle (most metallic) 380 pt2 = np.array( 381 [ 382 (Element("Cs").X + Element("Fr").X) / 2, 383 abs(Element("Cs").X - Element("Fr").X), 384 ] 385 ) 386 # O-F has the highest average X. We set this as our 387 # bottom right corner of the triangle (most covalent) 388 pt3 = np.array([(Element("O").X + Element("F").X) / 2, abs(Element("O").X - Element("F").X)]) 389 390 # get the parameters for the lines of the triangle 391 d = np.array(pt1) - np.array(pt2) 392 slope1 = d[1] / d[0] 393 b1 = pt1[1] - slope1 * pt1[0] 394 d = pt3 - pt1 395 slope2 = d[1] / d[0] 396 b2 = pt3[1] - slope2 * pt3[0] 397 398 # Initialize the plt object 399 import matplotlib.pyplot as plt 400 401 # set labels and appropriate limits for plot 402 plt.xlim(pt2[0] - 0.45, -b2 / slope2 + 0.45) 403 plt.ylim(-0.45, pt1[1] + 0.45) 404 plt.annotate("Ionic", xy=[pt1[0] - 0.3, pt1[1] + 0.05], fontsize=20) 405 plt.annotate("Covalent", xy=[-b2 / slope2 - 0.65, -0.4], fontsize=20) 406 plt.annotate("Metallic", xy=[pt2[0] - 0.4, -0.4], fontsize=20) 407 plt.xlabel(r"$\frac{\chi_{A}+\chi_{B}}{2}$", fontsize=25) 408 plt.ylabel(r"$|\chi_{A}-\chi_{B}|$", fontsize=25) 409 410 # Set the lines of the triangle 411 chi_list = [el.X for el in Element] 412 plt.plot( 413 [min(chi_list), pt1[0]], 414 [slope1 * min(chi_list) + b1, pt1[1]], 415 "k-", 416 linewidth=3, 417 ) 418 plt.plot([pt1[0], -b2 / slope2], [pt1[1], 0], "k-", linewidth=3) 419 plt.plot([min(chi_list), -b2 / slope2], [0, 0], "k-", linewidth=3) 420 plt.xticks(fontsize=15) 421 plt.yticks(fontsize=15) 422 423 # Shade with appropriate colors corresponding to ionic, metallci and covalent 424 ax = plt.gca() 425 # ionic filling 426 ax.fill_between( 427 [min(chi_list), pt1[0]], 428 [slope1 * min(chi_list) + b1, pt1[1]], 429 facecolor=[1, 1, 0], 430 zorder=-5, 431 edgecolor=[1, 1, 0], 432 ) 433 ax.fill_between( 434 [pt1[0], -b2 / slope2], 435 [pt1[1], slope2 * min(chi_list) - b1], 436 facecolor=[1, 1, 0], 437 zorder=-5, 438 edgecolor=[1, 1, 0], 439 ) 440 # metal filling 441 XPt = Element("Pt").X 442 ax.fill_between( 443 [min(chi_list), (XPt + min(chi_list)) / 2], 444 [0, slope1 * (XPt + min(chi_list)) / 2 + b1], 445 facecolor=[1, 0, 0], 446 zorder=-3, 447 alpha=0.8, 448 ) 449 ax.fill_between( 450 [(XPt + min(chi_list)) / 2, XPt], 451 [slope1 * ((XPt + min(chi_list)) / 2) + b1, 0], 452 facecolor=[1, 0, 0], 453 zorder=-3, 454 alpha=0.8, 455 ) 456 # covalent filling 457 ax.fill_between( 458 [(XPt + min(chi_list)) / 2, ((XPt + min(chi_list)) / 2 + -b2 / slope2) / 2], 459 [0, slope2 * (((XPt + min(chi_list)) / 2 + -b2 / slope2) / 2) + b2], 460 facecolor=[0, 1, 0], 461 zorder=-4, 462 alpha=0.8, 463 ) 464 ax.fill_between( 465 [((XPt + min(chi_list)) / 2 + -b2 / slope2) / 2, -b2 / slope2], 466 [slope2 * (((XPt + min(chi_list)) / 2 + -b2 / slope2) / 2) + b2, 0], 467 facecolor=[0, 1, 0], 468 zorder=-4, 469 alpha=0.8, 470 ) 471 472 # Label the triangle with datapoints 473 for entry in list_of_materials: 474 if type(entry).__name__ not in ["ComputedEntry", "ComputedStructureEntry"]: 475 X_pair = [Element(el).X for el in entry] 476 formatted_formula = "%s-%s" % tuple(entry) 477 else: 478 X_pair = [Element(el).X for el in entry.composition.as_dict().keys()] 479 formatted_formula = format_formula(entry.composition.reduced_formula) 480 plt.scatter(np.mean(X_pair), abs(X_pair[0] - X_pair[1]), c="b", s=100) 481 if annotate: 482 plt.annotate( 483 formatted_formula, 484 fontsize=15, 485 xy=[np.mean(X_pair) + 0.005, abs(X_pair[0] - X_pair[1])], 486 ) 487 488 plt.tight_layout() 489 return plt 490 491 492def get_ax_fig_plt(ax=None, **kwargs): 493 """ 494 Helper function used in plot functions supporting an optional Axes argument. 495 If ax is None, we build the `matplotlib` figure and create the Axes else 496 we return the current active figure. 497 498 Args: 499 kwargs: keyword arguments are passed to plt.figure if ax is not None. 500 501 Returns: 502 ax: :class:`Axes` object 503 figure: matplotlib figure 504 plt: matplotlib pyplot module. 505 """ 506 import matplotlib.pyplot as plt 507 508 if ax is None: 509 fig = plt.figure(**kwargs) 510 ax = fig.add_subplot(1, 1, 1) 511 else: 512 fig = plt.gcf() 513 514 return ax, fig, plt 515 516 517def get_ax3d_fig_plt(ax=None, **kwargs): 518 """ 519 Helper function used in plot functions supporting an optional Axes3D 520 argument. If ax is None, we build the `matplotlib` figure and create the 521 Axes3D else we return the current active figure. 522 523 Args: 524 kwargs: keyword arguments are passed to plt.figure if ax is not None. 525 526 Returns: 527 ax: :class:`Axes` object 528 figure: matplotlib figure 529 plt: matplotlib pyplot module. 530 """ 531 import matplotlib.pyplot as plt 532 from mpl_toolkits.mplot3d import axes3d 533 534 if ax is None: 535 fig = plt.figure(**kwargs) 536 ax = axes3d.Axes3D(fig) 537 else: 538 fig = plt.gcf() 539 540 return ax, fig, plt 541 542 543def get_axarray_fig_plt( 544 ax_array, nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True, subplot_kw=None, gridspec_kw=None, **fig_kw 545): 546 """ 547 Helper function used in plot functions that accept an optional array of Axes 548 as argument. If ax_array is None, we build the `matplotlib` figure and 549 create the array of Axes by calling plt.subplots else we return the 550 current active figure. 551 552 Returns: 553 ax: Array of :class:`Axes` objects 554 figure: matplotlib figure 555 plt: matplotlib pyplot module. 556 """ 557 import matplotlib.pyplot as plt 558 559 if ax_array is None: 560 fig, ax_array = plt.subplots( 561 nrows=nrows, 562 ncols=ncols, 563 sharex=sharex, 564 sharey=sharey, 565 squeeze=squeeze, 566 subplot_kw=subplot_kw, 567 gridspec_kw=gridspec_kw, 568 **fig_kw, 569 ) 570 else: 571 fig = plt.gcf() 572 ax_array = np.reshape(np.array(ax_array), (nrows, ncols)) 573 if squeeze: 574 if ax_array.size == 1: 575 ax_array = ax_array[0] 576 elif any(s == 1 for s in ax_array.shape): 577 ax_array = ax_array.ravel() 578 579 return ax_array, fig, plt 580 581 582def add_fig_kwargs(func): 583 """ 584 Decorator that adds keyword arguments for functions returning matplotlib 585 figures. 586 587 The function should return either a matplotlib figure or None to signal 588 some sort of error/unexpected event. 589 See doc string below for the list of supported options. 590 """ 591 from functools import wraps 592 593 @wraps(func) 594 def wrapper(*args, **kwargs): 595 # pop the kwds used by the decorator. 596 title = kwargs.pop("title", None) 597 size_kwargs = kwargs.pop("size_kwargs", None) 598 show = kwargs.pop("show", True) 599 savefig = kwargs.pop("savefig", None) 600 tight_layout = kwargs.pop("tight_layout", False) 601 ax_grid = kwargs.pop("ax_grid", None) 602 ax_annotate = kwargs.pop("ax_annotate", None) 603 fig_close = kwargs.pop("fig_close", False) 604 605 # Call func and return immediately if None is returned. 606 fig = func(*args, **kwargs) 607 if fig is None: 608 return fig 609 610 # Operate on matplotlib figure. 611 if title is not None: 612 fig.suptitle(title) 613 614 if size_kwargs is not None: 615 fig.set_size_inches(size_kwargs.pop("w"), size_kwargs.pop("h"), **size_kwargs) 616 617 if ax_grid is not None: 618 for ax in fig.axes: 619 ax.grid(bool(ax_grid)) 620 621 if ax_annotate: 622 from string import ascii_letters 623 624 tags = ascii_letters 625 if len(fig.axes) > len(tags): 626 tags = (1 + len(ascii_letters) // len(fig.axes)) * ascii_letters 627 for ax, tag in zip(fig.axes, tags): 628 ax.annotate("(%s)" % tag, xy=(0.05, 0.95), xycoords="axes fraction") 629 630 if tight_layout: 631 try: 632 fig.tight_layout() 633 except Exception as exc: 634 # For some unknown reason, this problem shows up only on travis. 635 # https://stackoverflow.com/questions/22708888/valueerror-when-using-matplotlib-tight-layout 636 print("Ignoring Exception raised by fig.tight_layout\n", str(exc)) 637 638 if savefig: 639 fig.savefig(savefig) 640 641 import matplotlib.pyplot as plt 642 643 if show: 644 plt.show() 645 if fig_close: 646 plt.close(fig=fig) 647 648 return fig 649 650 # Add docstring to the decorated method. 651 s = ( 652 "\n\n" 653 + """\ 654 Keyword arguments controlling the display of the figure: 655 656 ================ ==================================================== 657 kwargs Meaning 658 ================ ==================================================== 659 title Title of the plot (Default: None). 660 show True to show the figure (default: True). 661 savefig "abc.png" or "abc.eps" to save the figure to a file. 662 size_kwargs Dictionary with options passed to fig.set_size_inches 663 e.g. size_kwargs=dict(w=3, h=4) 664 tight_layout True to call fig.tight_layout (default: False) 665 ax_grid True (False) to add (remove) grid from all axes in fig. 666 Default: None i.e. fig is left unchanged. 667 ax_annotate Add labels to subplots e.g. (a), (b). 668 Default: False 669 fig_close Close figure. Default: False. 670 ================ ==================================================== 671 672""" 673 ) 674 675 if wrapper.__doc__ is not None: 676 # Add s at the end of the docstring. 677 wrapper.__doc__ += "\n" + s 678 else: 679 # Use s 680 wrapper.__doc__ = s 681 682 return wrapper 683