1# coding: utf-8
2# Copyright (c) Pymatgen Development Team.
3# Distributed under the terms of the MIT License.
4"""
5This module implements plotter for DOS and band structure.
6"""
7
8import copy
9import itertools
10import logging
11import math
12import warnings
13from collections import Counter, OrderedDict
14
15import matplotlib.lines as mlines
16import numpy as np
17import scipy.interpolate as scint
18from monty.dev import requires
19from monty.json import jsanitize
20
21try:
22    from mayavi import mlab
23except ImportError:
24    mlab = None
25
26from pymatgen.core.periodic_table import Element
27from pymatgen.electronic_structure.bandstructure import BandStructureSymmLine
28from pymatgen.electronic_structure.boltztrap import BoltztrapError
29from pymatgen.electronic_structure.core import OrbitalType, Spin
30from pymatgen.util.plotting import add_fig_kwargs, get_ax3d_fig_plt, pretty_plot
31
32__author__ = "Shyue Ping Ong, Geoffroy Hautier, Anubhav Jain"
33__copyright__ = "Copyright 2012, The Materials Project"
34__version__ = "0.1"
35__maintainer__ = "Shyue Ping Ong"
36__email__ = "shyuep@gmail.com"
37__date__ = "May 1, 2012"
38
39logger = logging.getLogger(__name__)
40
41
42class DosPlotter:
43    """
44    Class for plotting DOSs. Note that the interface is extremely flexible
45    given that there are many different ways in which people want to view
46    DOS. The typical usage is::
47
48        # Initializes plotter with some optional args. Defaults are usually
49        # fine,
50        plotter = DosPlotter()
51
52        # Adds a DOS with a label.
53        plotter.add_dos("Total DOS", dos)
54
55        # Alternatively, you can add a dict of DOSs. This is the typical
56        # form returned by CompleteDos.get_spd/element/others_dos().
57        plotter.add_dos_dict({"dos1": dos1, "dos2": dos2})
58        plotter.add_dos_dict(complete_dos.get_spd_dos())
59    """
60
61    def __init__(self, zero_at_efermi=True, stack=False, sigma=None):
62        """
63        Args:
64            zero_at_efermi: Whether to shift all Dos to have zero energy at the
65                fermi energy. Defaults to True.
66            stack: Whether to plot the DOS as a stacked area graph
67            key_sort_func: function used to sort the dos_dict keys.
68            sigma: A float specifying a standard deviation for Gaussian smearing
69                the DOS for nicer looking plots. Defaults to None for no
70                smearing.
71        """
72        self.zero_at_efermi = zero_at_efermi
73        self.stack = stack
74        self.sigma = sigma
75        self._doses = OrderedDict()
76
77    def add_dos(self, label, dos):
78        """
79        Adds a dos for plotting.
80
81        Args:
82            label:
83                label for the DOS. Must be unique.
84            dos:
85                Dos object
86        """
87        energies = dos.energies - dos.efermi if self.zero_at_efermi else dos.energies
88        densities = dos.get_smeared_densities(self.sigma) if self.sigma else dos.densities
89        efermi = dos.efermi
90        self._doses[label] = {
91            "energies": energies,
92            "densities": densities,
93            "efermi": efermi,
94        }
95
96    def add_dos_dict(self, dos_dict, key_sort_func=None):
97        """
98        Add a dictionary of doses, with an optional sorting function for the
99        keys.
100
101        Args:
102            dos_dict: dict of {label: Dos}
103            key_sort_func: function used to sort the dos_dict keys.
104        """
105        if key_sort_func:
106            keys = sorted(dos_dict.keys(), key=key_sort_func)
107        else:
108            keys = dos_dict.keys()
109        for label in keys:
110            self.add_dos(label, dos_dict[label])
111
112    def get_dos_dict(self):
113        """
114        Returns the added doses as a json-serializable dict. Note that if you
115        have specified smearing for the DOS plot, the densities returned will
116        be the smeared densities, not the original densities.
117
118        Returns:
119            dict: Dict of dos data. Generally of the form
120            {label: {'energies':..., 'densities': {'up':...}, 'efermi':efermi}}
121        """
122        return jsanitize(self._doses)
123
124    def get_plot(self, xlim=None, ylim=None):
125        """
126        Get a matplotlib plot showing the DOS.
127
128        Args:
129            xlim: Specifies the x-axis limits. Set to None for automatic
130                determination.
131            ylim: Specifies the y-axis limits.
132        """
133
134        ncolors = max(3, len(self._doses))
135        ncolors = min(9, ncolors)
136
137        import palettable
138
139        # pylint: disable=E1101
140        colors = palettable.colorbrewer.qualitative.Set1_9.mpl_colors
141
142        y = None
143        alldensities = []
144        allenergies = []
145        plt = pretty_plot(12, 8)
146
147        # Note that this complicated processing of energies is to allow for
148        # stacked plots in matplotlib.
149        for key, dos in self._doses.items():
150            energies = dos["energies"]
151            densities = dos["densities"]
152            if not y:
153                y = {
154                    Spin.up: np.zeros(energies.shape),
155                    Spin.down: np.zeros(energies.shape),
156                }
157            newdens = {}
158            for spin in [Spin.up, Spin.down]:
159                if spin in densities:
160                    if self.stack:
161                        y[spin] += densities[spin]
162                        newdens[spin] = y[spin].copy()
163                    else:
164                        newdens[spin] = densities[spin]
165            allenergies.append(energies)
166            alldensities.append(newdens)
167
168        keys = list(self._doses.keys())
169        keys.reverse()
170        alldensities.reverse()
171        allenergies.reverse()
172        allpts = []
173        for i, key in enumerate(keys):
174            x = []
175            y = []
176            for spin in [Spin.up, Spin.down]:
177                if spin in alldensities[i]:
178                    densities = list(int(spin) * alldensities[i][spin])
179                    energies = list(allenergies[i])
180                    if spin == Spin.down:
181                        energies.reverse()
182                        densities.reverse()
183                    x.extend(energies)
184                    y.extend(densities)
185            allpts.extend(list(zip(x, y)))
186            if self.stack:
187                plt.fill(x, y, color=colors[i % ncolors], label=str(key))
188            else:
189                plt.plot(x, y, color=colors[i % ncolors], label=str(key), linewidth=3)
190            if not self.zero_at_efermi:
191                ylim = plt.ylim()
192                plt.plot(
193                    [self._doses[key]["efermi"], self._doses[key]["efermi"]],
194                    ylim,
195                    color=colors[i % ncolors],
196                    linestyle="--",
197                    linewidth=2,
198                )
199
200        if xlim:
201            plt.xlim(xlim)
202        if ylim:
203            plt.ylim(ylim)
204        else:
205            xlim = plt.xlim()
206            relevanty = [p[1] for p in allpts if xlim[0] < p[0] < xlim[1]]
207            plt.ylim((min(relevanty), max(relevanty)))
208
209        if self.zero_at_efermi:
210            ylim = plt.ylim()
211            plt.plot([0, 0], ylim, "k--", linewidth=2)
212
213        plt.xlabel("Energies (eV)")
214        plt.ylabel("Density of states")
215
216        plt.axhline(y=0, color="k", linestyle="--", linewidth=2)
217        plt.legend()
218        leg = plt.gca().get_legend()
219        ltext = leg.get_texts()  # all the text.Text instance in the legend
220        plt.setp(ltext, fontsize=30)
221        plt.tight_layout()
222        return plt
223
224    def save_plot(self, filename, img_format="eps", xlim=None, ylim=None):
225        """
226        Save matplotlib plot to a file.
227
228        Args:
229            filename: Filename to write to.
230            img_format: Image format to use. Defaults to EPS.
231            xlim: Specifies the x-axis limits. Set to None for automatic
232                determination.
233            ylim: Specifies the y-axis limits.
234        """
235        plt = self.get_plot(xlim, ylim)
236        plt.savefig(filename, format=img_format)
237
238    def show(self, xlim=None, ylim=None):
239        """
240        Show the plot using matplotlib.
241
242        Args:
243            xlim: Specifies the x-axis limits. Set to None for automatic
244                determination.
245            ylim: Specifies the y-axis limits.
246        """
247        plt = self.get_plot(xlim, ylim)
248        plt.show()
249
250
251class BSPlotter:
252    """
253    Class to plot or get data to facilitate the plot of band structure objects.
254    """
255
256    def __init__(self, bs):
257        """
258        Args:
259            bs: A BandStructureSymmLine object.
260        """
261
262        self._bs = []
263        self._nb_bands = []
264
265        self.add_bs(bs)
266
267    def _check_bs_kpath(self, bs_list):
268        """
269        Helper method that chack the all the band objs in bs_list are
270        BandStructureSymmLine objs and they all have the same kpath.
271        """
272
273        # check obj type
274        for bs in bs_list:
275            if not isinstance(bs, BandStructureSymmLine):
276                raise ValueError(
277                    "BSPlotter only works with BandStructureSymmLine objects. "
278                    "A BandStructure object (on a uniform grid for instance and "
279                    "not along symmetry lines won't work)"
280                )
281
282        # check the kpath
283        if len(bs_list) == 1 and self._bs == []:
284            return True
285
286        if self._bs == []:
287            kpath_ref = [br["name"] for br in bs_list[0].branches]
288        else:
289            kpath_ref = [br["name"] for br in self._bs[0].branches]
290
291        for bs in bs_list:
292            if kpath_ref != [br["name"] for br in bs.branches]:
293                msg = (
294                    f"BSPlotter only works with BandStructureSymmLine "
295                    f"which have the same kpath. \n{bs} has a different kpath!"
296                )
297                raise ValueError(msg)
298
299        return True
300
301    def add_bs(self, bs):
302        """
303        Method to add bands objects to the BSPlotter
304        """
305        if not isinstance(bs, list):
306            bs = [bs]
307
308        if self._check_bs_kpath(bs):
309            self._bs.extend(bs)
310            # TODO: come with an intelligent way to cut the highest unconverged
311            # bands
312            self._nb_bands.extend([b.nb_bands for b in bs])
313
314    def _maketicks(self, plt):
315        """
316        utility private method to add ticks to a band structure
317        """
318        ticks = self.get_ticks()
319        # Sanitize only plot the uniq values
320        uniq_d = []
321        uniq_l = []
322        temp_ticks = list(zip(ticks["distance"], ticks["label"]))
323        for i, t in enumerate(temp_ticks):
324            if i == 0:
325                uniq_d.append(t[0])
326                uniq_l.append(t[1])
327                logger.debug("Adding label {l} at {d}".format(l=t[0], d=t[1]))
328            else:
329                if t[1] == temp_ticks[i - 1][1]:
330                    logger.debug("Skipping label {i}".format(i=t[1]))
331                else:
332                    logger.debug("Adding label {l} at {d}".format(l=t[0], d=t[1]))
333                    uniq_d.append(t[0])
334                    uniq_l.append(t[1])
335
336        logger.debug("Unique labels are %s" % list(zip(uniq_d, uniq_l)))
337        plt.gca().set_xticks(uniq_d)
338        plt.gca().set_xticklabels(uniq_l)
339
340        for i in range(len(ticks["label"])):
341            if ticks["label"][i] is not None:
342                # don't print the same label twice
343                if i != 0:
344                    if ticks["label"][i] == ticks["label"][i - 1]:
345                        logger.debug("already print label... " "skipping label {i}".format(i=ticks["label"][i]))
346                    else:
347                        logger.debug(
348                            "Adding a line at {d}" " for label {l}".format(d=ticks["distance"][i], l=ticks["label"][i])
349                        )
350                        plt.axvline(ticks["distance"][i], color="k")
351                else:
352                    logger.debug(
353                        "Adding a line at {d} for label {l}".format(d=ticks["distance"][i], l=ticks["label"][i])
354                    )
355                    plt.axvline(ticks["distance"][i], color="k")
356        return plt
357
358    @staticmethod
359    def _get_branch_steps(branches):
360        """
361        Method to find discontinuous branches
362        """
363        steps = [0]
364        for b1, b2 in zip(branches[:-1], branches[1:]):
365            if b2["name"].split("-")[0] != b1["name"].split("-")[-1]:
366                steps.append(b2["start_index"])
367        steps.append(branches[-1]["end_index"] + 1)
368        return steps
369
370    @staticmethod
371    def _rescale_distances(bs_ref, bs):
372        """
373        Method to rescale distances of bs to distances in bs_ref.
374        This is used for plotting two bandstructures (same k-path)
375        of different materials.
376        """
377        scaled_distances = []
378
379        for br, br2 in zip(bs_ref.branches, bs.branches):
380            s = br["start_index"]
381            e = br["end_index"]
382            max_d = bs_ref.distance[e]
383            min_d = bs_ref.distance[s]
384            s2 = br2["start_index"]
385            e2 = br2["end_index"]
386            np = e2 - s2
387            if np == 0:
388                # it deals with single point branches
389                scaled_distances.extend([min_d])
390            else:
391                scaled_distances.extend([(max_d - min_d) / np * i + min_d for i in range(np + 1)])
392
393        return scaled_distances
394
395    def bs_plot_data(self, zero_to_efermi=True, bs=None, bs_ref=None, split_branches=True):
396        """
397        Get the data nicely formatted for a plot
398
399        Args:
400            zero_to_efermi: Automatically subtract off the Fermi energy from the
401                eigenvalues and plot.
402            bs: the bandstructure to get the data from. If not provided, the first
403                one in the self._bs list will be used.
404            bs_ref: is the bandstructure of reference when a rescale of the distances
405                is need to plot multiple bands
406            split_branches: if True distances and energies are split according to the
407                branches. If False distances and energies are split only where branches
408                are discontinuous (reducing the number of lines to plot).
409
410        Returns:
411            dict: A dictionary of the following format:
412            ticks: A dict with the 'distances' at which there is a kpoint (the
413            x axis) and the labels (None if no label).
414            energy: A dict storing bands for spin up and spin down data
415            {Spin:[np.array(nb_bands,kpoints),...]} as a list of discontinuous kpath
416            of energies. The energy of multiple continuous branches are stored together.
417            vbm: A list of tuples (distance,energy) marking the vbms. The
418            energies are shifted with respect to the fermi level is the
419            option has been selected.
420            cbm: A list of tuples (distance,energy) marking the cbms. The
421            energies are shifted with respect to the fermi level is the
422            option has been selected.
423            lattice: The reciprocal lattice.
424            zero_energy: This is the energy used as zero for the plot.
425            band_gap:A string indicating the band gap and its nature (empty if
426            it's a metal).
427            is_metal: True if the band structure is metallic (i.e., there is at
428            least one band crossing the fermi level).
429        """
430
431        if bs is None:
432            if isinstance(self._bs, list):
433                # if BSPlotter
434                bs = self._bs[0]
435            else:
436                # if BSPlotterProjected
437                bs = self._bs
438
439        energies = {str(sp): [] for sp in bs.bands.keys()}
440
441        bs_is_metal = bs.is_metal()
442
443        if not bs_is_metal:
444            vbm = bs.get_vbm()
445            cbm = bs.get_cbm()
446
447        zero_energy = 0.0
448        if zero_to_efermi:
449            if bs_is_metal:
450                zero_energy = bs.efermi
451            else:
452                zero_energy = vbm["energy"]
453
454        # rescale distances when a bs_ref is given as reference,
455        # and when bs and bs_ref have different points in branches.
456        # Usually bs_ref is the first one in self._bs list is bs_ref
457        distances = bs.distance
458        if bs_ref is not None:
459            if bs_ref.branches != bs.branches:
460                distances = self._rescale_distances(bs_ref, bs)
461
462        if split_branches:
463            steps = [br["end_index"] + 1 for br in bs.branches][:-1]
464        else:
465            # join all the continuous branches
466            # to reduce the total number of branches to plot
467            steps = self._get_branch_steps(bs.branches)[1:-1]
468
469        distances = np.split(distances, steps)
470        for sp in bs.bands.keys():
471            energies[str(sp)] = np.hsplit(bs.bands[sp] - zero_energy, steps)
472
473        ticks = self.get_ticks()
474
475        vbm_plot = []
476        cbm_plot = []
477        bg_str = ""
478
479        if not bs_is_metal:
480            for index in cbm["kpoint_index"]:
481                cbm_plot.append(
482                    (
483                        bs.distance[index],
484                        cbm["energy"] - zero_energy if zero_to_efermi else cbm["energy"],
485                    )
486                )
487
488            for index in vbm["kpoint_index"]:
489                vbm_plot.append(
490                    (
491                        bs.distance[index],
492                        vbm["energy"] - zero_energy if zero_to_efermi else vbm["energy"],
493                    )
494                )
495
496            bg = bs.get_band_gap()
497            direct = "Indirect"
498            if bg["direct"]:
499                direct = "Direct"
500
501            bg_str = "{} {} bandgap = {}".format(direct, bg["transition"], bg["energy"])
502
503        return {
504            "ticks": ticks,
505            "distances": distances,
506            "energy": energies,
507            "vbm": vbm_plot,
508            "cbm": cbm_plot,
509            "lattice": bs.lattice_rec.as_dict(),
510            "zero_energy": zero_energy,
511            "is_metal": bs_is_metal,
512            "band_gap": bg_str,
513        }
514
515    @staticmethod
516    def _interpolate_bands(distances, energies, smooth_tol=0, smooth_k=3, smooth_np=100):
517        """
518        Method that interpolates the provided energies using B-splines as
519        implemented in scipy.interpolate. Distances and energies has to provided
520        already split into pieces (branches work good, for longer segments
521        the interpolation may fail).
522
523        Interpolation failure can be caused by trying to fit an entire
524        band with one spline rather than fitting with piecewise splines
525        (splines are ill-suited to fit discontinuities).
526
527        The number of splines used to fit a band is determined by the
528        number of branches (high symmetry lines) defined in the
529        BandStructureSymmLine object (see BandStructureSymmLine._branches).
530        """
531
532        int_energies, int_distances = [], []
533        smooth_k_orig = smooth_k
534
535        for dist, ene in zip(distances, energies):
536            br_en = []
537            warning_nan = (
538                f"WARNING! Distance / branch, band cannot be "
539                f"interpolated. See full warning in source. "
540                f"If this is not a mistake, try increasing "
541                f"smooth_tol. Current smooth_tol is {smooth_tol}."
542            )
543
544            warning_m_fewer_k = (
545                f"The number of points (m) has to be higher then "
546                f"the order (k) of the splines. In this branch {len(dist)} "
547                f"points are found, while k is set to {smooth_k}. "
548                f"Smooth_k will be reduced to {smooth_k - 1} for this branch."
549            )
550
551            # skip single point branches
552            if len(dist) in (2, 3):
553                # reducing smooth_k when the number
554                # of points are fewer then k
555                smooth_k = len(dist) - 1
556                warnings.warn(warning_m_fewer_k)
557            elif len(dist) == 1:
558                warnings.warn("Skipping single point branch")
559                continue
560
561            int_distances.append(np.linspace(dist[0], dist[-1], smooth_np))
562
563            for ien in ene:
564                tck = scint.splrep(dist, ien, s=smooth_tol, k=smooth_k)
565
566                br_en.append(scint.splev(int_distances[-1], tck))
567
568            smooth_k = smooth_k_orig
569
570            int_energies.append(np.vstack(br_en))
571
572            if np.any(np.isnan(int_energies[-1])):
573                warnings.warn(warning_nan)
574
575        return int_distances, int_energies
576
577    def get_plot(
578        self,
579        zero_to_efermi=True,
580        ylim=None,
581        smooth=False,
582        vbm_cbm_marker=False,
583        smooth_tol=0,
584        smooth_k=3,
585        smooth_np=100,
586        bs_labels=[],
587    ):
588        """
589        Get a matplotlib object for the bandstructures plot.
590        Multiple bandstructure objs are plotted together if they have the
591        same high symm path.
592
593        Args:
594            zero_to_efermi: Automatically subtract off the Fermi energy from
595                the eigenvalues and plot (E-Ef).
596            ylim: Specify the y-axis (energy) limits; by default None let
597                the code choose. It is vbm-4 and cbm+4 if insulator
598                efermi-10 and efermi+10 if metal
599            smooth (bool or list(bools)): interpolates the bands by a spline cubic.
600                A single bool values means to interpolate all the bandstructure objs.
601                A list of bools allows to select the bandstructure obs to interpolate.
602            smooth_tol (float) : tolerance for fitting spline to band data.
603                Default is None such that no tolerance will be used.
604            smooth_k (int): degree of splines 1<k<5
605            smooth_np (int): number of interpolated points per each branch.
606            bs_labels: labels for each band for the plot legend.
607        """
608        plt = pretty_plot(12, 8)
609
610        if isinstance(smooth, bool):
611            smooth = [smooth] * len(self._bs)
612
613        handles = []
614        vbm_min, cbm_max = [], []
615
616        colors = list(plt.rcParams["axes.prop_cycle"].by_key().values())[0]
617        for ibs, bs in enumerate(self._bs):
618
619            # set first bs in the list as ref for rescaling the distances of the other bands
620            bs_ref = self._bs[0] if len(self._bs) > 1 and ibs > 0 else None
621
622            if smooth[ibs]:
623                # interpolation works good on short segments like branches
624                data = self.bs_plot_data(zero_to_efermi, bs, bs_ref, split_branches=True)
625            else:
626                data = self.bs_plot_data(zero_to_efermi, bs, bs_ref, split_branches=False)
627
628            # remember if one bs is a metal for setting the ylim later
629            one_is_metal = False
630            if not one_is_metal and data["is_metal"]:
631                one_is_metal = data["is_metal"]
632
633            # remember all the cbm and vbm for setting the ylim later
634            if not data["is_metal"]:
635                cbm_max.append(data["cbm"][0][1])
636                vbm_min.append(data["vbm"][0][1])
637
638            for sp in bs.bands.keys():
639                ls = "-" if str(sp) == "1" else "--"
640
641                if bs_labels != []:
642                    bs_label = f"{bs_labels[ibs]} {sp.name}"
643                else:
644                    bs_label = f"Band {ibs} {sp.name}"
645
646                handles.append(mlines.Line2D([], [], lw=2, ls=ls, color=colors[ibs], label=bs_label))
647
648                distances, energies = data["distances"], data["energy"][str(sp)]
649
650                if smooth[ibs]:
651                    distances, energies = self._interpolate_bands(
652                        distances,
653                        energies,
654                        smooth_tol=smooth_tol,
655                        smooth_k=smooth_k,
656                        smooth_np=smooth_np,
657                    )
658                    # join all branches together
659                    distances = np.hstack(distances)
660                    energies = np.hstack(energies)
661                    # split only discontinuous branches
662                    steps = self._get_branch_steps(bs.branches)[1:-1]
663                    distances = np.split(distances, steps)
664                    energies = np.hsplit(energies, steps)
665
666                for dist, ene in zip(distances, energies):
667                    plt.plot(dist, ene.T, c=colors[ibs], ls=ls)
668
669            # plot markers for vbm and cbm
670            if vbm_cbm_marker:
671                for cbm in data["cbm"]:
672                    plt.scatter(cbm[0], cbm[1], color="r", marker="o", s=100)
673                for vbm in data["vbm"]:
674                    plt.scatter(vbm[0], vbm[1], color="g", marker="o", s=100)
675
676            # Draw Fermi energy, only if not the zero
677            if not zero_to_efermi:
678                ef = bs.efermi
679                plt.axhline(ef, lw=2, ls="-.", color=colors[ibs])
680
681        # defaults for ylim
682        e_min = -4
683        e_max = 4
684        if one_is_metal:
685            e_min = -10
686            e_max = 10
687
688        if ylim is None:
689            if zero_to_efermi:
690                if one_is_metal:
691                    # Plot A Metal
692                    plt.ylim(e_min, e_max)
693                else:
694                    plt.ylim(e_min, max(cbm_max) + e_max)
695            else:
696                all_efermi = [b.efermi for b in self._bs]
697                ll = min([min(vbm_min), min(all_efermi)])
698                hh = max([max(cbm_max), max(all_efermi)])
699                plt.ylim(ll + e_min, hh + e_max)
700        else:
701            plt.ylim(ylim)
702
703        self._maketicks(plt)
704
705        # Main X and Y Labels
706        plt.xlabel(r"$\mathrm{Wave\ Vector}$", fontsize=30)
707        ylabel = r"$\mathrm{E\ -\ E_f\ (eV)}$" if zero_to_efermi else r"$\mathrm{Energy\ (eV)}$"
708        plt.ylabel(ylabel, fontsize=30)
709
710        # X range (K)
711        # last distance point
712        x_max = data["distances"][-1][-1]
713        plt.xlim(0, x_max)
714
715        plt.legend(handles=handles)
716
717        plt.tight_layout()
718
719        # auto tight_layout when resizing or pressing t
720        def fix_layout(event):
721            if (event.name == "key_press_event" and event.key == "t") or event.name == "resize_event":
722                plt.gcf().tight_layout()
723                plt.gcf().canvas.draw()
724
725        plt.gcf().canvas.mpl_connect("key_press_event", fix_layout)
726        plt.gcf().canvas.mpl_connect("resize_event", fix_layout)
727
728        return plt
729
730    def show(self, zero_to_efermi=True, ylim=None, smooth=False, smooth_tol=None):
731        """
732        Show the plot using matplotlib.
733
734        Args:
735            zero_to_efermi: Automatically subtract off the Fermi energy from
736                the eigenvalues and plot (E-Ef).
737            ylim: Specify the y-axis (energy) limits; by default None let
738                the code choose. It is vbm-4 and cbm+4 if insulator
739                efermi-10 and efermi+10 if metal
740            smooth: interpolates the bands by a spline cubic
741            smooth_tol (float) : tolerance for fitting spline to band data.
742                Default is None such that no tolerance will be used.
743        """
744        plt = self.get_plot(zero_to_efermi, ylim, smooth)
745        plt.show()
746
747    def save_plot(self, filename, img_format="eps", ylim=None, zero_to_efermi=True, smooth=False):
748        """
749        Save matplotlib plot to a file.
750
751        Args:
752            filename: Filename to write to.
753            img_format: Image format to use. Defaults to EPS.
754            ylim: Specifies the y-axis limits.
755        """
756        plt = self.get_plot(ylim=ylim, zero_to_efermi=zero_to_efermi, smooth=smooth)
757        plt.savefig(filename, format=img_format)
758        plt.close()
759
760    def get_ticks(self):
761        """
762        Get all ticks and labels for a band structure plot.
763
764        Returns:
765            dict: A dictionary with 'distance': a list of distance at which
766            ticks should be set and 'label': a list of label for each of those
767            ticks.
768        """
769        bs = self._bs[0] if isinstance(self._bs, list) else self._bs
770        ticks, distance = [], []
771        for br in bs.branches:
772            s, e = br["start_index"], br["end_index"]
773
774            labels = br["name"].split("-")
775
776            # skip those branches with only one point
777            if labels[0] == labels[1]:
778                continue
779
780            # add latex $$
781            for i, l in enumerate(labels):
782                if l.startswith("\\") or "_" in l:
783                    labels[i] = "$" + l + "$"
784
785            # If next branch is not continuous,
786            # join the firts lbl to the previous tick label
787            # and add the second lbl to ticks list
788            # otherwise add to ticks list both new labels.
789            # Similar for distances.
790            if ticks != [] and labels[0] != ticks[-1]:
791                ticks[-1] += "$\\mid$" + labels[0]
792                ticks.append(labels[1])
793                distance.append(bs.distance[e])
794            else:
795                ticks.extend(labels)
796                distance.extend([bs.distance[s], bs.distance[e]])
797
798        return {"distance": distance, "label": ticks}
799
800    def get_ticks_old(self):
801        """
802        Get all ticks and labels for a band structure plot.
803
804        Returns:
805            dict: A dictionary with 'distance': a list of distance at which
806            ticks should be set and 'label': a list of label for each of those
807            ticks.
808        """
809        bs = self._bs[0]
810        tick_distance = []
811        tick_labels = []
812        previous_label = bs.kpoints[0].label
813        previous_branch = bs.branches[0]["name"]
814        for i, c in enumerate(bs.kpoints):
815            if c.label is not None:
816                tick_distance.append(bs.distance[i])
817                this_branch = None
818                for b in bs.branches:
819                    if b["start_index"] <= i <= b["end_index"]:
820                        this_branch = b["name"]
821                        break
822                if c.label != previous_label and previous_branch != this_branch:
823                    label1 = c.label
824                    if label1.startswith("\\") or label1.find("_") != -1:
825                        label1 = "$" + label1 + "$"
826                    label0 = previous_label
827                    if label0.startswith("\\") or label0.find("_") != -1:
828                        label0 = "$" + label0 + "$"
829                    tick_labels.pop()
830                    tick_distance.pop()
831                    tick_labels.append(label0 + "$\\mid$" + label1)
832                else:
833                    if c.label.startswith("\\") or c.label.find("_") != -1:
834                        tick_labels.append("$" + c.label + "$")
835                    else:
836                        tick_labels.append(c.label)
837                previous_label = c.label
838                previous_branch = this_branch
839        return {"distance": tick_distance, "label": tick_labels}
840
841    def plot_compare(self, other_plotter, legend=True):
842        """
843        plot two band structure for comparison. One is in red the other in blue
844        (no difference in spins). The two band structures need to be defined
845        on the same symmetry lines! and the distance between symmetry lines is
846        the one of the band structure used to build the BSPlotter
847
848        Args:
849            another band structure object defined along the same symmetry lines
850
851        Returns:
852            a matplotlib object with both band structures
853
854        """
855        warnings.warn("Deprecated method. " "Use BSPlotter([sbs1,sbs2,...]).get_plot() instead.")
856
857        # TODO: add exception if the band structures are not compatible
858        import matplotlib.lines as mlines
859
860        plt = self.get_plot()
861        data_orig = self.bs_plot_data()
862        data = other_plotter.bs_plot_data()
863        band_linewidth = 1
864        for i in range(other_plotter._nb_bands):
865            for d in range(len(data_orig["distances"])):
866                plt.plot(
867                    data_orig["distances"][d],
868                    [e[str(Spin.up)][i] for e in data["energy"]][d],
869                    "c-",
870                    linewidth=band_linewidth,
871                )
872                if other_plotter._bs.is_spin_polarized:
873                    plt.plot(
874                        data_orig["distances"][d],
875                        [e[str(Spin.down)][i] for e in data["energy"]][d],
876                        "m--",
877                        linewidth=band_linewidth,
878                    )
879        if legend:
880            handles = [
881                mlines.Line2D([], [], linewidth=2, color="b", label="bs 1 up"),
882                mlines.Line2D([], [], linewidth=2, color="r", label="bs 1 down", linestyle="--"),
883                mlines.Line2D([], [], linewidth=2, color="c", label="bs 2 up"),
884                mlines.Line2D([], [], linewidth=2, color="m", linestyle="--", label="bs 2 down"),
885            ]
886
887            plt.legend(handles=handles)
888        return plt
889
890    def plot_brillouin(self):
891        """
892        plot the Brillouin zone
893        """
894
895        # get labels and lines
896        labels = {}
897        for k in self._bs[0].kpoints:
898            if k.label:
899                labels[k.label] = k.frac_coords
900
901        lines = []
902        for b in self._bs[0].branches:
903            lines.append(
904                [
905                    self._bs[0].kpoints[b["start_index"]].frac_coords,
906                    self._bs[0].kpoints[b["end_index"]].frac_coords,
907                ]
908            )
909
910        plot_brillouin_zone(self._bs[0].lattice_rec, lines=lines, labels=labels)
911
912
913class BSPlotterProjected(BSPlotter):
914    """
915    Class to plot or get data to facilitate the plot of band structure objects
916    projected along orbitals, elements or sites.
917    """
918
919    def __init__(self, bs):
920        """
921        Args:
922            bs: A BandStructureSymmLine object with projections.
923        """
924        if isinstance(bs, list):
925            warnings.warn(
926                "Multiple bands are not handled by BSPlotterProjected." "The first band in the list will be considered"
927            )
928            bs = bs[0]
929
930        if len(bs.projections) == 0:
931            raise ValueError("try to plot projections on a band structure without any")
932
933        self._bs = bs
934        self._nb_bands = bs.nb_bands
935
936    def _get_projections_by_branches(self, dictio):
937        proj = self._bs.get_projections_on_elements_and_orbitals(dictio)
938        proj_br = []
939        for b in self._bs.branches:
940            if self._bs.is_spin_polarized:
941                proj_br.append(
942                    {
943                        str(Spin.up): [[] for l in range(self._nb_bands)],
944                        str(Spin.down): [[] for l in range(self._nb_bands)],
945                    }
946                )
947            else:
948                proj_br.append({str(Spin.up): [[] for l in range(self._nb_bands)]})
949
950            for i in range(self._nb_bands):
951                for j in range(b["start_index"], b["end_index"] + 1):
952                    proj_br[-1][str(Spin.up)][i].append(
953                        {e: {o: proj[Spin.up][i][j][e][o] for o in proj[Spin.up][i][j][e]} for e in proj[Spin.up][i][j]}
954                    )
955            if self._bs.is_spin_polarized:
956                for b in self._bs.branches:
957                    for i in range(self._nb_bands):
958                        for j in range(b["start_index"], b["end_index"] + 1):
959                            proj_br[-1][str(Spin.down)][i].append(
960                                {
961                                    e: {o: proj[Spin.down][i][j][e][o] for o in proj[Spin.down][i][j][e]}
962                                    for e in proj[Spin.down][i][j]
963                                }
964                            )
965        return proj_br
966
967    def get_projected_plots_dots(self, dictio, zero_to_efermi=True, ylim=None, vbm_cbm_marker=False):
968        """
969        Method returning a plot composed of subplots along different elements
970        and orbitals.
971
972        Args:
973            dictio: The element and orbitals you want a projection on. The
974                format is {Element:[Orbitals]} for instance
975                {'Cu':['d','s'],'O':['p']} will give projections for Cu on
976                d and s orbitals and on oxygen p.
977                If you use this class to plot LobsterBandStructureSymmLine,
978                the orbitals are named as in the FATBAND filename, e.g.
979                "2p" or "2p_x"
980
981        Returns:
982            a pylab object with different subfigures for each projection
983            The blue and red colors are for spin up and spin down.
984            The bigger the red or blue dot in the band structure the higher
985            character for the corresponding element and orbital.
986        """
987        band_linewidth = 1.0
988        fig_cols = len(dictio) * 100
989        fig_rows = max([len(v) for v in dictio.values()]) * 10
990        proj = self._get_projections_by_branches(dictio)
991        data = self.bs_plot_data(zero_to_efermi)
992        plt = pretty_plot(12, 8)
993        e_min = -4
994        e_max = 4
995        if self._bs.is_metal():
996            e_min = -10
997            e_max = 10
998        count = 1
999
1000        for el in dictio:
1001            for o in dictio[el]:
1002                plt.subplot(fig_rows + fig_cols + count)
1003                self._maketicks(plt)
1004                for b in range(len(data["distances"])):
1005                    for i in range(self._nb_bands):
1006                        plt.plot(
1007                            data["distances"][b],
1008                            data["energy"][str(Spin.up)][b][i],
1009                            "b-",
1010                            linewidth=band_linewidth,
1011                        )
1012                        if self._bs.is_spin_polarized:
1013                            plt.plot(
1014                                data["distances"][b],
1015                                data["energy"][str(Spin.down)][b][i],
1016                                "r--",
1017                                linewidth=band_linewidth,
1018                            )
1019                            for j in range(len(data["energy"][str(Spin.up)][b][i])):
1020                                plt.plot(
1021                                    data["distances"][b][j],
1022                                    data["energy"][str(Spin.down)][b][i][j],
1023                                    "ro",
1024                                    markersize=proj[b][str(Spin.down)][i][j][str(el)][o] * 15.0,
1025                                )
1026                        for j in range(len(data["energy"][str(Spin.up)][b][i])):
1027                            plt.plot(
1028                                data["distances"][b][j],
1029                                data["energy"][str(Spin.up)][b][i][j],
1030                                "bo",
1031                                markersize=proj[b][str(Spin.up)][i][j][str(el)][o] * 15.0,
1032                            )
1033                if ylim is None:
1034                    if self._bs.is_metal():
1035                        if zero_to_efermi:
1036                            plt.ylim(e_min, e_max)
1037                        else:
1038                            plt.ylim(self._bs.efermi + e_min, self._bs.efermi + e_max)
1039                    else:
1040                        if vbm_cbm_marker:
1041                            for cbm in data["cbm"]:
1042                                plt.scatter(cbm[0], cbm[1], color="r", marker="o", s=100)
1043
1044                            for vbm in data["vbm"]:
1045                                plt.scatter(vbm[0], vbm[1], color="g", marker="o", s=100)
1046
1047                        plt.ylim(data["vbm"][0][1] + e_min, data["cbm"][0][1] + e_max)
1048                else:
1049                    plt.ylim(ylim)
1050                plt.title(str(el) + " " + str(o))
1051                count += 1
1052        return plt
1053
1054    def get_elt_projected_plots(self, zero_to_efermi=True, ylim=None, vbm_cbm_marker=False):
1055        """
1056        Method returning a plot composed of subplots along different elements
1057
1058        Returns:
1059            a pylab object with different subfigures for each projection
1060            The blue and red colors are for spin up and spin down
1061            The bigger the red or blue dot in the band structure the higher
1062            character for the corresponding element and orbital
1063        """
1064        band_linewidth = 1.0
1065        proj = self._get_projections_by_branches(
1066            {e.symbol: ["s", "p", "d"] for e in self._bs.structure.composition.elements}
1067        )
1068        data = self.bs_plot_data(zero_to_efermi)
1069        plt = pretty_plot(12, 8)
1070        e_min = -4
1071        e_max = 4
1072        if self._bs.is_metal():
1073            e_min = -10
1074            e_max = 10
1075        count = 1
1076        for el in self._bs.structure.composition.elements:
1077            plt.subplot(220 + count)
1078            self._maketicks(plt)
1079            for b in range(len(data["distances"])):
1080                for i in range(self._nb_bands):
1081                    plt.plot(
1082                        data["distances"][b],
1083                        data["energy"][str(Spin.up)][b][i],
1084                        "-",
1085                        color=[192 / 255, 192 / 255, 192 / 255],
1086                        linewidth=band_linewidth,
1087                    )
1088                    if self._bs.is_spin_polarized:
1089                        plt.plot(
1090                            data["distances"][b],
1091                            data["energy"][str(Spin.down)][b][i],
1092                            "--",
1093                            color=[128 / 255, 128 / 255, 128 / 255],
1094                            linewidth=band_linewidth,
1095                        )
1096                        for j in range(len(data["energy"][str(Spin.up)][b][i])):
1097                            markerscale = sum(
1098                                [
1099                                    proj[b][str(Spin.down)][i][j][str(el)][o]
1100                                    for o in proj[b][str(Spin.down)][i][j][str(el)]
1101                                ]
1102                            )
1103                            plt.plot(
1104                                data["distances"][b][j],
1105                                data["energy"][str(Spin.down)][b][i][j],
1106                                "bo",
1107                                markersize=markerscale * 15.0,
1108                                color=[
1109                                    markerscale,
1110                                    0.3 * markerscale,
1111                                    0.4 * markerscale,
1112                                ],
1113                            )
1114                    for j in range(len(data["energy"][str(Spin.up)][b][i])):
1115                        markerscale = sum(
1116                            [proj[b][str(Spin.up)][i][j][str(el)][o] for o in proj[b][str(Spin.up)][i][j][str(el)]]
1117                        )
1118                        plt.plot(
1119                            data["distances"][b][j],
1120                            data["energy"][str(Spin.up)][b][i][j],
1121                            "o",
1122                            markersize=markerscale * 15.0,
1123                            color=[markerscale, 0.3 * markerscale, 0.4 * markerscale],
1124                        )
1125            if ylim is None:
1126                if self._bs.is_metal():
1127                    if zero_to_efermi:
1128                        plt.ylim(e_min, e_max)
1129                    else:
1130                        plt.ylim(self._bs.efermi + e_min, self._bs.efermi + e_max)
1131                else:
1132                    if vbm_cbm_marker:
1133                        for cbm in data["cbm"]:
1134                            plt.scatter(cbm[0], cbm[1], color="r", marker="o", s=100)
1135
1136                        for vbm in data["vbm"]:
1137                            plt.scatter(vbm[0], vbm[1], color="g", marker="o", s=100)
1138
1139                    plt.ylim(data["vbm"][0][1] + e_min, data["cbm"][0][1] + e_max)
1140            else:
1141                plt.ylim(ylim)
1142            plt.title(str(el))
1143            count += 1
1144
1145        return plt
1146
1147    def get_elt_projected_plots_color(self, zero_to_efermi=True, elt_ordered=None):
1148        """
1149        returns a pylab plot object with one plot where the band structure
1150        line color depends on the character of the band (along different
1151        elements). Each element is associated with red, green or blue
1152        and the corresponding rgb color depending on the character of the band
1153        is used. The method can only deal with binary and ternary compounds
1154
1155        spin up and spin down are differientiated by a '-' and a '--' line
1156
1157        Args:
1158            elt_ordered: A list of Element ordered. The first one is red,
1159                second green, last blue
1160
1161        Returns:
1162            a pylab object
1163
1164        """
1165        band_linewidth = 3.0
1166        if len(self._bs.structure.composition.elements) > 3:
1167            raise ValueError
1168        if elt_ordered is None:
1169            elt_ordered = self._bs.structure.composition.elements
1170        proj = self._get_projections_by_branches(
1171            {e.symbol: ["s", "p", "d"] for e in self._bs.structure.composition.elements}
1172        )
1173        data = self.bs_plot_data(zero_to_efermi)
1174        plt = pretty_plot(12, 8)
1175
1176        spins = [Spin.up]
1177        if self._bs.is_spin_polarized:
1178            spins = [Spin.up, Spin.down]
1179        self._maketicks(plt)
1180        for s in spins:
1181            for b in range(len(data["distances"])):
1182                for i in range(self._nb_bands):
1183                    for j in range(len(data["energy"][str(s)][b][i]) - 1):
1184                        sum_e = 0.0
1185                        for el in elt_ordered:
1186                            sum_e = sum_e + sum(
1187                                [proj[b][str(s)][i][j][str(el)][o] for o in proj[b][str(s)][i][j][str(el)]]
1188                            )
1189                        if sum_e == 0.0:
1190                            color = [0.0] * len(elt_ordered)
1191                        else:
1192                            color = [
1193                                sum([proj[b][str(s)][i][j][str(el)][o] for o in proj[b][str(s)][i][j][str(el)]]) / sum_e
1194                                for el in elt_ordered
1195                            ]
1196                        if len(color) == 2:
1197                            color.append(0.0)
1198                            color[2] = color[1]
1199                            color[1] = 0.0
1200                        sign = "-"
1201                        if s == Spin.down:
1202                            sign = "--"
1203                        plt.plot(
1204                            [data["distances"][b][j], data["distances"][b][j + 1]],
1205                            [
1206                                data["energy"][str(s)][b][i][j],
1207                                data["energy"][str(s)][b][i][j + 1],
1208                            ],
1209                            sign,
1210                            color=color,
1211                            linewidth=band_linewidth,
1212                        )
1213
1214        if self._bs.is_metal():
1215            if zero_to_efermi:
1216                e_min = -10
1217                e_max = 10
1218                plt.ylim(e_min, e_max)
1219                plt.ylim(self._bs.efermi + e_min, self._bs.efermi + e_max)
1220        else:
1221            plt.ylim(data["vbm"][0][1] - 4.0, data["cbm"][0][1] + 2.0)
1222        return plt
1223
1224    def _get_projections_by_branches_patom_pmorb(self, dictio, dictpa, sum_atoms, sum_morbs, selected_branches):
1225        import copy
1226
1227        setos = {
1228            "s": 0,
1229            "py": 1,
1230            "pz": 2,
1231            "px": 3,
1232            "dxy": 4,
1233            "dyz": 5,
1234            "dz2": 6,
1235            "dxz": 7,
1236            "dx2": 8,
1237            "f_3": 9,
1238            "f_2": 10,
1239            "f_1": 11,
1240            "f0": 12,
1241            "f1": 13,
1242            "f2": 14,
1243            "f3": 15,
1244        }
1245
1246        num_branches = len(self._bs.branches)
1247        if selected_branches is not None:
1248            indices = []
1249            if not isinstance(selected_branches, list):
1250                raise TypeError("You do not give a correct type of 'selected_branches'. It should be 'list' type.")
1251            if len(selected_branches) == 0:
1252                raise ValueError("The 'selected_branches' is empty. We cannot do anything.")
1253            for index in selected_branches:
1254                if not isinstance(index, int):
1255                    raise ValueError(
1256                        "You do not give a correct type of index of symmetry lines. It should be " "'int' type"
1257                    )
1258                if index > num_branches or index < 1:
1259                    raise ValueError(
1260                        "You give a incorrect index of symmetry lines: %s. The index should be in "
1261                        "range of [1, %s]." % (str(index), str(num_branches))
1262                    )
1263                indices.append(index - 1)
1264        else:
1265            indices = range(0, num_branches)
1266
1267        proj = self._bs.projections
1268        proj_br = []
1269        for index in indices:
1270            b = self._bs.branches[index]
1271            print(b)
1272            if self._bs.is_spin_polarized:
1273                proj_br.append(
1274                    {
1275                        str(Spin.up): [[] for l in range(self._nb_bands)],
1276                        str(Spin.down): [[] for l in range(self._nb_bands)],
1277                    }
1278                )
1279            else:
1280                proj_br.append({str(Spin.up): [[] for l in range(self._nb_bands)]})
1281
1282            for i in range(self._nb_bands):
1283                for j in range(b["start_index"], b["end_index"] + 1):
1284                    edict = {}
1285                    for elt in dictpa:
1286                        for anum in dictpa[elt]:
1287                            edict[elt + str(anum)] = {}
1288                            for morb in dictio[elt]:
1289                                edict[elt + str(anum)][morb] = proj[Spin.up][i][j][setos[morb]][anum - 1]
1290                    proj_br[-1][str(Spin.up)][i].append(edict)
1291
1292            if self._bs.is_spin_polarized:
1293                for i in range(self._nb_bands):
1294                    for j in range(b["start_index"], b["end_index"] + 1):
1295                        edict = {}
1296                        for elt in dictpa:
1297                            for anum in dictpa[elt]:
1298                                edict[elt + str(anum)] = {}
1299                                for morb in dictio[elt]:
1300                                    edict[elt + str(anum)][morb] = proj[Spin.up][i][j][setos[morb]][anum - 1]
1301                        proj_br[-1][str(Spin.down)][i].append(edict)
1302
1303        # Adjusting  projections for plot
1304        dictio_d, dictpa_d = self._summarize_keys_for_plot(dictio, dictpa, sum_atoms, sum_morbs)
1305        print("dictio_d: %s" % str(dictio_d))
1306        print("dictpa_d: %s" % str(dictpa_d))
1307
1308        if (sum_atoms is None) and (sum_morbs is None):
1309            proj_br_d = copy.deepcopy(proj_br)
1310        else:
1311            proj_br_d = []
1312            branch = -1
1313            for index in indices:
1314                branch += 1
1315                br = self._bs.branches[index]
1316                if self._bs.is_spin_polarized:
1317                    proj_br_d.append(
1318                        {
1319                            str(Spin.up): [[] for l in range(self._nb_bands)],
1320                            str(Spin.down): [[] for l in range(self._nb_bands)],
1321                        }
1322                    )
1323                else:
1324                    proj_br_d.append({str(Spin.up): [[] for l in range(self._nb_bands)]})
1325
1326                if (sum_atoms is not None) and (sum_morbs is None):
1327                    for i in range(self._nb_bands):
1328                        for j in range(br["end_index"] - br["start_index"] + 1):
1329                            atoms_morbs = copy.deepcopy(proj_br[branch][str(Spin.up)][i][j])
1330                            edict = {}
1331                            for elt in dictpa:
1332                                if elt in sum_atoms:
1333                                    for anum in dictpa_d[elt][:-1]:
1334                                        edict[elt + anum] = copy.deepcopy(atoms_morbs[elt + anum])
1335                                    edict[elt + dictpa_d[elt][-1]] = {}
1336                                    for morb in dictio[elt]:
1337                                        sprojection = 0.0
1338                                        for anum in sum_atoms[elt]:
1339                                            sprojection += atoms_morbs[elt + str(anum)][morb]
1340                                        edict[elt + dictpa_d[elt][-1]][morb] = sprojection
1341                                else:
1342                                    for anum in dictpa_d[elt]:
1343                                        edict[elt + anum] = copy.deepcopy(atoms_morbs[elt + anum])
1344                            proj_br_d[-1][str(Spin.up)][i].append(edict)
1345                    if self._bs.is_spin_polarized:
1346                        for i in range(self._nb_bands):
1347                            for j in range(br["end_index"] - br["start_index"] + 1):
1348                                atoms_morbs = copy.deepcopy(proj_br[branch][str(Spin.down)][i][j])
1349                                edict = {}
1350                                for elt in dictpa:
1351                                    if elt in sum_atoms:
1352                                        for anum in dictpa_d[elt][:-1]:
1353                                            edict[elt + anum] = copy.deepcopy(atoms_morbs[elt + anum])
1354                                        edict[elt + dictpa_d[elt][-1]] = {}
1355                                        for morb in dictio[elt]:
1356                                            sprojection = 0.0
1357                                            for anum in sum_atoms[elt]:
1358                                                sprojection += atoms_morbs[elt + str(anum)][morb]
1359                                            edict[elt + dictpa_d[elt][-1]][morb] = sprojection
1360                                    else:
1361                                        for anum in dictpa_d[elt]:
1362                                            edict[elt + anum] = copy.deepcopy(atoms_morbs[elt + anum])
1363                                proj_br_d[-1][str(Spin.down)][i].append(edict)
1364
1365                elif (sum_atoms is None) and (sum_morbs is not None):
1366                    for i in range(self._nb_bands):
1367                        for j in range(br["end_index"] - br["start_index"] + 1):
1368                            atoms_morbs = copy.deepcopy(proj_br[branch][str(Spin.up)][i][j])
1369                            edict = {}
1370                            for elt in dictpa:
1371                                if elt in sum_morbs:
1372                                    for anum in dictpa_d[elt]:
1373                                        edict[elt + anum] = {}
1374                                        for morb in dictio_d[elt][:-1]:
1375                                            edict[elt + anum][morb] = atoms_morbs[elt + anum][morb]
1376                                        sprojection = 0.0
1377                                        for morb in sum_morbs[elt]:
1378                                            sprojection += atoms_morbs[elt + anum][morb]
1379                                        edict[elt + anum][dictio_d[elt][-1]] = sprojection
1380                                else:
1381                                    for anum in dictpa_d[elt]:
1382                                        edict[elt + anum] = copy.deepcopy(atoms_morbs[elt + anum])
1383                            proj_br_d[-1][str(Spin.up)][i].append(edict)
1384                    if self._bs.is_spin_polarized:
1385                        for i in range(self._nb_bands):
1386                            for j in range(br["end_index"] - br["start_index"] + 1):
1387                                atoms_morbs = copy.deepcopy(proj_br[branch][str(Spin.down)][i][j])
1388                                edict = {}
1389                                for elt in dictpa:
1390                                    if elt in sum_morbs:
1391                                        for anum in dictpa_d[elt]:
1392                                            edict[elt + anum] = {}
1393                                            for morb in dictio_d[elt][:-1]:
1394                                                edict[elt + anum][morb] = atoms_morbs[elt + anum][morb]
1395                                            sprojection = 0.0
1396                                            for morb in sum_morbs[elt]:
1397                                                sprojection += atoms_morbs[elt + anum][morb]
1398                                            edict[elt + anum][dictio_d[elt][-1]] = sprojection
1399                                    else:
1400                                        for anum in dictpa_d[elt]:
1401                                            edict[elt + anum] = copy.deepcopy(atoms_morbs[elt + anum])
1402                                proj_br_d[-1][str(Spin.down)][i].append(edict)
1403
1404                else:
1405                    for i in range(self._nb_bands):
1406                        for j in range(br["end_index"] - br["start_index"] + 1):
1407                            atoms_morbs = copy.deepcopy(proj_br[branch][str(Spin.up)][i][j])
1408                            edict = {}
1409                            for elt in dictpa:
1410                                if (elt in sum_atoms) and (elt in sum_morbs):
1411                                    for anum in dictpa_d[elt][:-1]:
1412                                        edict[elt + anum] = {}
1413                                        for morb in dictio_d[elt][:-1]:
1414                                            edict[elt + anum][morb] = atoms_morbs[elt + anum][morb]
1415                                        sprojection = 0.0
1416                                        for morb in sum_morbs[elt]:
1417                                            sprojection += atoms_morbs[elt + anum][morb]
1418                                        edict[elt + anum][dictio_d[elt][-1]] = sprojection
1419
1420                                    edict[elt + dictpa_d[elt][-1]] = {}
1421                                    for morb in dictio_d[elt][:-1]:
1422                                        sprojection = 0.0
1423                                        for anum in sum_atoms[elt]:
1424                                            sprojection += atoms_morbs[elt + str(anum)][morb]
1425                                        edict[elt + dictpa_d[elt][-1]][morb] = sprojection
1426
1427                                    sprojection = 0.0
1428                                    for anum in sum_atoms[elt]:
1429                                        for morb in sum_morbs[elt]:
1430                                            sprojection += atoms_morbs[elt + str(anum)][morb]
1431                                    edict[elt + dictpa_d[elt][-1]][dictio_d[elt][-1]] = sprojection
1432
1433                                elif (elt in sum_atoms) and (elt not in sum_morbs):
1434                                    for anum in dictpa_d[elt][:-1]:
1435                                        edict[elt + anum] = copy.deepcopy(atoms_morbs[elt + anum])
1436                                    edict[elt + dictpa_d[elt][-1]] = {}
1437                                    for morb in dictio[elt]:
1438                                        sprojection = 0.0
1439                                        for anum in sum_atoms[elt]:
1440                                            sprojection += atoms_morbs[elt + str(anum)][morb]
1441                                        edict[elt + dictpa_d[elt][-1]][morb] = sprojection
1442
1443                                elif (elt not in sum_atoms) and (elt in sum_morbs):
1444                                    for anum in dictpa_d[elt]:
1445                                        edict[elt + anum] = {}
1446                                        for morb in dictio_d[elt][:-1]:
1447                                            edict[elt + anum][morb] = atoms_morbs[elt + anum][morb]
1448                                        sprojection = 0.0
1449                                        for morb in sum_morbs[elt]:
1450                                            sprojection += atoms_morbs[elt + anum][morb]
1451                                        edict[elt + anum][dictio_d[elt][-1]] = sprojection
1452
1453                                else:
1454                                    for anum in dictpa_d[elt]:
1455                                        edict[elt + anum] = {}
1456                                        for morb in dictio_d[elt]:
1457                                            edict[elt + anum][morb] = atoms_morbs[elt + anum][morb]
1458                            proj_br_d[-1][str(Spin.up)][i].append(edict)
1459
1460                    if self._bs.is_spin_polarized:
1461                        for i in range(self._nb_bands):
1462                            for j in range(br["end_index"] - br["start_index"] + 1):
1463                                atoms_morbs = copy.deepcopy(proj_br[branch][str(Spin.down)][i][j])
1464                                edict = {}
1465                                for elt in dictpa:
1466                                    if (elt in sum_atoms) and (elt in sum_morbs):
1467                                        for anum in dictpa_d[elt][:-1]:
1468                                            edict[elt + anum] = {}
1469                                            for morb in dictio_d[elt][:-1]:
1470                                                edict[elt + anum][morb] = atoms_morbs[elt + anum][morb]
1471                                            sprojection = 0.0
1472                                            for morb in sum_morbs[elt]:
1473                                                sprojection += atoms_morbs[elt + anum][morb]
1474                                            edict[elt + anum][dictio_d[elt][-1]] = sprojection
1475
1476                                        edict[elt + dictpa_d[elt][-1]] = {}
1477                                        for morb in dictio_d[elt][:-1]:
1478                                            sprojection = 0.0
1479                                            for anum in sum_atoms[elt]:
1480                                                sprojection += atoms_morbs[elt + str(anum)][morb]
1481                                            edict[elt + dictpa_d[elt][-1]][morb] = sprojection
1482
1483                                        sprojection = 0.0
1484                                        for anum in sum_atoms[elt]:
1485                                            for morb in sum_morbs[elt]:
1486                                                sprojection += atoms_morbs[elt + str(anum)][morb]
1487                                        edict[elt + dictpa_d[elt][-1]][dictio_d[elt][-1]] = sprojection
1488
1489                                    elif (elt in sum_atoms) and (elt not in sum_morbs):
1490                                        for anum in dictpa_d[elt][:-1]:
1491                                            edict[elt + anum] = copy.deepcopy(atoms_morbs[elt + anum])
1492                                        edict[elt + dictpa_d[elt][-1]] = {}
1493                                        for morb in dictio[elt]:
1494                                            sprojection = 0.0
1495                                            for anum in sum_atoms[elt]:
1496                                                sprojection += atoms_morbs[elt + str(anum)][morb]
1497                                            edict[elt + dictpa_d[elt][-1]][morb] = sprojection
1498
1499                                    elif (elt not in sum_atoms) and (elt in sum_morbs):
1500                                        for anum in dictpa_d[elt]:
1501                                            edict[elt + anum] = {}
1502                                            for morb in dictio_d[elt][:-1]:
1503                                                edict[elt + anum][morb] = atoms_morbs[elt + anum][morb]
1504                                            sprojection = 0.0
1505                                            for morb in sum_morbs[elt]:
1506                                                sprojection += atoms_morbs[elt + anum][morb]
1507                                            edict[elt + anum][dictio_d[elt][-1]] = sprojection
1508
1509                                    else:
1510                                        for anum in dictpa_d[elt]:
1511                                            edict[elt + anum] = {}
1512                                            for morb in dictio_d[elt]:
1513                                                edict[elt + anum][morb] = atoms_morbs[elt + anum][morb]
1514                                proj_br_d[-1][str(Spin.down)][i].append(edict)
1515
1516        return proj_br_d, dictio_d, dictpa_d, indices
1517
1518    def get_projected_plots_dots_patom_pmorb(
1519        self,
1520        dictio,
1521        dictpa,
1522        sum_atoms=None,
1523        sum_morbs=None,
1524        zero_to_efermi=True,
1525        ylim=None,
1526        vbm_cbm_marker=False,
1527        selected_branches=None,
1528        w_h_size=(12, 8),
1529        num_column=None,
1530    ):
1531        """
1532        Method returns a plot composed of subplots for different atoms and
1533        orbitals (subshell orbitals such as 's', 'p', 'd' and 'f' defined by
1534        azimuthal quantum numbers l = 0, 1, 2 and 3, respectively or
1535        individual orbitals like 'px', 'py' and 'pz' defined by magnetic
1536        quantum numbers m = -1, 1 and 0, respectively).
1537        This is an extension of "get_projected_plots_dots" method.
1538
1539        Args:
1540            dictio: The elements and the orbitals you need to project on. The
1541                format is {Element:[Orbitals]}, for instance:
1542                {'Cu':['dxy','s','px'],'O':['px','py','pz']} will give
1543                projections for Cu on orbitals dxy, s, px and
1544                for O on orbitals px, py, pz. If you want to sum over all
1545                individual orbitals of subshell orbitals,
1546                for example, 'px', 'py' and 'pz' of O, just simply set
1547                {'Cu':['dxy','s','px'],'O':['p']} and set sum_morbs (see
1548                explanations below) as {'O':[p],...}.
1549                Otherwise, you will get an error.
1550            dictpa: The elements and their sites (defined by site numbers) you
1551                need to project on. The format is
1552                {Element: [Site numbers]}, for instance: {'Cu':[1,5],'O':[3,4]}
1553                will give projections for Cu on site-1
1554                and on site-5, O on site-3 and on site-4 in the cell.
1555                Attention:
1556                The correct site numbers of atoms are consistent with
1557                themselves in the structure computed. Normally,
1558                the structure should be totally similar with POSCAR file,
1559                however, sometimes VASP can rotate or
1560                translate the cell. Thus, it would be safe if using Vasprun
1561                class to get the final_structure and as a
1562                result, correct index numbers of atoms.
1563            sum_atoms: Sum projection of the similar atoms together (e.g.: Cu
1564                on site-1 and Cu on site-5). The format is
1565                {Element: [Site numbers]}, for instance:
1566                 {'Cu': [1,5], 'O': [3,4]} means summing projections over Cu on
1567                 site-1 and Cu on site-5 and O on site-3
1568                 and on site-4. If you do not want to use this functional, just
1569                 turn it off by setting sum_atoms = None.
1570            sum_morbs: Sum projections of individual orbitals of similar atoms
1571                together (e.g.: 'dxy' and 'dxz'). The
1572                format is {Element: [individual orbitals]}, for instance:
1573                {'Cu': ['dxy', 'dxz'], 'O': ['px', 'py']} means summing
1574                projections over 'dxy' and 'dxz' of Cu and 'px'
1575                and 'py' of O. If you do not want to use this functional, just
1576                turn it off by setting sum_morbs = None.
1577            selected_branches: The index of symmetry lines you chose for
1578                plotting. This can be useful when the number of
1579                symmetry lines (in KPOINTS file) are manny while you only want
1580                to show for certain ones. The format is
1581                [index of line], for instance:
1582                [1, 3, 4] means you just need to do projection along lines
1583                number 1, 3 and 4 while neglecting lines
1584                number 2 and so on. By default, this is None type and all
1585                symmetry lines will be plotted.
1586            w_h_size: This variable help you to control the width and height
1587                of figure. By default, width = 12 and
1588                height = 8 (inches). The width/height ratio is kept the same
1589                for subfigures and the size of each depends
1590                on how many number of subfigures are plotted.
1591            num_column: This variable help you to manage how the subfigures are
1592                arranged in the figure by setting
1593                up the number of columns of subfigures. The value should be an
1594                int number. For example, num_column = 3
1595                means you want to plot subfigures in 3 columns. By default,
1596                num_column = None and subfigures are
1597                aligned in 2 columns.
1598
1599        Returns:
1600            A pylab object with different subfigures for different projections.
1601            The blue and red colors lines are bands
1602            for spin up and spin down. The green and cyan dots are projections
1603            for spin up and spin down. The bigger
1604            the green or cyan dots in the projected band structures, the higher
1605            character for the corresponding elements
1606            and orbitals. List of individual orbitals and their numbers (set up
1607            by VASP and no special meaning):
1608            s = 0; py = 1 pz = 2 px = 3; dxy = 4 dyz = 5 dz2 = 6 dxz = 7 dx2 = 8;
1609            f_3 = 9 f_2 = 10 f_1 = 11 f0 = 12 f1 = 13 f2 = 14 f3 = 15
1610        """
1611        dictio, sum_morbs = self._Orbitals_SumOrbitals(dictio, sum_morbs)
1612        dictpa, sum_atoms, number_figs = self._number_of_subfigures(dictio, dictpa, sum_atoms, sum_morbs)
1613        print("Number of subfigures: %s" % str(number_figs))
1614        if number_figs > 9:
1615            print(
1616                "The number of sub-figures %s might be too manny and the implementation might take a long time.\n"
1617                "A smaller number or a plot with selected symmetry lines (selected_branches) might be better.\n"
1618                % str(number_figs)
1619            )
1620        from pymatgen.util.plotting import pretty_plot
1621
1622        band_linewidth = 0.5
1623        plt = pretty_plot(w_h_size[0], w_h_size[1])
1624        (
1625            proj_br_d,
1626            dictio_d,
1627            dictpa_d,
1628            branches,
1629        ) = self._get_projections_by_branches_patom_pmorb(dictio, dictpa, sum_atoms, sum_morbs, selected_branches)
1630        data = self.bs_plot_data(zero_to_efermi)
1631        e_min = -4
1632        e_max = 4
1633        if self._bs.is_metal():
1634            e_min = -10
1635            e_max = 10
1636
1637        count = 0
1638        for elt in dictpa_d:
1639            for numa in dictpa_d[elt]:
1640                for o in dictio_d[elt]:
1641
1642                    count += 1
1643                    if num_column is None:
1644                        if number_figs == 1:
1645                            plt.subplot(1, 1, 1)
1646                        else:
1647                            row = number_figs / 2
1648                            if number_figs % 2 == 0:
1649                                plt.subplot(row, 2, count)
1650                            else:
1651                                plt.subplot(row + 1, 2, count)
1652                    elif isinstance(num_column, int):
1653                        row = number_figs / num_column
1654                        if number_figs % num_column == 0:
1655                            plt.subplot(row, num_column, count)
1656                        else:
1657                            plt.subplot(row + 1, num_column, count)
1658                    else:
1659                        raise ValueError("The invalid 'num_column' is assigned. It should be an integer.")
1660
1661                    plt, shift = self._maketicks_selected(plt, branches)
1662                    br = -1
1663                    for b in branches:
1664                        br += 1
1665                        for i in range(self._nb_bands):
1666                            plt.plot(
1667                                list(map(lambda x: x - shift[br], data["distances"][b])),
1668                                [data["energy"][str(Spin.up)][b][i][j] for j in range(len(data["distances"][b]))],
1669                                "b-",
1670                                linewidth=band_linewidth,
1671                            )
1672
1673                            if self._bs.is_spin_polarized:
1674                                plt.plot(
1675                                    list(
1676                                        map(
1677                                            lambda x: x - shift[br],
1678                                            data["distances"][b],
1679                                        )
1680                                    ),
1681                                    [data["energy"][str(Spin.down)][b][i][j] for j in range(len(data["distances"][b]))],
1682                                    "r--",
1683                                    linewidth=band_linewidth,
1684                                )
1685                                for j in range(len(data["energy"][str(Spin.up)][b][i])):
1686                                    plt.plot(
1687                                        data["distances"][b][j] - shift[br],
1688                                        data["energy"][str(Spin.down)][b][i][j],
1689                                        "co",
1690                                        markersize=proj_br_d[br][str(Spin.down)][i][j][elt + numa][o] * 15.0,
1691                                    )
1692
1693                            for j in range(len(data["energy"][str(Spin.up)][b][i])):
1694                                plt.plot(
1695                                    data["distances"][b][j] - shift[br],
1696                                    data["energy"][str(Spin.up)][b][i][j],
1697                                    "go",
1698                                    markersize=proj_br_d[br][str(Spin.up)][i][j][elt + numa][o] * 15.0,
1699                                )
1700
1701                    if ylim is None:
1702                        if self._bs.is_metal():
1703                            if zero_to_efermi:
1704                                plt.ylim(e_min, e_max)
1705                            else:
1706                                plt.ylim(self._bs.efermi + e_min, self._bs._efermi + e_max)
1707                        else:
1708                            if vbm_cbm_marker:
1709                                for cbm in data["cbm"]:
1710                                    plt.scatter(cbm[0], cbm[1], color="r", marker="o", s=100)
1711
1712                                for vbm in data["vbm"]:
1713                                    plt.scatter(vbm[0], vbm[1], color="g", marker="o", s=100)
1714
1715                            plt.ylim(data["vbm"][0][1] + e_min, data["cbm"][0][1] + e_max)
1716                    else:
1717                        plt.ylim(ylim)
1718                    plt.title(elt + " " + numa + " " + str(o))
1719
1720        return plt
1721
1722    @classmethod
1723    def _Orbitals_SumOrbitals(cls, dictio, sum_morbs):
1724        all_orbitals = [
1725            "s",
1726            "p",
1727            "d",
1728            "f",
1729            "px",
1730            "py",
1731            "pz",
1732            "dxy",
1733            "dyz",
1734            "dxz",
1735            "dx2",
1736            "dz2",
1737            "f_3",
1738            "f_2",
1739            "f_1",
1740            "f0",
1741            "f1",
1742            "f2",
1743            "f3",
1744        ]
1745        individual_orbs = {
1746            "p": ["px", "py", "pz"],
1747            "d": ["dxy", "dyz", "dxz", "dx2", "dz2"],
1748            "f": ["f_3", "f_2", "f_1", "f0", "f1", "f2", "f3"],
1749        }
1750
1751        if not isinstance(dictio, dict):
1752            raise TypeError("The invalid type of 'dictio' was bound. It should be dict type.")
1753        if len(dictio.keys()) == 0:
1754            raise KeyError("The 'dictio' is empty. We cannot do anything.")
1755
1756        for elt in dictio:
1757            if Element.is_valid_symbol(elt):
1758                if isinstance(dictio[elt], list):
1759                    if len(dictio[elt]) == 0:
1760                        raise ValueError("The dictio[%s] is empty. We cannot do anything" % elt)
1761                    for orb in dictio[elt]:
1762                        if not isinstance(orb, str):
1763                            raise ValueError(
1764                                "The invalid format of orbitals is in 'dictio[%s]': %s. "
1765                                "They should be string." % (elt, str(orb))
1766                            )
1767                        if orb not in all_orbitals:
1768                            raise ValueError("The invalid name of orbital is given in 'dictio[%s]'." % elt)
1769                        if orb in individual_orbs.keys():
1770                            if len(set(dictio[elt]).intersection(individual_orbs[orb])) != 0:
1771                                raise ValueError("The 'dictio[%s]' contains orbitals repeated." % elt)
1772                    nelems = Counter(dictio[elt]).values()
1773                    if sum(nelems) > len(nelems):
1774                        raise ValueError("You put in at least two similar orbitals in dictio[%s]." % elt)
1775                else:
1776                    raise TypeError(
1777                        "The invalid type of value was put into 'dictio[%s]'. It should be list " "type." % elt
1778                    )
1779            else:
1780                raise KeyError("The invalid element was put into 'dictio' as a key: %s" % elt)
1781
1782        if sum_morbs is None:
1783            print("You do not want to sum projection over orbitals.")
1784        elif not isinstance(sum_morbs, dict):
1785            raise TypeError("The invalid type of 'sum_orbs' was bound. It should be dict or 'None' type.")
1786        elif len(sum_morbs.keys()) == 0:
1787            raise KeyError("The 'sum_morbs' is empty. We cannot do anything")
1788        else:
1789            for elt in sum_morbs:
1790                if Element.is_valid_symbol(elt):
1791                    if isinstance(sum_morbs[elt], list):
1792                        for orb in sum_morbs[elt]:
1793                            if not isinstance(orb, str):
1794                                raise TypeError(
1795                                    "The invalid format of orbitals is in 'sum_morbs[%s]': %s. "
1796                                    "They should be string." % (elt, str(orb))
1797                                )
1798                            if orb not in all_orbitals:
1799                                raise ValueError("The invalid name of orbital in 'sum_morbs[%s]' is given." % elt)
1800                            if orb in individual_orbs.keys():
1801                                if len(set(sum_morbs[elt]).intersection(individual_orbs[orb])) != 0:
1802                                    raise ValueError("The 'sum_morbs[%s]' contains orbitals repeated." % elt)
1803                        nelems = Counter(sum_morbs[elt]).values()
1804                        if sum(nelems) > len(nelems):
1805                            raise ValueError("You put in at least two similar orbitals in sum_morbs[%s]." % elt)
1806                    else:
1807                        raise TypeError(
1808                            "The invalid type of value was put into 'sum_morbs[%s]'. It should be list " "type." % elt
1809                        )
1810                    if elt not in dictio.keys():
1811                        raise ValueError(
1812                            "You cannot sum projection over orbitals of atoms '%s' because they are not "
1813                            "mentioned in 'dictio'." % elt
1814                        )
1815                else:
1816                    raise KeyError("The invalid element was put into 'sum_morbs' as a key: %s" % elt)
1817
1818        for elt in dictio:
1819            if len(dictio[elt]) == 1:
1820                if len(dictio[elt][0]) > 1:
1821                    if elt in sum_morbs.keys():
1822                        raise ValueError(
1823                            "You cannot sum projection over one individual orbital '%s' of '%s'."
1824                            % (dictio[elt][0], elt)
1825                        )
1826                else:
1827                    if sum_morbs is None:
1828                        pass
1829                    elif elt not in sum_morbs.keys():
1830                        print("You do not want to sum projection over orbitals of element: %s" % elt)
1831                    else:
1832                        if len(sum_morbs[elt]) == 0:
1833                            raise ValueError("The empty list is an invalid value for sum_morbs[%s]." % elt)
1834                        if len(sum_morbs[elt]) > 1:
1835                            for orb in sum_morbs[elt]:
1836                                if dictio[elt][0] not in orb:
1837                                    raise ValueError(
1838                                        "The invalid orbital '%s' was put into 'sum_morbs[%s]'." % (orb, elt)
1839                                    )
1840                        else:
1841                            if orb == "s" or len(orb) > 1:
1842                                raise ValueError("The invalid orbital '%s' was put into sum_orbs['%s']." % (orb, elt))
1843                            sum_morbs[elt] = individual_orbs[dictio[elt][0]]
1844                            dictio[elt] = individual_orbs[dictio[elt][0]]
1845            else:
1846                duplicate = copy.deepcopy(dictio[elt])
1847                for orb in dictio[elt]:
1848                    if orb in individual_orbs.keys():
1849                        duplicate.remove(orb)
1850                        for o in individual_orbs[orb]:
1851                            duplicate.append(o)
1852                dictio[elt] = copy.deepcopy(duplicate)
1853
1854                if sum_morbs is None:
1855                    pass
1856                elif elt not in sum_morbs.keys():
1857                    print("You do not want to sum projection over orbitals of element: %s" % elt)
1858                else:
1859                    if len(sum_morbs[elt]) == 0:
1860                        raise ValueError("The empty list is an invalid value for sum_morbs[%s]." % elt)
1861                    if len(sum_morbs[elt]) == 1:
1862                        orb = sum_morbs[elt][0]
1863                        if orb == "s":
1864                            raise ValueError(
1865                                "We do not sum projection over only 's' orbital of the same " "type of element."
1866                            )
1867                        if orb in individual_orbs.keys():
1868                            sum_morbs[elt].pop(0)
1869                            for o in individual_orbs[orb]:
1870                                sum_morbs[elt].append(o)
1871                        else:
1872                            raise ValueError("You never sum projection over one orbital in sum_morbs[%s]" % elt)
1873                    else:
1874                        duplicate = copy.deepcopy(sum_morbs[elt])
1875                        for orb in sum_morbs[elt]:
1876                            if orb in individual_orbs.keys():
1877                                duplicate.remove(orb)
1878                                for o in individual_orbs[orb]:
1879                                    duplicate.append(o)
1880                        sum_morbs[elt] = copy.deepcopy(duplicate)
1881
1882                    for orb in sum_morbs[elt]:
1883                        if orb not in dictio[elt]:
1884                            raise ValueError(
1885                                "The orbitals of sum_morbs[%s] conflict with those of dictio[%s]." % (elt, elt)
1886                            )
1887
1888        return dictio, sum_morbs
1889
1890    def _number_of_subfigures(self, dictio, dictpa, sum_atoms, sum_morbs):
1891        from collections import Counter
1892
1893        from pymatgen.core.periodic_table import Element
1894
1895        if not isinstance(dictpa, dict):
1896            raise TypeError("The invalid type of 'dictpa' was bound. It should be dict type.")
1897        if len(dictpa.keys()) == 0:
1898            raise KeyError("The 'dictpa' is empty. We cannot do anything.")
1899        for elt in dictpa:
1900            if Element.is_valid_symbol(elt):
1901                if isinstance(dictpa[elt], list):
1902                    if len(dictpa[elt]) == 0:
1903                        raise ValueError("The dictpa[%s] is empty. We cannot do anything" % elt)
1904                    _sites = self._bs.structure.sites
1905                    indices = []
1906                    for i in range(0, len(_sites)):  # pylint: disable=C0200
1907                        if list(_sites[i]._species.keys())[0].__eq__(Element(elt)):
1908                            indices.append(i + 1)
1909                    for number in dictpa[elt]:
1910                        if isinstance(number, str):
1911                            if number.lower() == "all":
1912                                dictpa[elt] = indices
1913                                print("You want to consider all '%s' atoms." % elt)
1914                                break
1915
1916                            raise ValueError("You put wrong site numbers in 'dictpa[%s]': %s." % (elt, str(number)))
1917                        if isinstance(number, int):
1918                            if number not in indices:
1919                                raise ValueError("You put wrong site numbers in 'dictpa[%s]': %s." % (elt, str(number)))
1920                        else:
1921                            raise ValueError("You put wrong site numbers in 'dictpa[%s]': %s." % (elt, str(number)))
1922                    nelems = Counter(dictpa[elt]).values()
1923                    if sum(nelems) > len(nelems):
1924                        raise ValueError("You put at least two similar site numbers into 'dictpa[%s]'." % elt)
1925                else:
1926                    raise TypeError(
1927                        "The invalid type of value was put into 'dictpa[%s]'. It should be list " "type." % elt
1928                    )
1929            else:
1930                raise KeyError("The invalid element was put into 'dictpa' as a key: %s" % elt)
1931
1932        if len(list(dictio.keys())) != len(list(dictpa.keys())):
1933            raise KeyError("The number of keys in 'dictio' and 'dictpa' are not the same.")
1934        for elt in dictio.keys():
1935            if elt not in dictpa.keys():
1936                raise KeyError("The element '%s' is not in both dictpa and dictio." % elt)
1937        for elt in dictpa.keys():
1938            if elt not in dictio.keys():
1939                raise KeyError("The element '%s' in not in both dictpa and dictio." % elt)
1940
1941        if sum_atoms is None:
1942            print("You do not want to sum projection over atoms.")
1943        elif not isinstance(sum_atoms, dict):
1944            raise TypeError("The invalid type of 'sum_atoms' was bound. It should be dict type.")
1945        elif len(sum_atoms.keys()) == 0:
1946            raise KeyError("The 'sum_atoms' is empty. We cannot do anything.")
1947        else:
1948            for elt in sum_atoms:
1949                if Element.is_valid_symbol(elt):
1950                    if isinstance(sum_atoms[elt], list):
1951                        if len(sum_atoms[elt]) == 0:
1952                            raise ValueError("The sum_atoms[%s] is empty. We cannot do anything" % elt)
1953                        _sites = self._bs.structure.sites
1954                        indices = []
1955                        for i in range(0, len(_sites)):  # pylint: disable=C0200
1956                            if list(_sites[i]._species.keys())[0].__eq__(Element(elt)):
1957                                indices.append(i + 1)
1958                        for number in sum_atoms[elt]:
1959                            if isinstance(number, str):
1960                                if number.lower() == "all":
1961                                    sum_atoms[elt] = indices
1962                                    print("You want to sum projection over all '%s' atoms." % elt)
1963                                    break
1964                                raise ValueError("You put wrong site numbers in 'sum_atoms[%s]'." % elt)
1965                            if isinstance(number, int):
1966                                if number not in indices:
1967                                    raise ValueError("You put wrong site numbers in 'sum_atoms[%s]'." % elt)
1968                                if number not in dictpa[elt]:
1969                                    raise ValueError(
1970                                        "You cannot sum projection with atom number '%s' because it is not "
1971                                        "metioned in dicpta[%s]" % (str(number), elt)
1972                                    )
1973                            else:
1974                                raise ValueError("You put wrong site numbers in 'sum_atoms[%s]'." % elt)
1975                        nelems = Counter(sum_atoms[elt]).values()
1976                        if sum(nelems) > len(nelems):
1977                            raise ValueError("You put at least two similar site numbers into 'sum_atoms[%s]'." % elt)
1978                    else:
1979                        raise TypeError(
1980                            "The invalid type of value was put into 'sum_atoms[%s]'. It should be list " "type." % elt
1981                        )
1982                    if elt not in dictpa.keys():
1983                        raise ValueError(
1984                            "You cannot sum projection over atoms '%s' because it is not "
1985                            "mentioned in 'dictio'." % elt
1986                        )
1987                else:
1988                    raise KeyError("The invalid element was put into 'sum_atoms' as a key: %s" % elt)
1989                if len(sum_atoms[elt]) == 1:
1990                    raise ValueError("We do not sum projection over only one atom: %s" % elt)
1991
1992        max_number_figs = 0
1993        decrease = 0
1994        for elt in dictio:
1995            max_number_figs += len(dictio[elt]) * len(dictpa[elt])
1996
1997        if (sum_atoms is None) and (sum_morbs is None):
1998            number_figs = max_number_figs
1999        elif (sum_atoms is not None) and (sum_morbs is None):
2000            for elt in sum_atoms:
2001                decrease += (len(sum_atoms[elt]) - 1) * len(dictio[elt])
2002            number_figs = max_number_figs - decrease
2003        elif (sum_atoms is None) and (sum_morbs is not None):
2004            for elt in sum_morbs:
2005                decrease += (len(sum_morbs[elt]) - 1) * len(dictpa[elt])
2006            number_figs = max_number_figs - decrease
2007        elif (sum_atoms is not None) and (sum_morbs is not None):
2008            for elt in sum_atoms:
2009                decrease += (len(sum_atoms[elt]) - 1) * len(dictio[elt])
2010            for elt in sum_morbs:
2011                if elt in sum_atoms:
2012                    decrease += (len(sum_morbs[elt]) - 1) * (len(dictpa[elt]) - len(sum_atoms[elt]) + 1)
2013                else:
2014                    decrease += (len(sum_morbs[elt]) - 1) * len(dictpa[elt])
2015            number_figs = max_number_figs - decrease
2016        else:
2017            raise ValueError("Invalid format of 'sum_atoms' and 'sum_morbs'.")
2018
2019        return dictpa, sum_atoms, number_figs
2020
2021    def _summarize_keys_for_plot(self, dictio, dictpa, sum_atoms, sum_morbs):
2022        from pymatgen.core.periodic_table import Element
2023
2024        individual_orbs = {
2025            "p": ["px", "py", "pz"],
2026            "d": ["dxy", "dyz", "dxz", "dx2", "dz2"],
2027            "f": ["f_3", "f_2", "f_1", "f0", "f1", "f2", "f3"],
2028        }
2029
2030        def number_label(list_numbers):
2031            list_numbers = sorted(list_numbers)
2032            divide = [[]]
2033            divide[0].append(list_numbers[0])
2034            group = 0
2035            for i in range(1, len(list_numbers)):
2036                if list_numbers[i] == list_numbers[i - 1] + 1:
2037                    divide[group].append(list_numbers[i])
2038                else:
2039                    group += 1
2040                    divide.append([list_numbers[i]])
2041            label = ""
2042            for elem in divide:
2043                if len(elem) > 1:
2044                    label += str(elem[0]) + "-" + str(elem[-1]) + ","
2045                else:
2046                    label += str(elem[0]) + ","
2047            return label[:-1]
2048
2049        def orbital_label(list_orbitals):
2050            divide = {}
2051            for orb in list_orbitals:
2052                if orb[0] in divide:
2053                    divide[orb[0]].append(orb)
2054                else:
2055                    divide[orb[0]] = []
2056                    divide[orb[0]].append(orb)
2057            label = ""
2058            for elem, v in divide.items():
2059                if elem == "s":
2060                    label += "s" + ","
2061                else:
2062                    if len(v) == len(individual_orbs[elem]):
2063                        label += elem + ","
2064                    else:
2065                        l = [o[1:] for o in v]
2066                        label += elem + str(l).replace("['", "").replace("']", "").replace("', '", "-") + ","
2067            return label[:-1]
2068
2069        if (sum_atoms is None) and (sum_morbs is None):
2070            dictio_d = dictio
2071            dictpa_d = {elt: [str(anum) for anum in dictpa[elt]] for elt in dictpa}
2072
2073        elif (sum_atoms is not None) and (sum_morbs is None):
2074            dictio_d = dictio
2075            dictpa_d = {}
2076            for elt in dictpa:
2077                dictpa_d[elt] = []
2078                if elt in sum_atoms:
2079                    _sites = self._bs.structure.sites
2080                    indices = []
2081                    for i in range(0, len(_sites)):  # pylint: disable=C0200
2082                        if list(_sites[i]._species.keys())[0].__eq__(Element(elt)):
2083                            indices.append(i + 1)
2084                    flag_1 = len(set(dictpa[elt]).intersection(indices))
2085                    flag_2 = len(set(sum_atoms[elt]).intersection(indices))
2086                    if flag_1 == len(indices) and flag_2 == len(indices):
2087                        dictpa_d[elt].append("all")
2088                    else:
2089                        for anum in dictpa[elt]:
2090                            if anum not in sum_atoms[elt]:
2091                                dictpa_d[elt].append(str(anum))
2092                        label = number_label(sum_atoms[elt])
2093                        dictpa_d[elt].append(label)
2094                else:
2095                    for anum in dictpa[elt]:
2096                        dictpa_d[elt].append(str(anum))
2097
2098        elif (sum_atoms is None) and (sum_morbs is not None):
2099            dictio_d = {}
2100            for elt in dictio:
2101                dictio_d[elt] = []
2102                if elt in sum_morbs:
2103                    for morb in dictio[elt]:
2104                        if morb not in sum_morbs[elt]:
2105                            dictio_d[elt].append(morb)
2106                    label = orbital_label(sum_morbs[elt])
2107                    dictio_d[elt].append(label)
2108                else:
2109                    dictio_d[elt] = dictio[elt]
2110            dictpa_d = {elt: [str(anum) for anum in dictpa[elt]] for elt in dictpa}
2111
2112        else:
2113            dictio_d = {}
2114            for elt in dictio:
2115                dictio_d[elt] = []
2116                if elt in sum_morbs:
2117                    for morb in dictio[elt]:
2118                        if morb not in sum_morbs[elt]:
2119                            dictio_d[elt].append(morb)
2120                    label = orbital_label(sum_morbs[elt])
2121                    dictio_d[elt].append(label)
2122                else:
2123                    dictio_d[elt] = dictio[elt]
2124            dictpa_d = {}
2125            for elt in dictpa:
2126                dictpa_d[elt] = []
2127                if elt in sum_atoms:
2128                    _sites = self._bs.structure.sites
2129                    indices = []
2130                    for i in range(0, len(_sites)):  # pylint: disable=C0200
2131                        if list(_sites[i]._species.keys())[0].__eq__(Element(elt)):
2132                            indices.append(i + 1)
2133                    flag_1 = len(set(dictpa[elt]).intersection(indices))
2134                    flag_2 = len(set(sum_atoms[elt]).intersection(indices))
2135                    if flag_1 == len(indices) and flag_2 == len(indices):
2136                        dictpa_d[elt].append("all")
2137                    else:
2138                        for anum in dictpa[elt]:
2139                            if anum not in sum_atoms[elt]:
2140                                dictpa_d[elt].append(str(anum))
2141                        label = number_label(sum_atoms[elt])
2142                        dictpa_d[elt].append(label)
2143                else:
2144                    for anum in dictpa[elt]:
2145                        dictpa_d[elt].append(str(anum))
2146
2147        return dictio_d, dictpa_d
2148
2149    def _maketicks_selected(self, plt, branches):
2150        """
2151        utility private method to add ticks to a band structure with selected branches
2152        """
2153        ticks = self.get_ticks()
2154        distance = []
2155        label = []
2156        rm_elems = []
2157        for i in range(1, len(ticks["distance"])):
2158            if ticks["label"][i] == ticks["label"][i - 1]:
2159                rm_elems.append(i)
2160        for i in range(len(ticks["distance"])):
2161            if i not in rm_elems:
2162                distance.append(ticks["distance"][i])
2163                label.append(ticks["label"][i])
2164        l_branches = [distance[i] - distance[i - 1] for i in range(1, len(distance))]
2165        n_distance = []
2166        n_label = []
2167        for branch in branches:
2168            n_distance.append(l_branches[branch])
2169            if ("$\\mid$" not in label[branch]) and ("$\\mid$" not in label[branch + 1]):
2170                n_label.append([label[branch], label[branch + 1]])
2171            elif ("$\\mid$" in label[branch]) and ("$\\mid$" not in label[branch + 1]):
2172                n_label.append([label[branch].split("$")[-1], label[branch + 1]])
2173            elif ("$\\mid$" not in label[branch]) and ("$\\mid$" in label[branch + 1]):
2174                n_label.append([label[branch], label[branch + 1].split("$")[0]])
2175            else:
2176                n_label.append([label[branch].split("$")[-1], label[branch + 1].split("$")[0]])
2177
2178        f_distance = []
2179        rf_distance = []
2180        f_label = []
2181        f_label.append(n_label[0][0])
2182        f_label.append(n_label[0][1])
2183        f_distance.append(0.0)
2184        f_distance.append(n_distance[0])
2185        rf_distance.append(0.0)
2186        rf_distance.append(n_distance[0])
2187        length = n_distance[0]
2188        for i in range(1, len(n_distance)):
2189            if n_label[i][0] == n_label[i - 1][1]:
2190                f_distance.append(length)
2191                f_distance.append(length + n_distance[i])
2192                f_label.append(n_label[i][0])
2193                f_label.append(n_label[i][1])
2194            else:
2195                f_distance.append(length + n_distance[i])
2196                f_label[-1] = n_label[i - 1][1] + "$\\mid$" + n_label[i][0]
2197                f_label.append(n_label[i][1])
2198            rf_distance.append(length + n_distance[i])
2199            length += n_distance[i]
2200
2201        n_ticks = {"distance": f_distance, "label": f_label}
2202        uniq_d = []
2203        uniq_l = []
2204        temp_ticks = list(zip(n_ticks["distance"], n_ticks["label"]))
2205        for i, t in enumerate(temp_ticks):
2206            if i == 0:
2207                uniq_d.append(t[0])
2208                uniq_l.append(t[1])
2209                logger.debug("Adding label {l} at {d}".format(l=t[0], d=t[1]))
2210            else:
2211                if t[1] == temp_ticks[i - 1][1]:
2212                    logger.debug("Skipping label {i}".format(i=t[1]))
2213                else:
2214                    logger.debug("Adding label {l} at {d}".format(l=t[0], d=t[1]))
2215                    uniq_d.append(t[0])
2216                    uniq_l.append(t[1])
2217
2218        logger.debug("Unique labels are %s" % list(zip(uniq_d, uniq_l)))
2219        plt.gca().set_xticks(uniq_d)
2220        plt.gca().set_xticklabels(uniq_l)
2221
2222        for i in range(len(n_ticks["label"])):
2223            if n_ticks["label"][i] is not None:
2224                # don't print the same label twice
2225                if i != 0:
2226                    if n_ticks["label"][i] == n_ticks["label"][i - 1]:
2227                        logger.debug("already print label... " "skipping label {i}".format(i=n_ticks["label"][i]))
2228                    else:
2229                        logger.debug(
2230                            "Adding a line at {d}"
2231                            " for label {l}".format(d=n_ticks["distance"][i], l=n_ticks["label"][i])
2232                        )
2233                        plt.axvline(n_ticks["distance"][i], color="k")
2234                else:
2235                    logger.debug(
2236                        "Adding a line at {d} for label {l}".format(d=n_ticks["distance"][i], l=n_ticks["label"][i])
2237                    )
2238                    plt.axvline(n_ticks["distance"][i], color="k")
2239
2240        shift = []
2241        br = -1
2242        for branch in branches:
2243            br += 1
2244            shift.append(distance[branch] - rf_distance[br])
2245
2246        return plt, shift
2247
2248
2249class BSDOSPlotter:
2250    """
2251    A joint, aligned band structure and density of states plot. Contributions
2252    from Jan Pohls as well as the online example from Germain Salvato-Vallverdu:
2253    http://gvallver.perso.univ-pau.fr/?p=587
2254    """
2255
2256    def __init__(
2257        self,
2258        bs_projection="elements",
2259        dos_projection="elements",
2260        vb_energy_range=4,
2261        cb_energy_range=4,
2262        fixed_cb_energy=False,
2263        egrid_interval=1,
2264        font="Times New Roman",
2265        axis_fontsize=20,
2266        tick_fontsize=15,
2267        legend_fontsize=14,
2268        bs_legend="best",
2269        dos_legend="best",
2270        rgb_legend=True,
2271        fig_size=(11, 8.5),
2272    ):
2273        """
2274        Instantiate plotter settings.
2275
2276        Args:
2277            bs_projection (str): "elements" or None
2278            dos_projection (str): "elements", "orbitals", or None
2279            vb_energy_range (float): energy in eV to show of valence bands
2280            cb_energy_range (float): energy in eV to show of conduction bands
2281            fixed_cb_energy (bool): If true, the cb_energy_range will be interpreted
2282                as constant (i.e., no gap correction for cb energy)
2283            egrid_interval (float): interval for grid marks
2284            font (str): font family
2285            axis_fontsize (float): font size for axis
2286            tick_fontsize (float): font size for axis tick labels
2287            legend_fontsize (float): font size for legends
2288            bs_legend (str): matplotlib string location for legend or None
2289            dos_legend (str): matplotlib string location for legend or None
2290            rgb_legend (bool): (T/F) whether to draw RGB triangle/bar for element proj.
2291            fig_size(tuple): dimensions of figure size (width, height)
2292        """
2293        self.bs_projection = bs_projection
2294        self.dos_projection = dos_projection
2295        self.vb_energy_range = vb_energy_range
2296        self.cb_energy_range = cb_energy_range
2297        self.fixed_cb_energy = fixed_cb_energy
2298        self.egrid_interval = egrid_interval
2299        self.font = font
2300        self.axis_fontsize = axis_fontsize
2301        self.tick_fontsize = tick_fontsize
2302        self.legend_fontsize = legend_fontsize
2303        self.bs_legend = bs_legend
2304        self.dos_legend = dos_legend
2305        self.rgb_legend = rgb_legend
2306        self.fig_size = fig_size
2307
2308    def get_plot(self, bs, dos=None):
2309        """
2310        Get a matplotlib plot object.
2311        Args:
2312            bs (BandStructureSymmLine): the bandstructure to plot. Projection
2313                data must exist for projected plots.
2314            dos (Dos): the Dos to plot. Projection data must exist (i.e.,
2315                CompleteDos) for projected plots.
2316
2317        Returns:
2318            matplotlib.pyplot object on which you can call commands like show()
2319            and savefig()
2320        """
2321        import matplotlib.lines as mlines
2322        import matplotlib.pyplot as mplt
2323        from matplotlib.gridspec import GridSpec
2324
2325        # make sure the user-specified band structure projection is valid
2326        bs_projection = self.bs_projection
2327        if dos:
2328            elements = [e.symbol for e in dos.structure.composition.elements]
2329        elif bs_projection and bs.structure:
2330            elements = [e.symbol for e in bs.structure.composition.elements]
2331        else:
2332            elements = []
2333
2334        rgb_legend = (
2335            self.rgb_legend and bs_projection and bs_projection.lower() == "elements" and len(elements) in [2, 3]
2336        )
2337
2338        if (
2339            bs_projection
2340            and bs_projection.lower() == "elements"
2341            and (len(elements) not in [2, 3] or not bs.get_projection_on_elements())
2342        ):
2343            warnings.warn(
2344                "Cannot get element projected data; either the projection data "
2345                "doesn't exist, or you don't have a compound with exactly 2 "
2346                "or 3 unique elements."
2347            )
2348            bs_projection = None
2349
2350        # specify energy range of plot
2351        emin = -self.vb_energy_range
2352        emax = self.cb_energy_range if self.fixed_cb_energy else self.cb_energy_range + bs.get_band_gap()["energy"]
2353
2354        # initialize all the k-point labels and k-point x-distances for bs plot
2355        xlabels = []  # all symmetry point labels on x-axis
2356        xlabel_distances = []  # positions of symmetry point x-labels
2357
2358        x_distances_list = []
2359        prev_right_klabel = None  # used to determine which branches require a midline separator
2360
2361        for idx, l in enumerate(bs.branches):
2362            x_distances = []
2363
2364            # get left and right kpoint labels of this branch
2365            left_k, right_k = l["name"].split("-")
2366
2367            # add $ notation for LaTeX kpoint labels
2368            if left_k[0] == "\\" or "_" in left_k:
2369                left_k = "$" + left_k + "$"
2370            if right_k[0] == "\\" or "_" in right_k:
2371                right_k = "$" + right_k + "$"
2372
2373            # add left k label to list of labels
2374            if prev_right_klabel is None:
2375                xlabels.append(left_k)
2376                xlabel_distances.append(0)
2377            elif prev_right_klabel != left_k:  # used for pipe separator
2378                xlabels[-1] = xlabels[-1] + "$\\mid$ " + left_k
2379
2380            # add right k label to list of labels
2381            xlabels.append(right_k)
2382            prev_right_klabel = right_k
2383
2384            # add x-coordinates for labels
2385            left_kpoint = bs.kpoints[l["start_index"]].cart_coords
2386            right_kpoint = bs.kpoints[l["end_index"]].cart_coords
2387            distance = np.linalg.norm(right_kpoint - left_kpoint)
2388            xlabel_distances.append(xlabel_distances[-1] + distance)
2389
2390            # add x-coordinates for kpoint data
2391            npts = l["end_index"] - l["start_index"]
2392            distance_interval = distance / npts
2393            x_distances.append(xlabel_distances[-2])
2394            for i in range(npts):
2395                x_distances.append(x_distances[-1] + distance_interval)
2396            x_distances_list.append(x_distances)
2397
2398        # set up bs and dos plot
2399        gs = GridSpec(1, 2, width_ratios=[2, 1]) if dos else GridSpec(1, 1)
2400
2401        fig = mplt.figure(figsize=self.fig_size)
2402        fig.patch.set_facecolor("white")
2403        bs_ax = mplt.subplot(gs[0])
2404        if dos:
2405            dos_ax = mplt.subplot(gs[1])
2406
2407        # set basic axes limits for the plot
2408        bs_ax.set_xlim(0, x_distances_list[-1][-1])
2409        bs_ax.set_ylim(emin, emax)
2410        if dos:
2411            dos_ax.set_ylim(emin, emax)
2412
2413        # add BS xticks, labels, etc.
2414        bs_ax.set_xticks(xlabel_distances)
2415        bs_ax.set_xticklabels(xlabels, size=self.tick_fontsize)
2416        bs_ax.set_xlabel("Wavevector $k$", fontsize=self.axis_fontsize, family=self.font)
2417        bs_ax.set_ylabel("$E-E_F$ / eV", fontsize=self.axis_fontsize, family=self.font)
2418
2419        # add BS fermi level line at E=0 and gridlines
2420        bs_ax.hlines(y=0, xmin=0, xmax=x_distances_list[-1][-1], color="k", lw=2)
2421        bs_ax.set_yticks(np.arange(emin, emax + 1e-5, self.egrid_interval))
2422        bs_ax.set_yticklabels(np.arange(emin, emax + 1e-5, self.egrid_interval), size=self.tick_fontsize)
2423        bs_ax.set_axisbelow(True)
2424        bs_ax.grid(color=[0.5, 0.5, 0.5], linestyle="dotted", linewidth=1)
2425        if dos:
2426            dos_ax.set_yticks(np.arange(emin, emax + 1e-5, self.egrid_interval))
2427            dos_ax.set_yticklabels([])
2428            dos_ax.grid(color=[0.5, 0.5, 0.5], linestyle="dotted", linewidth=1)
2429
2430        # renormalize the band energy to the Fermi level
2431        band_energies = {}
2432        for spin in (Spin.up, Spin.down):
2433            if spin in bs.bands:
2434                band_energies[spin] = []
2435                for band in bs.bands[spin]:
2436                    band_energies[spin].append([e - bs.efermi for e in band])
2437
2438        # renormalize the DOS energies to Fermi level
2439        if dos:
2440            dos_energies = [e - dos.efermi for e in dos.energies]
2441
2442        # get the projection data to set colors for the band structure
2443        colordata = self._get_colordata(bs, elements, bs_projection)
2444
2445        # plot the colored band structure lines
2446        for spin in (Spin.up, Spin.down):
2447            if spin in band_energies:
2448                linestyles = "solid" if spin == Spin.up else "dotted"
2449                for band_idx, band in enumerate(band_energies[spin]):
2450                    current_pos = 0
2451                    for x_distances in x_distances_list:
2452                        sub_band = band[current_pos : current_pos + len(x_distances)]
2453
2454                        self._rgbline(
2455                            bs_ax,
2456                            x_distances,
2457                            sub_band,
2458                            colordata[spin][band_idx, :, 0][current_pos : current_pos + len(x_distances)],
2459                            colordata[spin][band_idx, :, 1][current_pos : current_pos + len(x_distances)],
2460                            colordata[spin][band_idx, :, 2][current_pos : current_pos + len(x_distances)],
2461                            linestyles=linestyles,
2462                        )
2463
2464                        current_pos += len(x_distances)
2465
2466        if dos:
2467            # Plot the DOS and projected DOS
2468            for spin in (Spin.up, Spin.down):
2469                if spin in dos.densities:
2470                    # plot the total DOS
2471                    dos_densities = dos.densities[spin] * int(spin)
2472                    label = "total" if spin == Spin.up else None
2473                    dos_ax.plot(dos_densities, dos_energies, color=(0.6, 0.6, 0.6), label=label)
2474                    dos_ax.fill_betweenx(
2475                        dos_energies,
2476                        0,
2477                        dos_densities,
2478                        color=(0.7, 0.7, 0.7),
2479                        facecolor=(0.7, 0.7, 0.7),
2480                    )
2481
2482                    if self.dos_projection is None:
2483                        pass
2484
2485                    elif self.dos_projection.lower() == "elements":
2486                        # plot the atom-projected DOS
2487                        colors = ["b", "r", "g", "m", "y", "c", "k", "w"]
2488                        el_dos = dos.get_element_dos()
2489                        for idx, el in enumerate(elements):
2490                            dos_densities = el_dos[Element(el)].densities[spin] * int(spin)
2491                            label = el if spin == Spin.up else None
2492                            dos_ax.plot(
2493                                dos_densities,
2494                                dos_energies,
2495                                color=colors[idx],
2496                                label=label,
2497                            )
2498
2499                    elif self.dos_projection.lower() == "orbitals":
2500                        # plot each of the atomic projected DOS
2501                        colors = ["b", "r", "g", "m"]
2502                        spd_dos = dos.get_spd_dos()
2503                        for idx, orb in enumerate([OrbitalType.s, OrbitalType.p, OrbitalType.d, OrbitalType.f]):
2504                            if orb in spd_dos:
2505                                dos_densities = spd_dos[orb].densities[spin] * int(spin)
2506                                label = orb if spin == Spin.up else None
2507                                dos_ax.plot(
2508                                    dos_densities,
2509                                    dos_energies,
2510                                    color=colors[idx],
2511                                    label=label,
2512                                )
2513
2514            # get index of lowest and highest energy being plotted, used to help auto-scale DOS x-axis
2515            emin_idx = next(x[0] for x in enumerate(dos_energies) if x[1] >= emin)
2516            emax_idx = len(dos_energies) - next(x[0] for x in enumerate(reversed(dos_energies)) if x[1] <= emax)
2517
2518            # determine DOS x-axis range
2519            dos_xmin = (
2520                0 if Spin.down not in dos.densities else -max(dos.densities[Spin.down][emin_idx : emax_idx + 1] * 1.05)
2521            )
2522            dos_xmax = max([max(dos.densities[Spin.up][emin_idx:emax_idx]) * 1.05, abs(dos_xmin)])
2523
2524            # set up the DOS x-axis and add Fermi level line
2525            dos_ax.set_xlim(dos_xmin, dos_xmax)
2526            dos_ax.set_xticklabels([])
2527            dos_ax.hlines(y=0, xmin=dos_xmin, xmax=dos_xmax, color="k", lw=2)
2528            dos_ax.set_xlabel("DOS", fontsize=self.axis_fontsize, family=self.font)
2529
2530        # add legend for band structure
2531        if self.bs_legend and not rgb_legend:
2532            handles = []
2533
2534            if bs_projection is None:
2535                handles = [
2536                    mlines.Line2D([], [], linewidth=2, color="k", label="spin up"),
2537                    mlines.Line2D(
2538                        [],
2539                        [],
2540                        linewidth=2,
2541                        color="b",
2542                        linestyle="dotted",
2543                        label="spin down",
2544                    ),
2545                ]
2546
2547            elif bs_projection.lower() == "elements":
2548                colors = ["b", "r", "g"]
2549                for idx, el in enumerate(elements):
2550                    handles.append(mlines.Line2D([], [], linewidth=2, color=colors[idx], label=el))
2551
2552            bs_ax.legend(
2553                handles=handles,
2554                fancybox=True,
2555                prop={"size": self.legend_fontsize, "family": self.font},
2556                loc=self.bs_legend,
2557            )
2558
2559        elif self.bs_legend and rgb_legend:
2560            if len(elements) == 2:
2561                self._rb_line(bs_ax, elements[1], elements[0], loc=self.bs_legend)
2562            elif len(elements) == 3:
2563                self._rgb_triangle(bs_ax, elements[1], elements[2], elements[0], loc=self.bs_legend)
2564
2565        # add legend for DOS
2566        if dos and self.dos_legend:
2567            dos_ax.legend(
2568                fancybox=True,
2569                prop={"size": self.legend_fontsize, "family": self.font},
2570                loc=self.dos_legend,
2571            )
2572
2573        mplt.subplots_adjust(wspace=0.1)
2574        return mplt
2575
2576    @staticmethod
2577    def _rgbline(ax, k, e, red, green, blue, alpha=1, linestyles="solid"):
2578        """
2579        An RGB colored line for plotting.
2580        creation of segments based on:
2581        http://nbviewer.ipython.org/urls/raw.github.com/dpsanders/matplotlib-examples/master/colorline.ipynb
2582        Args:
2583            ax: matplotlib axis
2584            k: x-axis data (k-points)
2585            e: y-axis data (energies)
2586            red: red data
2587            green: green data
2588            blue: blue data
2589            alpha: alpha values data
2590            linestyles: linestyle for plot (e.g., "solid" or "dotted")
2591        """
2592        from matplotlib.collections import LineCollection
2593
2594        pts = np.array([k, e]).T.reshape(-1, 1, 2)
2595        seg = np.concatenate([pts[:-1], pts[1:]], axis=1)
2596
2597        nseg = len(k) - 1
2598        r = [0.5 * (red[i] + red[i + 1]) for i in range(nseg)]
2599        g = [0.5 * (green[i] + green[i + 1]) for i in range(nseg)]
2600        b = [0.5 * (blue[i] + blue[i + 1]) for i in range(nseg)]
2601        a = np.ones(nseg, np.float_) * alpha
2602        lc = LineCollection(seg, colors=list(zip(r, g, b, a)), linewidth=2, linestyles=linestyles)
2603        ax.add_collection(lc)
2604
2605    @staticmethod
2606    def _get_colordata(bs, elements, bs_projection):
2607        """
2608        Get color data, including projected band structures
2609        Args:
2610            bs: Bandstructure object
2611            elements: elements (in desired order) for setting to blue, red, green
2612            bs_projection: None for no projection, "elements" for element projection
2613
2614        Returns:
2615
2616        """
2617        contribs = {}
2618        if bs_projection and bs_projection.lower() == "elements":
2619            projections = bs.get_projection_on_elements()
2620
2621        for spin in (Spin.up, Spin.down):
2622            if spin in bs.bands:
2623                contribs[spin] = []
2624                for band_idx in range(bs.nb_bands):
2625                    colors = []
2626                    for k_idx in range(len(bs.kpoints)):
2627                        if bs_projection and bs_projection.lower() == "elements":
2628                            c = [0, 0, 0]
2629                            projs = projections[spin][band_idx][k_idx]
2630                            # note: squared color interpolations are smoother
2631                            # see: https://youtu.be/LKnqECcg6Gw
2632                            projs = {k: v ** 2 for k, v in projs.items()}
2633                            total = sum(projs.values())
2634                            if total > 0:
2635                                for idx, e in enumerate(elements):
2636                                    c[idx] = math.sqrt(projs[e] / total)  # min is to handle round errors
2637
2638                            c = [c[1], c[2], c[0]]  # prefer blue, then red, then green
2639
2640                        else:
2641                            c = [0, 0, 0] if spin == Spin.up else [0, 0, 1]  # black for spin up, blue for spin down
2642
2643                        colors.append(c)
2644
2645                    contribs[spin].append(colors)
2646                contribs[spin] = np.array(contribs[spin])
2647
2648        return contribs
2649
2650    @staticmethod
2651    def _rgb_triangle(ax, r_label, g_label, b_label, loc):
2652        """
2653        Draw an RGB triangle legend on the desired axis
2654        """
2655        if loc not in range(1, 11):
2656            loc = 2
2657
2658        from mpl_toolkits.axes_grid1.inset_locator import inset_axes
2659
2660        inset_ax = inset_axes(ax, width=1, height=1, loc=loc)
2661        mesh = 35
2662        x = []
2663        y = []
2664        color = []
2665        for r in range(0, mesh):
2666            for g in range(0, mesh):
2667                for b in range(0, mesh):
2668                    if not (r == 0 and b == 0 and g == 0):
2669                        r1 = r / (r + g + b)
2670                        g1 = g / (r + g + b)
2671                        b1 = b / (r + g + b)
2672                        x.append(0.33 * (2.0 * g1 + r1) / (r1 + b1 + g1))
2673                        y.append(0.33 * np.sqrt(3) * r1 / (r1 + b1 + g1))
2674                        rc = math.sqrt(r ** 2 / (r ** 2 + g ** 2 + b ** 2))
2675                        gc = math.sqrt(g ** 2 / (r ** 2 + g ** 2 + b ** 2))
2676                        bc = math.sqrt(b ** 2 / (r ** 2 + g ** 2 + b ** 2))
2677                        color.append([rc, gc, bc])
2678
2679        # x = [n + 0.25 for n in x]  # nudge x coordinates
2680        # y = [n + (max_y - 1) for n in y]  # shift y coordinates to top
2681        # plot the triangle
2682        inset_ax.scatter(x, y, s=7, marker=".", edgecolor=color)  # pylint: disable=E1101
2683        inset_ax.set_xlim([-0.35, 1.00])  # pylint: disable=E1101
2684        inset_ax.set_ylim([-0.35, 1.00])  # pylint: disable=E1101
2685
2686        # add the labels
2687        inset_ax.text(  # pylint: disable=E1101
2688            0.70,
2689            -0.2,
2690            g_label,
2691            fontsize=13,
2692            family="Times New Roman",
2693            color=(0, 0, 0),
2694            horizontalalignment="left",
2695        )
2696        inset_ax.text(  # pylint: disable=E1101
2697            0.325,
2698            0.70,
2699            r_label,
2700            fontsize=13,
2701            family="Times New Roman",
2702            color=(0, 0, 0),
2703            horizontalalignment="center",
2704        )
2705        inset_ax.text(  # pylint: disable=E1101
2706            -0.05,
2707            -0.2,
2708            b_label,
2709            fontsize=13,
2710            family="Times New Roman",
2711            color=(0, 0, 0),
2712            horizontalalignment="right",
2713        )
2714
2715        inset_ax.get_xaxis().set_visible(False)  # pylint: disable=E1101
2716        inset_ax.get_yaxis().set_visible(False)  # pylint: disable=E1101
2717
2718    @staticmethod
2719    def _rb_line(ax, r_label, b_label, loc):
2720        # Draw an rb bar legend on the desired axis
2721
2722        if loc not in range(1, 11):
2723            loc = 2
2724        from mpl_toolkits.axes_grid1.inset_locator import inset_axes
2725
2726        inset_ax = inset_axes(ax, width=1.2, height=0.4, loc=loc)
2727
2728        x = []
2729        y = []
2730        color = []
2731        for i in range(0, 1000):
2732            x.append(i / 1800.0 + 0.55)
2733            y.append(0)
2734            color.append([math.sqrt(c) for c in [1 - (i / 1000) ** 2, 0, (i / 1000) ** 2]])
2735
2736        # plot the bar
2737        # pylint: disable=E1101
2738        inset_ax.scatter(x, y, s=250.0, marker="s", c=color)
2739        inset_ax.set_xlim([-0.1, 1.7])
2740        inset_ax.text(
2741            1.35,
2742            0,
2743            b_label,
2744            fontsize=13,
2745            family="Times New Roman",
2746            color=(0, 0, 0),
2747            horizontalalignment="left",
2748            verticalalignment="center",
2749        )
2750        inset_ax.text(
2751            0.30,
2752            0,
2753            r_label,
2754            fontsize=13,
2755            family="Times New Roman",
2756            color=(0, 0, 0),
2757            horizontalalignment="right",
2758            verticalalignment="center",
2759        )
2760
2761        inset_ax.get_xaxis().set_visible(False)
2762        inset_ax.get_yaxis().set_visible(False)
2763
2764
2765class BoltztrapPlotter:
2766    # TODO: We need a unittest for this. Come on folks.
2767    """
2768    class containing methods to plot the data from Boltztrap.
2769    """
2770
2771    def __init__(self, bz):
2772        """
2773        Args:
2774            bz: a BoltztrapAnalyzer object
2775        """
2776        self._bz = bz
2777
2778    def _plot_doping(self, plt, temp):
2779        if len(self._bz.doping) != 0:
2780            limit = 2.21e15
2781            plt.axvline(self._bz.mu_doping["n"][temp][0], linewidth=3.0, linestyle="--")
2782            plt.text(
2783                self._bz.mu_doping["n"][temp][0] + 0.01,
2784                limit,
2785                "$n$=10$^{" + str(math.log10(self._bz.doping["n"][0])) + "}$",
2786                color="b",
2787            )
2788            plt.axvline(self._bz.mu_doping["n"][temp][-1], linewidth=3.0, linestyle="--")
2789            plt.text(
2790                self._bz.mu_doping["n"][temp][-1] + 0.01,
2791                limit,
2792                "$n$=10$^{" + str(math.log10(self._bz.doping["n"][-1])) + "}$",
2793                color="b",
2794            )
2795            plt.axvline(self._bz.mu_doping["p"][temp][0], linewidth=3.0, linestyle="--")
2796            plt.text(
2797                self._bz.mu_doping["p"][temp][0] + 0.01,
2798                limit,
2799                "$p$=10$^{" + str(math.log10(self._bz.doping["p"][0])) + "}$",
2800                color="b",
2801            )
2802            plt.axvline(self._bz.mu_doping["p"][temp][-1], linewidth=3.0, linestyle="--")
2803            plt.text(
2804                self._bz.mu_doping["p"][temp][-1] + 0.01,
2805                limit,
2806                "$p$=10$^{" + str(math.log10(self._bz.doping["p"][-1])) + "}$",
2807                color="b",
2808            )
2809
2810    def _plot_bg_limits(self, plt):
2811        plt.axvline(0.0, color="k", linewidth=3.0)
2812        plt.axvline(self._bz.gap, color="k", linewidth=3.0)
2813
2814    def plot_seebeck_eff_mass_mu(self, temps=[300], output="average", Lambda=0.5):
2815        """
2816        Plot respect to the chemical potential of the Seebeck effective mass
2817        calculated as explained in Ref.
2818        Gibbs, Z. M. et al., Effective mass and fermi surface complexity factor
2819        from ab initio band structure calculations.
2820        npj Computational Materials 3, 8 (2017).
2821
2822        Args:
2823            output: 'average' returns the seebeck effective mass calculated
2824                using the average of the three diagonal components of the
2825                seebeck tensor. 'tensor' returns the seebeck effective mass
2826                respect to the three diagonal components of the seebeck tensor.
2827            temps:  list of temperatures of calculated seebeck.
2828            Lambda: fitting parameter used to model the scattering (0.5 means
2829                constant relaxation time).
2830        Returns:
2831            a matplotlib object
2832        """
2833
2834        plt = pretty_plot(9, 7)
2835        for T in temps:
2836            sbk_mass = self._bz.get_seebeck_eff_mass(output=output, temp=T, Lambda=0.5)
2837            # remove noise inside the gap
2838            start = self._bz.mu_doping["p"][T][0]
2839            stop = self._bz.mu_doping["n"][T][0]
2840            mu_steps_1 = []
2841            mu_steps_2 = []
2842            sbk_mass_1 = []
2843            sbk_mass_2 = []
2844            for i, mu in enumerate(self._bz.mu_steps):
2845                if mu <= start:
2846                    mu_steps_1.append(mu)
2847                    sbk_mass_1.append(sbk_mass[i])
2848                elif mu >= stop:
2849                    mu_steps_2.append(mu)
2850                    sbk_mass_2.append(sbk_mass[i])
2851
2852            plt.plot(mu_steps_1, sbk_mass_1, label=str(T) + "K", linewidth=3.0)
2853            plt.plot(mu_steps_2, sbk_mass_2, linewidth=3.0)
2854            if output == "average":
2855                plt.gca().get_lines()[1].set_c(plt.gca().get_lines()[0].get_c())
2856            elif output == "tensor":
2857                plt.gca().get_lines()[3].set_c(plt.gca().get_lines()[0].get_c())
2858                plt.gca().get_lines()[4].set_c(plt.gca().get_lines()[1].get_c())
2859                plt.gca().get_lines()[5].set_c(plt.gca().get_lines()[2].get_c())
2860
2861        plt.xlabel("E-E$_f$ (eV)", fontsize=30)
2862        plt.ylabel("Seebeck effective mass", fontsize=30)
2863        plt.xticks(fontsize=25)
2864        plt.yticks(fontsize=25)
2865        if output == "tensor":
2866            plt.legend(
2867                [str(i) + "_" + str(T) + "K" for T in temps for i in ("x", "y", "z")],
2868                fontsize=20,
2869            )
2870        elif output == "average":
2871            plt.legend(fontsize=20)
2872        plt.tight_layout()
2873        return plt
2874
2875    def plot_complexity_factor_mu(self, temps=[300], output="average", Lambda=0.5):
2876        """
2877        Plot respect to the chemical potential of the Fermi surface complexity
2878        factor calculated as explained in Ref.
2879        Gibbs, Z. M. et al., Effective mass and fermi surface complexity factor
2880        from ab initio band structure calculations.
2881        npj Computational Materials 3, 8 (2017).
2882
2883        Args:
2884            output: 'average' returns the complexity factor calculated using the average
2885                    of the three diagonal components of the seebeck and conductivity tensors.
2886                    'tensor' returns the complexity factor respect to the three
2887                    diagonal components of seebeck and conductivity tensors.
2888            temps:  list of temperatures of calculated seebeck and conductivity.
2889            Lambda: fitting parameter used to model the scattering (0.5 means constant
2890                    relaxation time).
2891        Returns:
2892            a matplotlib object
2893        """
2894        plt = pretty_plot(9, 7)
2895        for T in temps:
2896            cmplx_fact = self._bz.get_complexity_factor(output=output, temp=T, Lambda=Lambda)
2897            start = self._bz.mu_doping["p"][T][0]
2898            stop = self._bz.mu_doping["n"][T][0]
2899            mu_steps_1 = []
2900            mu_steps_2 = []
2901            cmplx_fact_1 = []
2902            cmplx_fact_2 = []
2903            for i, mu in enumerate(self._bz.mu_steps):
2904                if mu <= start:
2905                    mu_steps_1.append(mu)
2906                    cmplx_fact_1.append(cmplx_fact[i])
2907                elif mu >= stop:
2908                    mu_steps_2.append(mu)
2909                    cmplx_fact_2.append(cmplx_fact[i])
2910
2911            plt.plot(mu_steps_1, cmplx_fact_1, label=str(T) + "K", linewidth=3.0)
2912            plt.plot(mu_steps_2, cmplx_fact_2, linewidth=3.0)
2913            if output == "average":
2914                plt.gca().get_lines()[1].set_c(plt.gca().get_lines()[0].get_c())
2915            elif output == "tensor":
2916                plt.gca().get_lines()[3].set_c(plt.gca().get_lines()[0].get_c())
2917                plt.gca().get_lines()[4].set_c(plt.gca().get_lines()[1].get_c())
2918                plt.gca().get_lines()[5].set_c(plt.gca().get_lines()[2].get_c())
2919
2920        plt.xlabel("E-E$_f$ (eV)", fontsize=30)
2921        plt.ylabel("Complexity Factor", fontsize=30)
2922        plt.xticks(fontsize=25)
2923        plt.yticks(fontsize=25)
2924        if output == "tensor":
2925            plt.legend(
2926                [str(i) + "_" + str(T) + "K" for T in temps for i in ("x", "y", "z")],
2927                fontsize=20,
2928            )
2929        elif output == "average":
2930            plt.legend(fontsize=20)
2931        plt.tight_layout()
2932        return plt
2933
2934    def plot_seebeck_mu(self, temp=600, output="eig", xlim=None):
2935        """
2936        Plot the seebeck coefficient in function of Fermi level
2937
2938        Args:
2939            temp:
2940                the temperature
2941            xlim:
2942                a list of min and max fermi energy by default (0, and band gap)
2943        Returns:
2944            a matplotlib object
2945        """
2946        plt = pretty_plot(9, 7)
2947        seebeck = self._bz.get_seebeck(output=output, doping_levels=False)[temp]
2948        plt.plot(self._bz.mu_steps, seebeck, linewidth=3.0)
2949
2950        self._plot_bg_limits(plt)
2951        self._plot_doping(plt, temp)
2952        if output == "eig":
2953            plt.legend(["S$_1$", "S$_2$", "S$_3$"])
2954        if xlim is None:
2955            plt.xlim(-0.5, self._bz.gap + 0.5)
2956        else:
2957            plt.xlim(xlim[0], xlim[1])
2958        plt.ylabel("Seebeck \n coefficient  ($\\mu$V/K)", fontsize=30.0)
2959        plt.xlabel("E-E$_f$ (eV)", fontsize=30)
2960        plt.xticks(fontsize=25)
2961        plt.yticks(fontsize=25)
2962        plt.tight_layout()
2963        return plt
2964
2965    def plot_conductivity_mu(self, temp=600, output="eig", relaxation_time=1e-14, xlim=None):
2966        """
2967        Plot the conductivity in function of Fermi level. Semi-log plot
2968
2969        Args:
2970            temp: the temperature
2971            xlim: a list of min and max fermi energy by default (0, and band
2972                gap)
2973            tau: A relaxation time in s. By default none and the plot is by
2974               units of relaxation time
2975
2976        Returns:
2977            a matplotlib object
2978        """
2979        cond = self._bz.get_conductivity(relaxation_time=relaxation_time, output=output, doping_levels=False)[temp]
2980        plt = pretty_plot(9, 7)
2981        plt.semilogy(self._bz.mu_steps, cond, linewidth=3.0)
2982        self._plot_bg_limits(plt)
2983        self._plot_doping(plt, temp)
2984        if output == "eig":
2985            plt.legend(["$\\Sigma_1$", "$\\Sigma_2$", "$\\Sigma_3$"])
2986        if xlim is None:
2987            plt.xlim(-0.5, self._bz.gap + 0.5)
2988        else:
2989            plt.xlim(xlim)
2990        plt.ylim([1e13 * relaxation_time, 1e20 * relaxation_time])
2991        plt.ylabel("conductivity,\n $\\Sigma$ (1/($\\Omega$ m))", fontsize=30.0)
2992        plt.xlabel("E-E$_f$ (eV)", fontsize=30.0)
2993        plt.xticks(fontsize=25)
2994        plt.yticks(fontsize=25)
2995        plt.tight_layout()
2996        return plt
2997
2998    def plot_power_factor_mu(self, temp=600, output="eig", relaxation_time=1e-14, xlim=None):
2999        """
3000        Plot the power factor in function of Fermi level. Semi-log plot
3001
3002        Args:
3003            temp: the temperature
3004            xlim: a list of min and max fermi energy by default (0, and band
3005                gap)
3006            tau: A relaxation time in s. By default none and the plot is by
3007               units of relaxation time
3008
3009        Returns:
3010            a matplotlib object
3011        """
3012        plt = pretty_plot(9, 7)
3013        pf = self._bz.get_power_factor(relaxation_time=relaxation_time, output=output, doping_levels=False)[temp]
3014        plt.semilogy(self._bz.mu_steps, pf, linewidth=3.0)
3015        self._plot_bg_limits(plt)
3016        self._plot_doping(plt, temp)
3017        if output == "eig":
3018            plt.legend(["PF$_1$", "PF$_2$", "PF$_3$"])
3019        if xlim is None:
3020            plt.xlim(-0.5, self._bz.gap + 0.5)
3021        else:
3022            plt.xlim(xlim)
3023        plt.ylabel("Power factor, ($\\mu$W/(mK$^2$))", fontsize=30.0)
3024        plt.xlabel("E-E$_f$ (eV)", fontsize=30.0)
3025        plt.xticks(fontsize=25)
3026        plt.yticks(fontsize=25)
3027        plt.tight_layout()
3028        return plt
3029
3030    def plot_zt_mu(self, temp=600, output="eig", relaxation_time=1e-14, xlim=None):
3031        """
3032        Plot the ZT in function of Fermi level.
3033
3034        Args:
3035            temp: the temperature
3036            xlim: a list of min and max fermi energy by default (0, and band
3037                gap)
3038            tau: A relaxation time in s. By default none and the plot is by
3039               units of relaxation time
3040
3041        Returns:
3042            a matplotlib object
3043        """
3044        plt = pretty_plot(9, 7)
3045        zt = self._bz.get_zt(relaxation_time=relaxation_time, output=output, doping_levels=False)[temp]
3046        plt.plot(self._bz.mu_steps, zt, linewidth=3.0)
3047        self._plot_bg_limits(plt)
3048        self._plot_doping(plt, temp)
3049        if output == "eig":
3050            plt.legend(["ZT$_1$", "ZT$_2$", "ZT$_3$"])
3051        if xlim is None:
3052            plt.xlim(-0.5, self._bz.gap + 0.5)
3053        else:
3054            plt.xlim(xlim)
3055        plt.ylabel("ZT", fontsize=30.0)
3056        plt.xlabel("E-E$_f$ (eV)", fontsize=30.0)
3057        plt.xticks(fontsize=25)
3058        plt.yticks(fontsize=25)
3059        plt.tight_layout()
3060        return plt
3061
3062    def plot_seebeck_temp(self, doping="all", output="average"):
3063        """
3064        Plot the Seebeck coefficient in function of temperature for different
3065        doping levels.
3066
3067        Args:
3068            dopings: the default 'all' plots all the doping levels in the analyzer.
3069                     Specify a list of doping levels if you want to plot only some.
3070            output: with 'average' you get an average of the three directions
3071                    with 'eigs' you get all the three directions.
3072        Returns:
3073            a matplotlib object
3074        """
3075
3076        if output == "average":
3077            sbk = self._bz.get_seebeck(output="average")
3078        elif output == "eigs":
3079            sbk = self._bz.get_seebeck(output="eigs")
3080
3081        plt = pretty_plot(22, 14)
3082        tlist = sorted(sbk["n"].keys())
3083        doping = self._bz.doping["n"] if doping == "all" else doping
3084        for i, dt in enumerate(["n", "p"]):
3085            plt.subplot(121 + i)
3086            for dop in doping:
3087                d = self._bz.doping[dt].index(dop)
3088                sbk_temp = []
3089                for temp in tlist:
3090                    sbk_temp.append(sbk[dt][temp][d])
3091                if output == "average":
3092                    plt.plot(tlist, sbk_temp, marker="s", label=str(dop) + " $cm^{-3}$")
3093                elif output == "eigs":
3094                    for xyz in range(3):
3095                        plt.plot(
3096                            tlist,
3097                            list(zip(*sbk_temp))[xyz],
3098                            marker="s",
3099                            label=str(xyz) + " " + str(dop) + " $cm^{-3}$",
3100                        )
3101            plt.title(dt + "-type", fontsize=20)
3102            if i == 0:
3103                plt.ylabel("Seebeck \n coefficient  ($\\mu$V/K)", fontsize=30.0)
3104            plt.xlabel("Temperature (K)", fontsize=30.0)
3105
3106            p = "lower right" if i == 0 else "best"
3107            plt.legend(loc=p, fontsize=15)
3108            plt.grid()
3109            plt.xticks(fontsize=25)
3110            plt.yticks(fontsize=25)
3111
3112        plt.tight_layout()
3113
3114        return plt
3115
3116    def plot_conductivity_temp(self, doping="all", output="average", relaxation_time=1e-14):
3117        """
3118        Plot the conductivity in function of temperature for different doping levels.
3119
3120        Args:
3121            dopings: the default 'all' plots all the doping levels in the analyzer.
3122                     Specify a list of doping levels if you want to plot only some.
3123            output: with 'average' you get an average of the three directions
3124                    with 'eigs' you get all the three directions.
3125            relaxation_time: specify a constant relaxation time value
3126
3127        Returns:
3128            a matplotlib object
3129        """
3130
3131        if output == "average":
3132            cond = self._bz.get_conductivity(relaxation_time=relaxation_time, output="average")
3133        elif output == "eigs":
3134            cond = self._bz.get_conductivity(relaxation_time=relaxation_time, output="eigs")
3135
3136        plt = pretty_plot(22, 14)
3137        tlist = sorted(cond["n"].keys())
3138        doping = self._bz.doping["n"] if doping == "all" else doping
3139        for i, dt in enumerate(["n", "p"]):
3140            plt.subplot(121 + i)
3141            for dop in doping:
3142                d = self._bz.doping[dt].index(dop)
3143                cond_temp = []
3144                for temp in tlist:
3145                    cond_temp.append(cond[dt][temp][d])
3146                if output == "average":
3147                    plt.plot(tlist, cond_temp, marker="s", label=str(dop) + " $cm^{-3}$")
3148                elif output == "eigs":
3149                    for xyz in range(3):
3150                        plt.plot(
3151                            tlist,
3152                            list(zip(*cond_temp))[xyz],
3153                            marker="s",
3154                            label=str(xyz) + " " + str(dop) + " $cm^{-3}$",
3155                        )
3156            plt.title(dt + "-type", fontsize=20)
3157            if i == 0:
3158                plt.ylabel("conductivity $\\sigma$ (1/($\\Omega$ m))", fontsize=30.0)
3159            plt.xlabel("Temperature (K)", fontsize=30.0)
3160
3161            p = "best"  # 'lower right' if i == 0 else ''
3162            plt.legend(loc=p, fontsize=15)
3163            plt.grid()
3164            plt.xticks(fontsize=25)
3165            plt.yticks(fontsize=25)
3166            plt.ticklabel_format(style="sci", axis="y", scilimits=(0, 0))
3167
3168        plt.tight_layout()
3169
3170        return plt
3171
3172    def plot_power_factor_temp(self, doping="all", output="average", relaxation_time=1e-14):
3173        """
3174        Plot the Power Factor in function of temperature for different doping levels.
3175
3176        Args:
3177            dopings: the default 'all' plots all the doping levels in the analyzer.
3178                     Specify a list of doping levels if you want to plot only some.
3179            output: with 'average' you get an average of the three directions
3180                    with 'eigs' you get all the three directions.
3181            relaxation_time: specify a constant relaxation time value
3182
3183        Returns:
3184            a matplotlib object
3185        """
3186
3187        if output == "average":
3188            pf = self._bz.get_power_factor(relaxation_time=relaxation_time, output="average")
3189        elif output == "eigs":
3190            pf = self._bz.get_power_factor(relaxation_time=relaxation_time, output="eigs")
3191
3192        plt = pretty_plot(22, 14)
3193        tlist = sorted(pf["n"].keys())
3194        doping = self._bz.doping["n"] if doping == "all" else doping
3195        for i, dt in enumerate(["n", "p"]):
3196            plt.subplot(121 + i)
3197            for dop in doping:
3198                d = self._bz.doping[dt].index(dop)
3199                pf_temp = []
3200                for temp in tlist:
3201                    pf_temp.append(pf[dt][temp][d])
3202                if output == "average":
3203                    plt.plot(tlist, pf_temp, marker="s", label=str(dop) + " $cm^{-3}$")
3204                elif output == "eigs":
3205                    for xyz in range(3):
3206                        plt.plot(
3207                            tlist,
3208                            list(zip(*pf_temp))[xyz],
3209                            marker="s",
3210                            label=str(xyz) + " " + str(dop) + " $cm^{-3}$",
3211                        )
3212            plt.title(dt + "-type", fontsize=20)
3213            if i == 0:
3214                plt.ylabel("Power Factor ($\\mu$W/(mK$^2$))", fontsize=30.0)
3215            plt.xlabel("Temperature (K)", fontsize=30.0)
3216
3217            p = "best"  # 'lower right' if i == 0 else ''
3218            plt.legend(loc=p, fontsize=15)
3219            plt.grid()
3220            plt.xticks(fontsize=25)
3221            plt.yticks(fontsize=25)
3222            plt.ticklabel_format(style="sci", axis="y", scilimits=(0, 0))
3223
3224        plt.tight_layout()
3225        return plt
3226
3227    def plot_zt_temp(self, doping="all", output="average", relaxation_time=1e-14):
3228        """
3229        Plot the figure of merit zT in function of temperature for different doping levels.
3230
3231        Args:
3232            dopings: the default 'all' plots all the doping levels in the analyzer.
3233                     Specify a list of doping levels if you want to plot only some.
3234            output: with 'average' you get an average of the three directions
3235                    with 'eigs' you get all the three directions.
3236            relaxation_time: specify a constant relaxation time value
3237
3238        Returns:
3239            a matplotlib object
3240        """
3241
3242        if output == "average":
3243            zt = self._bz.get_zt(relaxation_time=relaxation_time, output="average")
3244        elif output == "eigs":
3245            zt = self._bz.get_zt(relaxation_time=relaxation_time, output="eigs")
3246
3247        plt = pretty_plot(22, 14)
3248        tlist = sorted(zt["n"].keys())
3249        doping = self._bz.doping["n"] if doping == "all" else doping
3250        for i, dt in enumerate(["n", "p"]):
3251            plt.subplot(121 + i)
3252            for dop in doping:
3253                d = self._bz.doping[dt].index(dop)
3254                zt_temp = []
3255                for temp in tlist:
3256                    zt_temp.append(zt[dt][temp][d])
3257                if output == "average":
3258                    plt.plot(tlist, zt_temp, marker="s", label=str(dop) + " $cm^{-3}$")
3259                elif output == "eigs":
3260                    for xyz in range(3):
3261                        plt.plot(
3262                            tlist,
3263                            list(zip(*zt_temp))[xyz],
3264                            marker="s",
3265                            label=str(xyz) + " " + str(dop) + " $cm^{-3}$",
3266                        )
3267            plt.title(dt + "-type", fontsize=20)
3268            if i == 0:
3269                plt.ylabel("zT", fontsize=30.0)
3270            plt.xlabel("Temperature (K)", fontsize=30.0)
3271
3272            p = "best"  # 'lower right' if i == 0 else ''
3273            plt.legend(loc=p, fontsize=15)
3274            plt.grid()
3275            plt.xticks(fontsize=25)
3276            plt.yticks(fontsize=25)
3277
3278        plt.tight_layout()
3279        return plt
3280
3281    def plot_eff_mass_temp(self, doping="all", output="average"):
3282        """
3283        Plot the average effective mass in function of temperature
3284        for different doping levels.
3285
3286        Args:
3287            dopings: the default 'all' plots all the doping levels in the analyzer.
3288                     Specify a list of doping levels if you want to plot only some.
3289            output: with 'average' you get an average of the three directions
3290                    with 'eigs' you get all the three directions.
3291
3292        Returns:
3293            a matplotlib object
3294        """
3295
3296        if output == "average":
3297            em = self._bz.get_average_eff_mass(output="average")
3298        elif output == "eigs":
3299            em = self._bz.get_average_eff_mass(output="eigs")
3300
3301        plt = pretty_plot(22, 14)
3302        tlist = sorted(em["n"].keys())
3303        doping = self._bz.doping["n"] if doping == "all" else doping
3304        for i, dt in enumerate(["n", "p"]):
3305            plt.subplot(121 + i)
3306            for dop in doping:
3307                d = self._bz.doping[dt].index(dop)
3308                em_temp = []
3309                for temp in tlist:
3310                    em_temp.append(em[dt][temp][d])
3311                if output == "average":
3312                    plt.plot(tlist, em_temp, marker="s", label=str(dop) + " $cm^{-3}$")
3313                elif output == "eigs":
3314                    for xyz in range(3):
3315                        plt.plot(
3316                            tlist,
3317                            list(zip(*em_temp))[xyz],
3318                            marker="s",
3319                            label=str(xyz) + " " + str(dop) + " $cm^{-3}$",
3320                        )
3321            plt.title(dt + "-type", fontsize=20)
3322            if i == 0:
3323                plt.ylabel("Effective mass (m$_e$)", fontsize=30.0)
3324            plt.xlabel("Temperature (K)", fontsize=30.0)
3325
3326            p = "best"  # 'lower right' if i == 0 else ''
3327            plt.legend(loc=p, fontsize=15)
3328            plt.grid()
3329            plt.xticks(fontsize=25)
3330            plt.yticks(fontsize=25)
3331
3332        plt.tight_layout()
3333        return plt
3334
3335    def plot_seebeck_dop(self, temps="all", output="average"):
3336        """
3337        Plot the Seebeck in function of doping levels for different temperatures.
3338
3339        Args:
3340            temps: the default 'all' plots all the temperatures in the analyzer.
3341                   Specify a list of temperatures if you want to plot only some.
3342            output: with 'average' you get an average of the three directions
3343                    with 'eigs' you get all the three directions.
3344
3345        Returns:
3346            a matplotlib object
3347        """
3348
3349        if output == "average":
3350            sbk = self._bz.get_seebeck(output="average")
3351        elif output == "eigs":
3352            sbk = self._bz.get_seebeck(output="eigs")
3353
3354        tlist = sorted(sbk["n"].keys()) if temps == "all" else temps
3355        plt = pretty_plot(22, 14)
3356        for i, dt in enumerate(["n", "p"]):
3357            plt.subplot(121 + i)
3358            for temp in tlist:
3359                if output == "eigs":
3360                    for xyz in range(3):
3361                        plt.semilogx(
3362                            self._bz.doping[dt],
3363                            list(zip(*sbk[dt][temp]))[xyz],
3364                            marker="s",
3365                            label=str(xyz) + " " + str(temp) + " K",
3366                        )
3367                elif output == "average":
3368                    plt.semilogx(
3369                        self._bz.doping[dt],
3370                        sbk[dt][temp],
3371                        marker="s",
3372                        label=str(temp) + " K",
3373                    )
3374            plt.title(dt + "-type", fontsize=20)
3375            if i == 0:
3376                plt.ylabel("Seebeck coefficient ($\\mu$V/K)", fontsize=30.0)
3377            plt.xlabel("Doping concentration (cm$^{-3}$)", fontsize=30.0)
3378
3379            p = "lower right" if i == 0 else "best"
3380            plt.legend(loc=p, fontsize=15)
3381            plt.grid()
3382            plt.xticks(fontsize=25)
3383            plt.yticks(fontsize=25)
3384
3385        plt.tight_layout()
3386
3387        return plt
3388
3389    def plot_conductivity_dop(self, temps="all", output="average", relaxation_time=1e-14):
3390        """
3391        Plot the conductivity in function of doping levels for different
3392        temperatures.
3393
3394        Args:
3395            temps: the default 'all' plots all the temperatures in the analyzer.
3396                   Specify a list of temperatures if you want to plot only some.
3397            output: with 'average' you get an average of the three directions
3398                    with 'eigs' you get all the three directions.
3399            relaxation_time: specify a constant relaxation time value
3400
3401        Returns:
3402            a matplotlib object
3403        """
3404        if output == "average":
3405            cond = self._bz.get_conductivity(relaxation_time=relaxation_time, output="average")
3406        elif output == "eigs":
3407            cond = self._bz.get_conductivity(relaxation_time=relaxation_time, output="eigs")
3408
3409        tlist = sorted(cond["n"].keys()) if temps == "all" else temps
3410        plt = pretty_plot(22, 14)
3411        for i, dt in enumerate(["n", "p"]):
3412            plt.subplot(121 + i)
3413            for temp in tlist:
3414                if output == "eigs":
3415                    for xyz in range(3):
3416                        plt.semilogx(
3417                            self._bz.doping[dt],
3418                            list(zip(*cond[dt][temp]))[xyz],
3419                            marker="s",
3420                            label=str(xyz) + " " + str(temp) + " K",
3421                        )
3422                elif output == "average":
3423                    plt.semilogx(
3424                        self._bz.doping[dt],
3425                        cond[dt][temp],
3426                        marker="s",
3427                        label=str(temp) + " K",
3428                    )
3429            plt.title(dt + "-type", fontsize=20)
3430            if i == 0:
3431                plt.ylabel("conductivity $\\sigma$ (1/($\\Omega$ m))", fontsize=30.0)
3432            plt.xlabel("Doping concentration ($cm^{-3}$)", fontsize=30.0)
3433            plt.ticklabel_format(style="sci", axis="y", scilimits=(0, 0))
3434            plt.legend(fontsize=15)
3435            plt.grid()
3436            plt.xticks(fontsize=25)
3437            plt.yticks(fontsize=25)
3438
3439        plt.tight_layout()
3440
3441        return plt
3442
3443    def plot_power_factor_dop(self, temps="all", output="average", relaxation_time=1e-14):
3444        """
3445        Plot the Power Factor in function of doping levels for different temperatures.
3446
3447        Args:
3448            temps: the default 'all' plots all the temperatures in the analyzer.
3449                   Specify a list of temperatures if you want to plot only some.
3450            output: with 'average' you get an average of the three directions
3451                    with 'eigs' you get all the three directions.
3452            relaxation_time: specify a constant relaxation time value
3453
3454        Returns:
3455            a matplotlib object
3456        """
3457        if output == "average":
3458            pf = self._bz.get_power_factor(relaxation_time=relaxation_time, output="average")
3459        elif output == "eigs":
3460            pf = self._bz.get_power_factor(relaxation_time=relaxation_time, output="eigs")
3461
3462        tlist = sorted(pf["n"].keys()) if temps == "all" else temps
3463        plt = pretty_plot(22, 14)
3464        for i, dt in enumerate(["n", "p"]):
3465            plt.subplot(121 + i)
3466            for temp in tlist:
3467                if output == "eigs":
3468                    for xyz in range(3):
3469                        plt.semilogx(
3470                            self._bz.doping[dt],
3471                            list(zip(*pf[dt][temp]))[xyz],
3472                            marker="s",
3473                            label=str(xyz) + " " + str(temp) + " K",
3474                        )
3475                elif output == "average":
3476                    plt.semilogx(
3477                        self._bz.doping[dt],
3478                        pf[dt][temp],
3479                        marker="s",
3480                        label=str(temp) + " K",
3481                    )
3482            plt.title(dt + "-type", fontsize=20)
3483            if i == 0:
3484                plt.ylabel("Power Factor  ($\\mu$W/(mK$^2$))", fontsize=30.0)
3485            plt.xlabel("Doping concentration ($cm^{-3}$)", fontsize=30.0)
3486            plt.ticklabel_format(style="sci", axis="y", scilimits=(0, 0))
3487            p = "best"  # 'lower right' if i == 0 else ''
3488            plt.legend(loc=p, fontsize=15)
3489            plt.grid()
3490            plt.xticks(fontsize=25)
3491            plt.yticks(fontsize=25)
3492
3493        plt.tight_layout()
3494
3495        return plt
3496
3497    def plot_zt_dop(self, temps="all", output="average", relaxation_time=1e-14):
3498        """
3499        Plot the figure of merit zT in function of doping levels for different
3500        temperatures.
3501
3502        Args:
3503            temps: the default 'all' plots all the temperatures in the analyzer.
3504                   Specify a list of temperatures if you want to plot only some.
3505            output: with 'average' you get an average of the three directions
3506                    with 'eigs' you get all the three directions.
3507            relaxation_time: specify a constant relaxation time value
3508
3509        Returns:
3510            a matplotlib object
3511        """
3512        if output == "average":
3513            zt = self._bz.get_zt(relaxation_time=relaxation_time, output="average")
3514        elif output == "eigs":
3515            zt = self._bz.get_zt(relaxation_time=relaxation_time, output="eigs")
3516
3517        tlist = sorted(zt["n"].keys()) if temps == "all" else temps
3518        plt = pretty_plot(22, 14)
3519        for i, dt in enumerate(["n", "p"]):
3520            plt.subplot(121 + i)
3521            for temp in tlist:
3522                if output == "eigs":
3523                    for xyz in range(3):
3524                        plt.semilogx(
3525                            self._bz.doping[dt],
3526                            list(zip(*zt[dt][temp]))[xyz],
3527                            marker="s",
3528                            label=str(xyz) + " " + str(temp) + " K",
3529                        )
3530                elif output == "average":
3531                    plt.semilogx(
3532                        self._bz.doping[dt],
3533                        zt[dt][temp],
3534                        marker="s",
3535                        label=str(temp) + " K",
3536                    )
3537            plt.title(dt + "-type", fontsize=20)
3538            if i == 0:
3539                plt.ylabel("zT", fontsize=30.0)
3540            plt.xlabel("Doping concentration ($cm^{-3}$)", fontsize=30.0)
3541
3542            p = "lower right" if i == 0 else "best"
3543            plt.legend(loc=p, fontsize=15)
3544            plt.grid()
3545            plt.xticks(fontsize=25)
3546            plt.yticks(fontsize=25)
3547
3548        plt.tight_layout()
3549
3550        return plt
3551
3552    def plot_eff_mass_dop(self, temps="all", output="average"):
3553        """
3554        Plot the average effective mass in function of doping levels
3555        for different temperatures.
3556
3557        Args:
3558            temps: the default 'all' plots all the temperatures in the analyzer.
3559                   Specify a list of temperatures if you want to plot only some.
3560            output: with 'average' you get an average of the three directions
3561                    with 'eigs' you get all the three directions.
3562            relaxation_time: specify a constant relaxation time value
3563
3564        Returns:
3565            a matplotlib object
3566        """
3567
3568        if output == "average":
3569            em = self._bz.get_average_eff_mass(output="average")
3570        elif output == "eigs":
3571            em = self._bz.get_average_eff_mass(output="eigs")
3572
3573        tlist = sorted(em["n"].keys()) if temps == "all" else temps
3574        plt = pretty_plot(22, 14)
3575        for i, dt in enumerate(["n", "p"]):
3576            plt.subplot(121 + i)
3577            for temp in tlist:
3578                if output == "eigs":
3579                    for xyz in range(3):
3580                        plt.semilogx(
3581                            self._bz.doping[dt],
3582                            list(zip(*em[dt][temp]))[xyz],
3583                            marker="s",
3584                            label=str(xyz) + " " + str(temp) + " K",
3585                        )
3586                elif output == "average":
3587                    plt.semilogx(
3588                        self._bz.doping[dt],
3589                        em[dt][temp],
3590                        marker="s",
3591                        label=str(temp) + " K",
3592                    )
3593            plt.title(dt + "-type", fontsize=20)
3594            if i == 0:
3595                plt.ylabel("Effective mass (m$_e$)", fontsize=30.0)
3596            plt.xlabel("Doping concentration ($cm^{-3}$)", fontsize=30.0)
3597
3598            p = "lower right" if i == 0 else "best"
3599            plt.legend(loc=p, fontsize=15)
3600            plt.grid()
3601            plt.xticks(fontsize=25)
3602            plt.yticks(fontsize=25)
3603
3604        plt.tight_layout()
3605
3606        return plt
3607
3608    def plot_dos(self, sigma=0.05):
3609        """
3610        plot dos
3611
3612        Args:
3613            sigma: a smearing
3614
3615        Returns:
3616            a matplotlib object
3617        """
3618        plotter = DosPlotter(sigma=sigma)
3619        plotter.add_dos("t", self._bz.dos)
3620        return plotter.get_plot()
3621
3622    def plot_carriers(self, temp=300):
3623        """
3624        Plot the carrier concentration in function of Fermi level
3625
3626        Args:
3627            temp: the temperature
3628
3629        Returns:
3630            a matplotlib object
3631        """
3632        plt = pretty_plot(9, 7)
3633        carriers = [abs(c / (self._bz.vol * 1e-24)) for c in self._bz._carrier_conc[temp]]
3634        plt.semilogy(self._bz.mu_steps, carriers, linewidth=3.0, color="r")
3635        self._plot_bg_limits(plt)
3636        self._plot_doping(plt, temp)
3637        plt.xlim(-0.5, self._bz.gap + 0.5)
3638        plt.ylim(1e14, 1e22)
3639        plt.ylabel("carrier concentration (cm-3)", fontsize=30.0)
3640        plt.xlabel("E-E$_f$ (eV)", fontsize=30)
3641        plt.xticks(fontsize=25)
3642        plt.yticks(fontsize=25)
3643        plt.tight_layout()
3644        return plt
3645
3646    def plot_hall_carriers(self, temp=300):
3647        """
3648        Plot the Hall carrier concentration in function of Fermi level
3649
3650        Args:
3651            temp: the temperature
3652
3653        Returns:
3654            a matplotlib object
3655        """
3656        plt = pretty_plot(9, 7)
3657        hall_carriers = [abs(i) for i in self._bz.get_hall_carrier_concentration()[temp]]
3658        plt.semilogy(self._bz.mu_steps, hall_carriers, linewidth=3.0, color="r")
3659        self._plot_bg_limits(plt)
3660        self._plot_doping(plt, temp)
3661        plt.xlim(-0.5, self._bz.gap + 0.5)
3662        plt.ylim(1e14, 1e22)
3663        plt.ylabel("Hall carrier concentration (cm-3)", fontsize=30.0)
3664        plt.xlabel("E-E$_f$ (eV)", fontsize=30)
3665        plt.xticks(fontsize=25)
3666        plt.yticks(fontsize=25)
3667        plt.tight_layout()
3668        return plt
3669
3670
3671class CohpPlotter:
3672    """
3673    Class for plotting crystal orbital Hamilton populations (COHPs) or
3674    crystal orbital overlap populations (COOPs). It is modeled after the
3675    DosPlotter object.
3676    """
3677
3678    def __init__(self, zero_at_efermi=True, are_coops=False, are_cobis=False):
3679        """
3680        Args:
3681            zero_at_efermi: Whether to shift all populations to have zero
3682                energy at the Fermi level. Defaults to True.
3683            are_coops: Switch to indicate that these are COOPs, not COHPs.
3684                Defaults to False for COHPs.
3685            are_cobis: Switch to indicate that these are COBIs, not COHPs/COOPs.
3686                Defaults to False for COHPs
3687        """
3688        self.zero_at_efermi = zero_at_efermi
3689        self.are_coops = are_coops
3690        self.are_cobis = are_cobis
3691        self._cohps = OrderedDict()
3692
3693    def add_cohp(self, label, cohp):
3694        """
3695        Adds a COHP for plotting.
3696
3697        Args:
3698            label: Label for the COHP. Must be unique.
3699
3700            cohp: COHP object.
3701        """
3702        energies = cohp.energies - cohp.efermi if self.zero_at_efermi else cohp.energies
3703        populations = cohp.get_cohp()
3704        int_populations = cohp.get_icohp()
3705        self._cohps[label] = {
3706            "energies": energies,
3707            "COHP": populations,
3708            "ICOHP": int_populations,
3709            "efermi": cohp.efermi,
3710        }
3711
3712    def add_cohp_dict(self, cohp_dict, key_sort_func=None):
3713        """
3714        Adds a dictionary of COHPs with an optional sorting function
3715        for the keys.
3716
3717        Args:
3718            cohp_dict: dict of the form {label: Cohp}
3719
3720            key_sort_func: function used to sort the cohp_dict keys.
3721        """
3722        if key_sort_func:
3723            keys = sorted(cohp_dict.keys(), key=key_sort_func)
3724        else:
3725            keys = cohp_dict.keys()
3726        for label in keys:
3727            self.add_cohp(label, cohp_dict[label])
3728
3729    def get_cohp_dict(self):
3730        """
3731        Returns the added COHPs as a json-serializable dict. Note that if you
3732        have specified smearing for the COHP plot, the populations returned
3733        will be the smeared and not the original populations.
3734
3735        Returns:
3736            dict: Dict of COHP data of the form {label: {"efermi": efermi,
3737            "energies": ..., "COHP": {Spin.up: ...}, "ICOHP": ...}}.
3738        """
3739        return jsanitize(self._cohps)
3740
3741    def get_plot(
3742        self,
3743        xlim=None,
3744        ylim=None,
3745        plot_negative=None,
3746        integrated=False,
3747        invert_axes=True,
3748    ):
3749        """
3750        Get a matplotlib plot showing the COHP.
3751
3752        Args:
3753            xlim: Specifies the x-axis limits. Defaults to None for
3754                automatic determination.
3755
3756            ylim: Specifies the y-axis limits. Defaults to None for
3757                automatic determination.
3758
3759            plot_negative: It is common to plot -COHP(E) so that the
3760                sign means the same for COOPs and COHPs. Defaults to None
3761                for automatic determination: If are_coops is True, this
3762                will be set to False, else it will be set to True.
3763
3764            integrated: Switch to plot ICOHPs. Defaults to False.
3765
3766            invert_axes: Put the energies onto the y-axis, which is
3767                common in chemistry.
3768
3769        Returns:
3770            A matplotlib object.
3771        """
3772        if self.are_coops:
3773            cohp_label = "COOP"
3774        elif self.are_cobis:
3775            cohp_label = "COBI"
3776        else:
3777            cohp_label = "COHP"
3778
3779        if plot_negative is None:
3780            plot_negative = (not self.are_coops) and (not self.are_cobis)
3781
3782        if integrated:
3783            cohp_label = "I" + cohp_label + " (eV)"
3784
3785        if plot_negative:
3786            cohp_label = "-" + cohp_label
3787
3788        if self.zero_at_efermi:
3789            energy_label = "$E - E_f$ (eV)"
3790        else:
3791            energy_label = "$E$ (eV)"
3792
3793        ncolors = max(3, len(self._cohps))
3794        ncolors = min(9, ncolors)
3795
3796        import palettable
3797
3798        # pylint: disable=E1101
3799        colors = palettable.colorbrewer.qualitative.Set1_9.mpl_colors
3800
3801        plt = pretty_plot(12, 8)
3802
3803        allpts = []
3804        keys = self._cohps.keys()
3805        for i, key in enumerate(keys):
3806            energies = self._cohps[key]["energies"]
3807            if not integrated:
3808                populations = self._cohps[key]["COHP"]
3809            else:
3810                populations = self._cohps[key]["ICOHP"]
3811            for spin in [Spin.up, Spin.down]:
3812                if spin in populations:
3813                    if invert_axes:
3814                        x = -populations[spin] if plot_negative else populations[spin]
3815                        y = energies
3816                    else:
3817                        x = energies
3818                        y = -populations[spin] if plot_negative else populations[spin]
3819                    allpts.extend(list(zip(x, y)))
3820                    if spin == Spin.up:
3821                        plt.plot(
3822                            x,
3823                            y,
3824                            color=colors[i % ncolors],
3825                            linestyle="-",
3826                            label=str(key),
3827                            linewidth=3,
3828                        )
3829                    else:
3830                        plt.plot(x, y, color=colors[i % ncolors], linestyle="--", linewidth=3)
3831
3832        if xlim:
3833            plt.xlim(xlim)
3834        if ylim:
3835            plt.ylim(ylim)
3836        else:
3837            xlim = plt.xlim()
3838            relevanty = [p[1] for p in allpts if xlim[0] < p[0] < xlim[1]]
3839            plt.ylim((min(relevanty), max(relevanty)))
3840
3841        xlim = plt.xlim()
3842        ylim = plt.ylim()
3843        if not invert_axes:
3844            plt.plot(xlim, [0, 0], "k-", linewidth=2)
3845            if self.zero_at_efermi:
3846                plt.plot([0, 0], ylim, "k--", linewidth=2)
3847            else:
3848                plt.plot(
3849                    [self._cohps[key]["efermi"], self._cohps[key]["efermi"]],
3850                    ylim,
3851                    color=colors[i % ncolors],
3852                    linestyle="--",
3853                    linewidth=2,
3854                )
3855        else:
3856            plt.plot([0, 0], ylim, "k-", linewidth=2)
3857            if self.zero_at_efermi:
3858                plt.plot(xlim, [0, 0], "k--", linewidth=2)
3859            else:
3860                plt.plot(
3861                    xlim,
3862                    [self._cohps[key]["efermi"], self._cohps[key]["efermi"]],
3863                    color=colors[i % ncolors],
3864                    linestyle="--",
3865                    linewidth=2,
3866                )
3867
3868        if invert_axes:
3869            plt.xlabel(cohp_label)
3870            plt.ylabel(energy_label)
3871        else:
3872            plt.xlabel(energy_label)
3873            plt.ylabel(cohp_label)
3874
3875        plt.legend()
3876        leg = plt.gca().get_legend()
3877        ltext = leg.get_texts()
3878        plt.setp(ltext, fontsize=30)
3879        plt.tight_layout()
3880        return plt
3881
3882    def save_plot(self, filename, img_format="eps", xlim=None, ylim=None):
3883        """
3884        Save matplotlib plot to a file.
3885
3886        Args:
3887            filename: File name to write to.
3888            img_format: Image format to use. Defaults to EPS.
3889            xlim: Specifies the x-axis limits. Defaults to None for
3890                automatic determination.
3891            ylim: Specifies the y-axis limits. Defaults to None for
3892                automatic determination.
3893        """
3894        plt = self.get_plot(xlim, ylim)
3895        plt.savefig(filename, format=img_format)
3896
3897    def show(self, xlim=None, ylim=None):
3898        """
3899        Show the plot using matplotlib.
3900
3901        Args:
3902            xlim: Specifies the x-axis limits. Defaults to None for
3903                automatic determination.
3904            ylim: Specifies the y-axis limits. Defaults to None for
3905                automatic determination.
3906        """
3907        plt = self.get_plot(xlim, ylim)
3908        plt.show()
3909
3910
3911@requires(mlab is not None, "MayAvi mlab not imported! Please install mayavi.")
3912def plot_fermi_surface(
3913    data,
3914    structure,
3915    cbm,
3916    energy_levels=None,
3917    multiple_figure=True,
3918    mlab_figure=None,
3919    kpoints_dict=None,
3920    colors=None,
3921    transparency_factor=None,
3922    labels_scale_factor=0.05,
3923    points_scale_factor=0.02,
3924    interative=True,
3925):
3926    """
3927    Plot the Fermi surface at specific energy value using Boltztrap 1 FERMI
3928    mode.
3929
3930    The easiest way to use this plotter is:
3931
3932        1. Run boltztrap in 'FERMI' mode using BoltztrapRunner,
3933        2. Load BoltztrapAnalyzer using your method of choice (e.g., from_files)
3934        3. Pass in your BoltztrapAnalyzer's fermi_surface_data as this
3935            function's data argument.
3936
3937    Args:
3938        data: energy values in a 3D grid from a CUBE file via read_cube_file
3939            function, or from a BoltztrapAnalyzer.fermi_surface_data
3940        structure: structure object of the material
3941        energy_levels ([float]): Energy values for plotting the fermi surface(s)
3942            By default 0 eV correspond to the VBM, as in the plot of band
3943            structure along symmetry line.
3944            Default: One surface, with max energy value + 0.01 eV
3945        cbm (bool): Boolean value to specify if the considered band is a
3946            conduction band or not
3947        multiple_figure (bool): If True a figure for each energy level will be
3948            shown.  If False all the surfaces will be shown in the same figure.
3949            In this last case, tune the transparency factor.
3950        mlab_figure (mayavi.mlab.figure): A previous figure to plot a new
3951            surface on.
3952        kpoints_dict (dict): dictionary of kpoints to label in the plot.
3953            Example: {"K":[0.5,0.0,0.5]}, coords are fractional
3954        colors ([tuple]): Iterable of 3-tuples (r,g,b) of integers to define
3955            the colors of each surface (one per energy level).
3956            Should be the same length as the number of surfaces being plotted.
3957            Example (3 surfaces): colors=[(1,0,0), (0,1,0), (0,0,1)]
3958            Example (2 surfaces): colors=[(0, 0.5, 0.5)]
3959        transparency_factor [float]: Values in the range [0,1] to tune the
3960            opacity of each surface. Should be one transparency_factor per
3961            surface.
3962        labels_scale_factor (float): factor to tune size of the kpoint labels
3963        points_scale_factor (float): factor to tune size of the kpoint points
3964        interative (bool): if True an interactive figure will be shown.
3965            If False a non interactive figure will be shown, but it is possible
3966            to plot other surfaces on the same figure. To make it interactive,
3967            run mlab.show().
3968    Returns:
3969        ((mayavi.mlab.figure, mayavi.mlab)): The mlab plotter and an interactive
3970            figure to control the plot.
3971
3972    Note: Experimental.
3973          Please, double check the surface shown by using some
3974          other software and report issues.
3975    """
3976    bz = structure.lattice.reciprocal_lattice.get_wigner_seitz_cell()
3977    cell = structure.lattice.reciprocal_lattice.matrix
3978
3979    fact = 1 if not cbm else -1
3980    data_1d = data.ravel()
3981    en_min = np.min(fact * data_1d)
3982    en_max = np.max(fact * data_1d)
3983
3984    if energy_levels is None:
3985        energy_levels = [en_min + 0.01] if cbm else [en_max - 0.01]
3986        print("Energy level set to: " + str(energy_levels[0]) + " eV")
3987
3988    else:
3989        for e in energy_levels:
3990            if e > en_max or e < en_min:
3991                raise BoltztrapError(
3992                    "energy level "
3993                    + str(e)
3994                    + " not in the range of possible energies: ["
3995                    + str(en_min)
3996                    + ", "
3997                    + str(en_max)
3998                    + "]"
3999                )
4000
4001    n_surfaces = len(energy_levels)
4002    if colors is None:
4003        colors = [(0, 0, 1)] * n_surfaces
4004
4005    if transparency_factor is None:
4006        transparency_factor = [1] * n_surfaces
4007
4008    if mlab_figure:
4009        fig = mlab_figure
4010
4011    if kpoints_dict is None:
4012        kpoints_dict = {}
4013
4014    if mlab_figure is None and not multiple_figure:
4015        fig = mlab.figure(size=(1024, 768), bgcolor=(1, 1, 1))
4016        for iface in range(len(bz)):  # pylint: disable=C0200
4017            for line in itertools.combinations(bz[iface], 2):
4018                for jface in range(len(bz)):  # pylint: disable=C0200
4019                    if (
4020                        iface < jface
4021                        and any(np.all(line[0] == x) for x in bz[jface])
4022                        and any(np.all(line[1] == x) for x in bz[jface])
4023                    ):
4024                        mlab.plot3d(
4025                            *zip(line[0], line[1]),
4026                            color=(0, 0, 0),
4027                            tube_radius=None,
4028                            figure=fig,
4029                        )
4030        for label, coords in kpoints_dict.items():
4031            label_coords = structure.lattice.reciprocal_lattice.get_cartesian_coords(coords)
4032            mlab.points3d(
4033                *label_coords,
4034                scale_factor=points_scale_factor,
4035                color=(0, 0, 0),
4036                figure=fig,
4037            )
4038            mlab.text3d(
4039                *label_coords,
4040                text=label,
4041                scale=labels_scale_factor,
4042                color=(0, 0, 0),
4043                figure=fig,
4044            )
4045
4046    for i, isolevel in enumerate(energy_levels):
4047        alpha = transparency_factor[i]
4048        color = colors[i]
4049        if multiple_figure:
4050            fig = mlab.figure(size=(1024, 768), bgcolor=(1, 1, 1))
4051
4052            for iface in range(len(bz)):  # pylint: disable=C0200
4053                for line in itertools.combinations(bz[iface], 2):
4054                    for jface in range(len(bz)):
4055                        if (
4056                            iface < jface
4057                            and any(np.all(line[0] == x) for x in bz[jface])
4058                            and any(np.all(line[1] == x) for x in bz[jface])
4059                        ):
4060                            mlab.plot3d(
4061                                *zip(line[0], line[1]),
4062                                color=(0, 0, 0),
4063                                tube_radius=None,
4064                                figure=fig,
4065                            )
4066
4067            for label, coords in kpoints_dict.items():
4068                label_coords = structure.lattice.reciprocal_lattice.get_cartesian_coords(coords)
4069                mlab.points3d(
4070                    *label_coords,
4071                    scale_factor=points_scale_factor,
4072                    color=(0, 0, 0),
4073                    figure=fig,
4074                )
4075                mlab.text3d(
4076                    *label_coords,
4077                    text=label,
4078                    scale=labels_scale_factor,
4079                    color=(0, 0, 0),
4080                    figure=fig,
4081                )
4082
4083        cp = mlab.contour3d(
4084            fact * data,
4085            contours=[isolevel],
4086            transparent=True,
4087            colormap="hot",
4088            color=color,
4089            opacity=alpha,
4090            figure=fig,
4091        )
4092
4093        polydata = cp.actor.actors[0].mapper.input
4094        pts = np.array(polydata.points)  # - 1
4095        polydata.points = np.dot(pts, cell / np.array(data.shape)[:, np.newaxis])
4096
4097        cx, cy, cz = [np.mean(np.array(polydata.points)[:, i]) for i in range(3)]
4098
4099        polydata.points = (np.array(polydata.points) - [cx, cy, cz]) * 2
4100
4101        # mlab.view(distance='auto')
4102        fig.scene.isometric_view()
4103
4104    if interative:
4105        mlab.show()
4106
4107    return fig, mlab
4108
4109
4110def plot_wigner_seitz(lattice, ax=None, **kwargs):
4111    """
4112    Adds the skeleton of the Wigner-Seitz cell of the lattice to a matplotlib Axes
4113
4114    Args:
4115        lattice: Lattice object
4116        ax: matplotlib :class:`Axes` or None if a new figure should be created.
4117        kwargs: kwargs passed to the matplotlib function 'plot'. Color defaults to black
4118            and linewidth to 1.
4119
4120    Returns:
4121        matplotlib figure and matplotlib ax
4122    """
4123    ax, fig, plt = get_ax3d_fig_plt(ax)
4124
4125    if "color" not in kwargs:
4126        kwargs["color"] = "k"
4127    if "linewidth" not in kwargs:
4128        kwargs["linewidth"] = 1
4129
4130    bz = lattice.get_wigner_seitz_cell()
4131    ax, fig, plt = get_ax3d_fig_plt(ax)
4132    for iface in range(len(bz)):  # pylint: disable=C0200
4133        for line in itertools.combinations(bz[iface], 2):
4134            for jface in range(len(bz)):
4135                if (
4136                    iface < jface
4137                    and any(np.all(line[0] == x) for x in bz[jface])
4138                    and any(np.all(line[1] == x) for x in bz[jface])
4139                ):
4140                    ax.plot(*zip(line[0], line[1]), **kwargs)
4141
4142    return fig, ax
4143
4144
4145def plot_lattice_vectors(lattice, ax=None, **kwargs):
4146    """
4147    Adds the basis vectors of the lattice provided to a matplotlib Axes
4148
4149    Args:
4150        lattice: Lattice object
4151        ax: matplotlib :class:`Axes` or None if a new figure should be created.
4152        kwargs: kwargs passed to the matplotlib function 'plot'. Color defaults to green
4153            and linewidth to 3.
4154
4155    Returns:
4156        matplotlib figure and matplotlib ax
4157    """
4158    ax, fig, plt = get_ax3d_fig_plt(ax)
4159
4160    if "color" not in kwargs:
4161        kwargs["color"] = "g"
4162    if "linewidth" not in kwargs:
4163        kwargs["linewidth"] = 3
4164
4165    vertex1 = lattice.get_cartesian_coords([0.0, 0.0, 0.0])
4166    vertex2 = lattice.get_cartesian_coords([1.0, 0.0, 0.0])
4167    ax.plot(*zip(vertex1, vertex2), **kwargs)
4168    vertex2 = lattice.get_cartesian_coords([0.0, 1.0, 0.0])
4169    ax.plot(*zip(vertex1, vertex2), **kwargs)
4170    vertex2 = lattice.get_cartesian_coords([0.0, 0.0, 1.0])
4171    ax.plot(*zip(vertex1, vertex2), **kwargs)
4172
4173    return fig, ax
4174
4175
4176def plot_path(line, lattice=None, coords_are_cartesian=False, ax=None, **kwargs):
4177    """
4178    Adds a line passing through the coordinates listed in 'line' to a matplotlib Axes
4179
4180    Args:
4181        line: list of coordinates.
4182        lattice: Lattice object used to convert from reciprocal to cartesian coordinates
4183        coords_are_cartesian: Set to True if you are providing
4184            coordinates in cartesian coordinates. Defaults to False.
4185            Requires lattice if False.
4186        ax: matplotlib :class:`Axes` or None if a new figure should be created.
4187        kwargs: kwargs passed to the matplotlib function 'plot'. Color defaults to red
4188            and linewidth to 3.
4189
4190    Returns:
4191        matplotlib figure and matplotlib ax
4192    """
4193
4194    ax, fig, plt = get_ax3d_fig_plt(ax)
4195
4196    if "color" not in kwargs:
4197        kwargs["color"] = "r"
4198    if "linewidth" not in kwargs:
4199        kwargs["linewidth"] = 3
4200
4201    for k in range(1, len(line)):
4202        vertex1 = line[k - 1]
4203        vertex2 = line[k]
4204        if not coords_are_cartesian:
4205            if lattice is None:
4206                raise ValueError("coords_are_cartesian False requires the lattice")
4207            vertex1 = lattice.get_cartesian_coords(vertex1)
4208            vertex2 = lattice.get_cartesian_coords(vertex2)
4209        ax.plot(*zip(vertex1, vertex2), **kwargs)
4210
4211    return fig, ax
4212
4213
4214def plot_labels(labels, lattice=None, coords_are_cartesian=False, ax=None, **kwargs):
4215    """
4216    Adds labels to a matplotlib Axes
4217
4218    Args:
4219        labels: dict containing the label as a key and the coordinates as value.
4220        lattice: Lattice object used to convert from reciprocal to cartesian coordinates
4221        coords_are_cartesian: Set to True if you are providing.
4222            coordinates in cartesian coordinates. Defaults to False.
4223            Requires lattice if False.
4224        ax: matplotlib :class:`Axes` or None if a new figure should be created.
4225        kwargs: kwargs passed to the matplotlib function 'text'. Color defaults to blue
4226            and size to 25.
4227
4228    Returns:
4229        matplotlib figure and matplotlib ax
4230    """
4231    ax, fig, plt = get_ax3d_fig_plt(ax)
4232
4233    if "color" not in kwargs:
4234        kwargs["color"] = "b"
4235    if "size" not in kwargs:
4236        kwargs["size"] = 25
4237
4238    for k, coords in labels.items():
4239        label = k
4240        if k.startswith("\\") or k.find("_") != -1:
4241            label = "$" + k + "$"
4242        off = 0.01
4243        if coords_are_cartesian:
4244            coords = np.array(coords)
4245        else:
4246            if lattice is None:
4247                raise ValueError("coords_are_cartesian False requires the lattice")
4248            coords = lattice.get_cartesian_coords(coords)
4249        ax.text(*(coords + off), s=label, **kwargs)
4250
4251    return fig, ax
4252
4253
4254def fold_point(p, lattice, coords_are_cartesian=False):
4255    """
4256    Folds a point with coordinates p inside the first Brillouin zone of the lattice.
4257
4258    Args:
4259        p: coordinates of one point
4260        lattice: Lattice object used to convert from reciprocal to cartesian coordinates
4261        coords_are_cartesian: Set to True if you are providing
4262            coordinates in cartesian coordinates. Defaults to False.
4263
4264    Returns:
4265        The cartesian coordinates folded inside the first Brillouin zone
4266    """
4267
4268    if coords_are_cartesian:
4269        p = lattice.get_fractional_coords(p)
4270    else:
4271        p = np.array(p)
4272
4273    p = np.mod(p + 0.5 - 1e-10, 1) - 0.5 + 1e-10
4274    p = lattice.get_cartesian_coords(p)
4275
4276    closest_lattice_point = None
4277    smallest_distance = 10000
4278    for i in (-1, 0, 1):
4279        for j in (-1, 0, 1):
4280            for k in (-1, 0, 1):
4281                lattice_point = np.dot((i, j, k), lattice.matrix)
4282                dist = np.linalg.norm(p - lattice_point)
4283                if closest_lattice_point is None or dist < smallest_distance:
4284                    closest_lattice_point = lattice_point
4285                    smallest_distance = dist
4286
4287    if not np.allclose(closest_lattice_point, (0, 0, 0)):
4288        p = p - closest_lattice_point
4289
4290    return p
4291
4292
4293def plot_points(points, lattice=None, coords_are_cartesian=False, fold=False, ax=None, **kwargs):
4294    """
4295    Adds Points to a matplotlib Axes
4296
4297    Args:
4298        points: list of coordinates
4299        lattice: Lattice object used to convert from reciprocal to cartesian coordinates
4300        coords_are_cartesian: Set to True if you are providing
4301            coordinates in cartesian coordinates. Defaults to False.
4302            Requires lattice if False.
4303        fold: whether the points should be folded inside the first Brillouin Zone.
4304            Defaults to False. Requires lattice if True.
4305        ax: matplotlib :class:`Axes` or None if a new figure should be created.
4306        kwargs: kwargs passed to the matplotlib function 'scatter'. Color defaults to blue
4307
4308    Returns:
4309        matplotlib figure and matplotlib ax
4310    """
4311    ax, fig, plt = get_ax3d_fig_plt(ax)
4312
4313    if "color" not in kwargs:
4314        kwargs["color"] = "b"
4315
4316    if (not coords_are_cartesian or fold) and lattice is None:
4317        raise ValueError("coords_are_cartesian False or fold True require the lattice")
4318
4319    for p in points:
4320
4321        if fold:
4322            p = fold_point(p, lattice, coords_are_cartesian=coords_are_cartesian)
4323
4324        elif not coords_are_cartesian:
4325            p = lattice.get_cartesian_coords(p)
4326
4327        ax.scatter(*p, **kwargs)
4328
4329    return fig, ax
4330
4331
4332@add_fig_kwargs
4333def plot_brillouin_zone_from_kpath(kpath, ax=None, **kwargs):
4334    """
4335    Gives the plot (as a matplotlib object) of the symmetry line path in
4336        the Brillouin Zone.
4337
4338    Args:
4339        kpath (HighSymmKpath): a HighSymmKPath object
4340        ax: matplotlib :class:`Axes` or None if a new figure should be created.
4341        **kwargs: provided by add_fig_kwargs decorator
4342
4343    Returns:
4344        matplotlib figure
4345
4346    """
4347    lines = [[kpath.kpath["kpoints"][k] for k in p] for p in kpath.kpath["path"]]
4348    return plot_brillouin_zone(
4349        bz_lattice=kpath.prim_rec,
4350        lines=lines,
4351        ax=ax,
4352        labels=kpath.kpath["kpoints"],
4353        **kwargs,
4354    )
4355
4356
4357@add_fig_kwargs
4358def plot_brillouin_zone(
4359    bz_lattice,
4360    lines=None,
4361    labels=None,
4362    kpoints=None,
4363    fold=False,
4364    coords_are_cartesian=False,
4365    ax=None,
4366    **kwargs,
4367):
4368    """
4369    Plots a 3D representation of the Brillouin zone of the structure.
4370    Can add to the plot paths, labels and kpoints
4371
4372    Args:
4373        bz_lattice: Lattice object of the Brillouin zone
4374        lines: list of lists of coordinates. Each list represent a different path
4375        labels: dict containing the label as a key and the coordinates as value.
4376        kpoints: list of coordinates
4377        fold: whether the points should be folded inside the first Brillouin Zone.
4378            Defaults to False. Requires lattice if True.
4379        coords_are_cartesian: Set to True if you are providing
4380            coordinates in cartesian coordinates. Defaults to False.
4381        ax: matplotlib :class:`Axes` or None if a new figure should be created.
4382        kwargs: provided by add_fig_kwargs decorator
4383
4384    Returns:
4385        matplotlib figure
4386    """
4387
4388    fig, ax = plot_lattice_vectors(bz_lattice, ax=ax)
4389    plot_wigner_seitz(bz_lattice, ax=ax)
4390    if lines is not None:
4391        for line in lines:
4392            plot_path(line, bz_lattice, coords_are_cartesian=coords_are_cartesian, ax=ax)
4393
4394    if labels is not None:
4395        plot_labels(labels, bz_lattice, coords_are_cartesian=coords_are_cartesian, ax=ax)
4396        plot_points(
4397            labels.values(),
4398            bz_lattice,
4399            coords_are_cartesian=coords_are_cartesian,
4400            fold=False,
4401            ax=ax,
4402        )
4403
4404    if kpoints is not None:
4405        plot_points(
4406            kpoints,
4407            bz_lattice,
4408            coords_are_cartesian=coords_are_cartesian,
4409            ax=ax,
4410            fold=fold,
4411        )
4412
4413    ax.set_xlim3d(-1, 1)
4414    ax.set_ylim3d(-1, 1)
4415    ax.set_zlim3d(-1, 1)
4416
4417    # ax.set_aspect('equal')
4418    ax.axis("off")
4419
4420    return fig
4421
4422
4423def plot_ellipsoid(
4424    hessian,
4425    center,
4426    lattice=None,
4427    rescale=1.0,
4428    ax=None,
4429    coords_are_cartesian=False,
4430    arrows=False,
4431    **kwargs,
4432):
4433    """
4434    Plots a 3D ellipsoid rappresenting the Hessian matrix in input.
4435    Useful to get a graphical visualization of the effective mass
4436    of a band in a single k-point.
4437
4438    Args:
4439        hessian: the Hessian matrix
4440        center: the center of the ellipsoid in reciprocal coords (Default)
4441        lattice: Lattice object of the Brillouin zone
4442        rescale: factor for size scaling of the ellipsoid
4443        ax: matplotlib :class:`Axes` or None if a new figure should be created.
4444        coords_are_cartesian: Set to True if you are providing a center in
4445                              cartesian coordinates. Defaults to False.
4446        kwargs: kwargs passed to the matplotlib function 'plot_wireframe'.
4447                Color defaults to blue, rstride and cstride
4448                default to 4, alpha defaults to 0.2.
4449    Returns:
4450        matplotlib figure and matplotlib ax
4451    Example of use:
4452        fig,ax=plot_wigner_seitz(struct.reciprocal_lattice)
4453        plot_ellipsoid(hessian,[0.0,0.0,0.0], struct.reciprocal_lattice,ax=ax)
4454    """
4455
4456    if (not coords_are_cartesian) and lattice is None:
4457        raise ValueError("coords_are_cartesian False or fold True require the lattice")
4458
4459    if not coords_are_cartesian:
4460        center = lattice.get_cartesian_coords(center)
4461
4462    if "color" not in kwargs:
4463        kwargs["color"] = "b"
4464    if "rstride" not in kwargs:
4465        kwargs["rstride"] = 4
4466    if "cstride" not in kwargs:
4467        kwargs["cstride"] = 4
4468    if "alpha" not in kwargs:
4469        kwargs["alpha"] = 0.2
4470
4471    # calculate the ellipsoid
4472    # find the rotation matrix and radii of the axes
4473    U, s, rotation = np.linalg.svd(hessian)
4474    radii = 1.0 / np.sqrt(s)
4475
4476    # from polar coordinates
4477    u = np.linspace(0.0, 2.0 * np.pi, 100)
4478    v = np.linspace(0.0, np.pi, 100)
4479    x = radii[0] * np.outer(np.cos(u), np.sin(v))
4480    y = radii[1] * np.outer(np.sin(u), np.sin(v))
4481    z = radii[2] * np.outer(np.ones_like(u), np.cos(v))
4482    for i in range(len(x)):
4483        for j in range(len(x)):
4484            [x[i, j], y[i, j], z[i, j]] = np.dot([x[i, j], y[i, j], z[i, j]], rotation) * rescale + center
4485
4486    # add the ellipsoid to the current axes
4487    ax, fig, plt = get_ax3d_fig_plt(ax)
4488    ax.plot_wireframe(x, y, z, **kwargs)
4489
4490    if arrows:
4491        color = ("b", "g", "r")
4492        em = np.zeros((3, 3))
4493        for i in range(3):
4494            em[i, :] = rotation[i, :] / np.linalg.norm(rotation[i, :])
4495        for i in range(3):
4496            ax.quiver3D(
4497                center[0],
4498                center[1],
4499                center[2],
4500                em[i, 0],
4501                em[i, 1],
4502                em[i, 2],
4503                pivot="tail",
4504                arrow_length_ratio=0.2,
4505                length=radii[i] * rescale,
4506                color=color[i],
4507            )
4508
4509    return fig, ax
4510