1# coding: utf-8
2# Copyright (c) Pymatgen Development Team.
3# Distributed under the terms of the MIT License.
4"""
5Utilities for generating nicer plots.
6"""
7import math
8import sys
9from matplotlib import colors, cm
10
11import numpy as np
12
13from pymatgen.core.periodic_table import Element
14
15if sys.version_info >= (3, 8):
16    from typing import Literal
17else:
18    from typing_extensions import Literal
19
20
21def pretty_plot(width=8, height=None, plt=None, dpi=None, color_cycle=("qualitative", "Set1_9")):
22    """
23    Provides a publication quality plot, with nice defaults for font sizes etc.
24
25    Args:
26        width (float): Width of plot in inches. Defaults to 8in.
27        height (float): Height of plot in inches. Defaults to width * golden
28            ratio.
29        plt (matplotlib.pyplot): If plt is supplied, changes will be made to an
30            existing plot. Otherwise, a new plot will be created.
31        dpi (int): Sets dot per inch for figure. Defaults to 300.
32        color_cycle (tuple): Set the color cycle for new plots to one of the
33            color sets in palettable. Defaults to a qualitative Set1_9.
34
35    Returns:
36        Matplotlib plot object with properly sized fonts.
37    """
38    ticksize = int(width * 2.5)
39
40    golden_ratio = (math.sqrt(5) - 1) / 2
41
42    if not height:
43        height = int(width * golden_ratio)
44
45    if plt is None:
46        import importlib
47
48        import matplotlib.pyplot as plt
49
50        mod = importlib.import_module("palettable.colorbrewer.%s" % color_cycle[0])
51        colors = getattr(mod, color_cycle[1]).mpl_colors
52        from cycler import cycler
53
54        plt.figure(figsize=(width, height), facecolor="w", dpi=dpi)
55        ax = plt.gca()
56        ax.set_prop_cycle(cycler("color", colors))
57    else:
58        fig = plt.gcf()
59        fig.set_size_inches(width, height)
60    plt.xticks(fontsize=ticksize)
61    plt.yticks(fontsize=ticksize)
62
63    ax = plt.gca()
64    ax.set_title(ax.get_title(), size=width * 4)
65
66    labelsize = int(width * 3)
67
68    ax.set_xlabel(ax.get_xlabel(), size=labelsize)
69    ax.set_ylabel(ax.get_ylabel(), size=labelsize)
70
71    return plt
72
73
74def pretty_plot_two_axis(
75    x, y1, y2, xlabel=None, y1label=None, y2label=None, width=8, height=None, dpi=300, **plot_kwargs
76):
77    """
78    Variant of pretty_plot that does a dual axis plot. Adapted from matplotlib
79    examples. Makes it easier to create plots with different axes.
80
81    Args:
82        x (np.ndarray/list): Data for x-axis.
83        y1 (dict/np.ndarray/list): Data for y1 axis (left). If a dict, it will
84            be interpreted as a {label: sequence}.
85        y2 (dict/np.ndarray/list): Data for y2 axis (right). If a dict, it will
86            be interpreted as a {label: sequence}.
87        xlabel (str): If not None, this will be the label for the x-axis.
88        y1label (str): If not None, this will be the label for the y1-axis.
89        y2label (str): If not None, this will be the label for the y2-axis.
90        width (float): Width of plot in inches. Defaults to 8in.
91        height (float): Height of plot in inches. Defaults to width * golden
92            ratio.
93        dpi (int): Sets dot per inch for figure. Defaults to 300.
94        plot_kwargs: Passthrough kwargs to matplotlib's plot method. E.g.,
95            linewidth, etc.
96
97    Returns:
98        matplotlib.pyplot
99    """
100    # pylint: disable=E1101
101    import palettable.colorbrewer.diverging
102
103    colors = palettable.colorbrewer.diverging.RdYlBu_4.mpl_colors
104    c1 = colors[0]
105    c2 = colors[-1]
106
107    golden_ratio = (math.sqrt(5) - 1) / 2
108
109    if not height:
110        height = int(width * golden_ratio)
111
112    import matplotlib.pyplot as plt
113
114    width = 12
115    labelsize = int(width * 3)
116    ticksize = int(width * 2.5)
117    styles = ["-", "--", "-.", "."]
118
119    fig, ax1 = plt.subplots()
120    fig.set_size_inches((width, height))
121    if dpi:
122        fig.set_dpi(dpi)
123    if isinstance(y1, dict):
124        for i, (k, v) in enumerate(y1.items()):
125            ax1.plot(x, v, c=c1, marker="s", ls=styles[i % len(styles)], label=k, **plot_kwargs)
126        ax1.legend(fontsize=labelsize)
127    else:
128        ax1.plot(x, y1, c=c1, marker="s", ls="-", **plot_kwargs)
129
130    if xlabel:
131        ax1.set_xlabel(xlabel, fontsize=labelsize)
132
133    if y1label:
134        # Make the y-axis label, ticks and tick labels match the line color.
135        ax1.set_ylabel(y1label, color=c1, fontsize=labelsize)
136
137    ax1.tick_params("x", labelsize=ticksize)
138    ax1.tick_params("y", colors=c1, labelsize=ticksize)
139
140    ax2 = ax1.twinx()
141    if isinstance(y2, dict):
142        for i, (k, v) in enumerate(y2.items()):
143            ax2.plot(x, v, c=c2, marker="o", ls=styles[i % len(styles)], label=k)
144        ax2.legend(fontsize=labelsize)
145    else:
146        ax2.plot(x, y2, c=c2, marker="o", ls="-")
147
148    if y2label:
149        # Make the y-axis label, ticks and tick labels match the line color.
150        ax2.set_ylabel(y2label, color=c2, fontsize=labelsize)
151
152    ax2.tick_params("y", colors=c2, labelsize=ticksize)
153    return plt
154
155
156def pretty_polyfit_plot(x, y, deg=1, xlabel=None, ylabel=None, **kwargs):
157    r"""
158    Convenience method to plot data with trend lines based on polynomial fit.
159
160    Args:
161        x: Sequence of x data.
162        y: Sequence of y data.
163        deg (int): Degree of polynomial. Defaults to 1.
164        xlabel (str): Label for x-axis.
165        ylabel (str): Label for y-axis.
166        \\*\\*kwargs: Keyword args passed to pretty_plot.
167
168    Returns:
169        matplotlib.pyplot object.
170    """
171    plt = pretty_plot(**kwargs)
172    pp = np.polyfit(x, y, deg)
173    xp = np.linspace(min(x), max(x), 200)
174    plt.plot(xp, np.polyval(pp, xp), "k--", x, y, "o")
175    if xlabel:
176        plt.xlabel(xlabel)
177    if ylabel:
178        plt.ylabel(ylabel)
179    return plt
180
181
182def _decide_fontcolor(rgba: tuple) -> Literal["black", "white"]:
183    red, green, blue, _ = rgba
184    if (red * 0.299 + green * 0.587 + blue * 0.114) * 255 > 186:
185        return "black"
186
187    return "white"
188
189
190def periodic_table_heatmap(
191    elemental_data,
192    cbar_label="",
193    cbar_label_size=14,
194    show_plot=False,
195    cmap="YlOrRd",
196    cmap_range=None,
197    blank_color="grey",
198    edge_color="white",
199    value_format=None,
200    value_fontsize=10,
201    symbol_fontsize=14,
202    max_row=9,
203    readable_fontcolor=False,
204):
205    """
206    A static method that generates a heat map overlayed on a periodic table.
207
208    Args:
209         elemental_data (dict): A dictionary with the element as a key and a
210            value assigned to it, e.g. surface energy and frequency, etc.
211            Elements missing in the elemental_data will be grey by default
212            in the final table elemental_data={"Fe": 4.2, "O": 5.0}.
213         cbar_label (string): Label of the colorbar. Default is "".
214         cbar_label_size (float): Font size for the colorbar label. Default is 14.
215         cmap_range (tuple): Minimum and maximum value of the colormap scale.
216            If None, the colormap will autotmatically scale to the range of the
217            data.
218         show_plot (bool): Whether to show the heatmap. Default is False.
219         value_format (str): Formatting string to show values. If None, no value
220            is shown. Example: "%.4f" shows float to four decimals.
221         value_fontsize (float): Font size for values. Default is 10.
222         symbol_fontsize (float): Font size for element symbols. Default is 14.
223         cmap (string): Color scheme of the heatmap. Default is 'YlOrRd'.
224            Refer to the matplotlib documentation for other options.
225         blank_color (string): Color assigned for the missing elements in
226            elemental_data. Default is "grey".
227         edge_color (string): Color assigned for the edge of elements in the
228            periodic table. Default is "white".
229         max_row (integer): Maximum number of rows of the periodic table to be
230            shown. Default is 9, which means the periodic table heat map covers
231            the first 9 rows of elements.
232         readable_fontcolor (bool): Whether to use readable fontcolor depending
233            on background color. Default is False.
234    """
235
236    # Convert primitive_elemental data in the form of numpy array for plotting.
237    if cmap_range is not None:
238        max_val = cmap_range[1]
239        min_val = cmap_range[0]
240    else:
241        max_val = max(elemental_data.values())
242        min_val = min(elemental_data.values())
243
244    max_row = min(max_row, 9)
245
246    if max_row <= 0:
247        raise ValueError("The input argument 'max_row' must be positive!")
248
249    value_table = np.empty((max_row, 18)) * np.nan
250    blank_value = min_val - 0.01
251
252    for el in Element:
253        if el.row > max_row:
254            continue
255        value = elemental_data.get(el.symbol, blank_value)
256        value_table[el.row - 1, el.group - 1] = value
257
258    # Initialize the plt object
259    import matplotlib.pyplot as plt
260
261    fig, ax = plt.subplots()
262    plt.gcf().set_size_inches(12, 8)
263
264    # We set nan type values to masked values (ie blank spaces)
265    data_mask = np.ma.masked_invalid(value_table.tolist())
266    heatmap = ax.pcolor(
267        data_mask,
268        cmap=cmap,
269        edgecolors=edge_color,
270        linewidths=1,
271        vmin=min_val - 0.001,
272        vmax=max_val + 0.001,
273    )
274    cbar = fig.colorbar(heatmap)
275
276    # Grey out missing elements in input data
277    cbar.cmap.set_under(blank_color)
278
279    # Set the colorbar label and tick marks
280    cbar.set_label(cbar_label, rotation=270, labelpad=25, size=cbar_label_size)
281    cbar.ax.tick_params(labelsize=cbar_label_size)
282
283    # Refine and make the table look nice
284    ax.axis("off")
285    ax.invert_yaxis()
286
287    # Set the scalermap for fontcolor
288    norm = colors.Normalize(vmin=min_val, vmax=max_val)
289    scalar_cmap = cm.ScalarMappable(norm=norm, cmap=cmap)
290
291    # Label each block with corresponding element and value
292    for i, row in enumerate(value_table):
293        for j, el in enumerate(row):
294            if not np.isnan(el):
295                symbol = Element.from_row_and_group(i + 1, j + 1).symbol
296                rgba = scalar_cmap.to_rgba(el)
297                fontcolor = _decide_fontcolor(rgba) if readable_fontcolor else "black"
298                plt.text(
299                    j + 0.5,
300                    i + 0.25,
301                    symbol,
302                    horizontalalignment="center",
303                    verticalalignment="center",
304                    fontsize=symbol_fontsize,
305                    color=fontcolor,
306                )
307                if el != blank_value and value_format is not None:
308                    plt.text(
309                        j + 0.5,
310                        i + 0.5,
311                        value_format % el,
312                        horizontalalignment="center",
313                        verticalalignment="center",
314                        fontsize=value_fontsize,
315                        color=fontcolor,
316                    )
317
318    plt.tight_layout()
319
320    if show_plot:
321        plt.show()
322
323    return plt
324
325
326def format_formula(formula):
327    """
328    Converts str of chemical formula into
329    latex format for labelling purposes
330
331    Args:
332        formula (str): Chemical formula
333    """
334
335    formatted_formula = ""
336    number_format = ""
337    for i, s in enumerate(formula):
338        if s.isdigit():
339            if not number_format:
340                number_format = "_{"
341            number_format += s
342            if i == len(formula) - 1:
343                number_format += "}"
344                formatted_formula += number_format
345        else:
346            if number_format:
347                number_format += "}"
348                formatted_formula += number_format
349                number_format = ""
350            formatted_formula += s
351
352    return r"$%s$" % (formatted_formula)
353
354
355def van_arkel_triangle(list_of_materials, annotate=True):
356    """
357    A static method that generates a binary van Arkel-Ketelaar triangle to
358        quantify the ionic, metallic and covalent character of a compound
359        by plotting the electronegativity difference (y) vs average (x).
360        See:
361            A.E. van Arkel, Molecules and Crystals in Inorganic Chemistry,
362                Interscience, New York (1956)
363        and
364            J.A.A Ketelaar, Chemical Constitution (2nd edn.), An Introduction
365                to the Theory of the Chemical Bond, Elsevier, New York (1958)
366
367    Args:
368         list_of_materials (list): A list of computed entries of binary
369            materials or a list of lists containing two elements (str).
370         annotate (bool): Whether or not to lable the points on the
371            triangle with reduced formula (if list of entries) or pair
372            of elements (if list of list of str).
373    """
374
375    # F-Fr has the largest X difference. We set this
376    # as our top corner of the triangle (most ionic)
377    pt1 = np.array([(Element("F").X + Element("Fr").X) / 2, abs(Element("F").X - Element("Fr").X)])
378    # Cs-Fr has the lowest average X. We set this as our
379    # bottom left corner of the triangle (most metallic)
380    pt2 = np.array(
381        [
382            (Element("Cs").X + Element("Fr").X) / 2,
383            abs(Element("Cs").X - Element("Fr").X),
384        ]
385    )
386    # O-F has the highest average X. We set this as our
387    # bottom right corner of the triangle (most covalent)
388    pt3 = np.array([(Element("O").X + Element("F").X) / 2, abs(Element("O").X - Element("F").X)])
389
390    # get the parameters for the lines of the triangle
391    d = np.array(pt1) - np.array(pt2)
392    slope1 = d[1] / d[0]
393    b1 = pt1[1] - slope1 * pt1[0]
394    d = pt3 - pt1
395    slope2 = d[1] / d[0]
396    b2 = pt3[1] - slope2 * pt3[0]
397
398    # Initialize the plt object
399    import matplotlib.pyplot as plt
400
401    # set labels and appropriate limits for plot
402    plt.xlim(pt2[0] - 0.45, -b2 / slope2 + 0.45)
403    plt.ylim(-0.45, pt1[1] + 0.45)
404    plt.annotate("Ionic", xy=[pt1[0] - 0.3, pt1[1] + 0.05], fontsize=20)
405    plt.annotate("Covalent", xy=[-b2 / slope2 - 0.65, -0.4], fontsize=20)
406    plt.annotate("Metallic", xy=[pt2[0] - 0.4, -0.4], fontsize=20)
407    plt.xlabel(r"$\frac{\chi_{A}+\chi_{B}}{2}$", fontsize=25)
408    plt.ylabel(r"$|\chi_{A}-\chi_{B}|$", fontsize=25)
409
410    # Set the lines of the triangle
411    chi_list = [el.X for el in Element]
412    plt.plot(
413        [min(chi_list), pt1[0]],
414        [slope1 * min(chi_list) + b1, pt1[1]],
415        "k-",
416        linewidth=3,
417    )
418    plt.plot([pt1[0], -b2 / slope2], [pt1[1], 0], "k-", linewidth=3)
419    plt.plot([min(chi_list), -b2 / slope2], [0, 0], "k-", linewidth=3)
420    plt.xticks(fontsize=15)
421    plt.yticks(fontsize=15)
422
423    # Shade with appropriate colors corresponding to ionic, metallci and covalent
424    ax = plt.gca()
425    # ionic filling
426    ax.fill_between(
427        [min(chi_list), pt1[0]],
428        [slope1 * min(chi_list) + b1, pt1[1]],
429        facecolor=[1, 1, 0],
430        zorder=-5,
431        edgecolor=[1, 1, 0],
432    )
433    ax.fill_between(
434        [pt1[0], -b2 / slope2],
435        [pt1[1], slope2 * min(chi_list) - b1],
436        facecolor=[1, 1, 0],
437        zorder=-5,
438        edgecolor=[1, 1, 0],
439    )
440    # metal filling
441    XPt = Element("Pt").X
442    ax.fill_between(
443        [min(chi_list), (XPt + min(chi_list)) / 2],
444        [0, slope1 * (XPt + min(chi_list)) / 2 + b1],
445        facecolor=[1, 0, 0],
446        zorder=-3,
447        alpha=0.8,
448    )
449    ax.fill_between(
450        [(XPt + min(chi_list)) / 2, XPt],
451        [slope1 * ((XPt + min(chi_list)) / 2) + b1, 0],
452        facecolor=[1, 0, 0],
453        zorder=-3,
454        alpha=0.8,
455    )
456    # covalent filling
457    ax.fill_between(
458        [(XPt + min(chi_list)) / 2, ((XPt + min(chi_list)) / 2 + -b2 / slope2) / 2],
459        [0, slope2 * (((XPt + min(chi_list)) / 2 + -b2 / slope2) / 2) + b2],
460        facecolor=[0, 1, 0],
461        zorder=-4,
462        alpha=0.8,
463    )
464    ax.fill_between(
465        [((XPt + min(chi_list)) / 2 + -b2 / slope2) / 2, -b2 / slope2],
466        [slope2 * (((XPt + min(chi_list)) / 2 + -b2 / slope2) / 2) + b2, 0],
467        facecolor=[0, 1, 0],
468        zorder=-4,
469        alpha=0.8,
470    )
471
472    # Label the triangle with datapoints
473    for entry in list_of_materials:
474        if type(entry).__name__ not in ["ComputedEntry", "ComputedStructureEntry"]:
475            X_pair = [Element(el).X for el in entry]
476            formatted_formula = "%s-%s" % tuple(entry)
477        else:
478            X_pair = [Element(el).X for el in entry.composition.as_dict().keys()]
479            formatted_formula = format_formula(entry.composition.reduced_formula)
480        plt.scatter(np.mean(X_pair), abs(X_pair[0] - X_pair[1]), c="b", s=100)
481        if annotate:
482            plt.annotate(
483                formatted_formula,
484                fontsize=15,
485                xy=[np.mean(X_pair) + 0.005, abs(X_pair[0] - X_pair[1])],
486            )
487
488    plt.tight_layout()
489    return plt
490
491
492def get_ax_fig_plt(ax=None, **kwargs):
493    """
494    Helper function used in plot functions supporting an optional Axes argument.
495    If ax is None, we build the `matplotlib` figure and create the Axes else
496    we return the current active figure.
497
498    Args:
499        kwargs: keyword arguments are passed to plt.figure if ax is not None.
500
501    Returns:
502        ax: :class:`Axes` object
503        figure: matplotlib figure
504        plt: matplotlib pyplot module.
505    """
506    import matplotlib.pyplot as plt
507
508    if ax is None:
509        fig = plt.figure(**kwargs)
510        ax = fig.add_subplot(1, 1, 1)
511    else:
512        fig = plt.gcf()
513
514    return ax, fig, plt
515
516
517def get_ax3d_fig_plt(ax=None, **kwargs):
518    """
519    Helper function used in plot functions supporting an optional Axes3D
520    argument. If ax is None, we build the `matplotlib` figure and create the
521    Axes3D else we return the current active figure.
522
523    Args:
524        kwargs: keyword arguments are passed to plt.figure if ax is not None.
525
526    Returns:
527        ax: :class:`Axes` object
528        figure: matplotlib figure
529        plt: matplotlib pyplot module.
530    """
531    import matplotlib.pyplot as plt
532    from mpl_toolkits.mplot3d import axes3d
533
534    if ax is None:
535        fig = plt.figure(**kwargs)
536        ax = axes3d.Axes3D(fig)
537    else:
538        fig = plt.gcf()
539
540    return ax, fig, plt
541
542
543def get_axarray_fig_plt(
544    ax_array, nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True, subplot_kw=None, gridspec_kw=None, **fig_kw
545):
546    """
547    Helper function used in plot functions that accept an optional array of Axes
548    as argument. If ax_array is None, we build the `matplotlib` figure and
549    create the array of Axes by calling plt.subplots else we return the
550    current active figure.
551
552    Returns:
553        ax: Array of :class:`Axes` objects
554        figure: matplotlib figure
555        plt: matplotlib pyplot module.
556    """
557    import matplotlib.pyplot as plt
558
559    if ax_array is None:
560        fig, ax_array = plt.subplots(
561            nrows=nrows,
562            ncols=ncols,
563            sharex=sharex,
564            sharey=sharey,
565            squeeze=squeeze,
566            subplot_kw=subplot_kw,
567            gridspec_kw=gridspec_kw,
568            **fig_kw,
569        )
570    else:
571        fig = plt.gcf()
572        ax_array = np.reshape(np.array(ax_array), (nrows, ncols))
573        if squeeze:
574            if ax_array.size == 1:
575                ax_array = ax_array[0]
576            elif any(s == 1 for s in ax_array.shape):
577                ax_array = ax_array.ravel()
578
579    return ax_array, fig, plt
580
581
582def add_fig_kwargs(func):
583    """
584    Decorator that adds keyword arguments for functions returning matplotlib
585    figures.
586
587    The function should return either a matplotlib figure or None to signal
588    some sort of error/unexpected event.
589    See doc string below for the list of supported options.
590    """
591    from functools import wraps
592
593    @wraps(func)
594    def wrapper(*args, **kwargs):
595        # pop the kwds used by the decorator.
596        title = kwargs.pop("title", None)
597        size_kwargs = kwargs.pop("size_kwargs", None)
598        show = kwargs.pop("show", True)
599        savefig = kwargs.pop("savefig", None)
600        tight_layout = kwargs.pop("tight_layout", False)
601        ax_grid = kwargs.pop("ax_grid", None)
602        ax_annotate = kwargs.pop("ax_annotate", None)
603        fig_close = kwargs.pop("fig_close", False)
604
605        # Call func and return immediately if None is returned.
606        fig = func(*args, **kwargs)
607        if fig is None:
608            return fig
609
610        # Operate on matplotlib figure.
611        if title is not None:
612            fig.suptitle(title)
613
614        if size_kwargs is not None:
615            fig.set_size_inches(size_kwargs.pop("w"), size_kwargs.pop("h"), **size_kwargs)
616
617        if ax_grid is not None:
618            for ax in fig.axes:
619                ax.grid(bool(ax_grid))
620
621        if ax_annotate:
622            from string import ascii_letters
623
624            tags = ascii_letters
625            if len(fig.axes) > len(tags):
626                tags = (1 + len(ascii_letters) // len(fig.axes)) * ascii_letters
627            for ax, tag in zip(fig.axes, tags):
628                ax.annotate("(%s)" % tag, xy=(0.05, 0.95), xycoords="axes fraction")
629
630        if tight_layout:
631            try:
632                fig.tight_layout()
633            except Exception as exc:
634                # For some unknown reason, this problem shows up only on travis.
635                # https://stackoverflow.com/questions/22708888/valueerror-when-using-matplotlib-tight-layout
636                print("Ignoring Exception raised by fig.tight_layout\n", str(exc))
637
638        if savefig:
639            fig.savefig(savefig)
640
641        import matplotlib.pyplot as plt
642
643        if show:
644            plt.show()
645        if fig_close:
646            plt.close(fig=fig)
647
648        return fig
649
650    # Add docstring to the decorated method.
651    s = (
652        "\n\n"
653        + """\
654        Keyword arguments controlling the display of the figure:
655
656        ================  ====================================================
657        kwargs            Meaning
658        ================  ====================================================
659        title             Title of the plot (Default: None).
660        show              True to show the figure (default: True).
661        savefig           "abc.png" or "abc.eps" to save the figure to a file.
662        size_kwargs       Dictionary with options passed to fig.set_size_inches
663                          e.g. size_kwargs=dict(w=3, h=4)
664        tight_layout      True to call fig.tight_layout (default: False)
665        ax_grid           True (False) to add (remove) grid from all axes in fig.
666                          Default: None i.e. fig is left unchanged.
667        ax_annotate       Add labels to  subplots e.g. (a), (b).
668                          Default: False
669        fig_close         Close figure. Default: False.
670        ================  ====================================================
671
672"""
673    )
674
675    if wrapper.__doc__ is not None:
676        # Add s at the end of the docstring.
677        wrapper.__doc__ += "\n" + s
678    else:
679        # Use s
680        wrapper.__doc__ = s
681
682    return wrapper
683