1# coding: utf-8
2# Copyright (c) Pymatgen Development Team.
3# Distributed under the terms of the MIT License.
4
5"""
6This module provides classes to define everything related to band structures.
7"""
8
9import collections
10import itertools
11import math
12import re
13import warnings
14
15import numpy as np
16from monty.json import MSONable
17
18from pymatgen.core.lattice import Lattice
19from pymatgen.core.periodic_table import Element, get_el_sp
20from pymatgen.core.structure import Structure
21from pymatgen.electronic_structure.core import Orbital, Spin
22from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
23from pymatgen.util.coord import pbc_diff
24
25__author__ = "Geoffroy Hautier, Shyue Ping Ong, Michael Kocher"
26__copyright__ = "Copyright 2012, The Materials Project"
27__version__ = "1.0"
28__maintainer__ = "Geoffroy Hautier"
29__email__ = "geoffroy@uclouvain.be"
30__status__ = "Development"
31__date__ = "March 14, 2012"
32
33
34class Kpoint(MSONable):
35    """
36    Class to store kpoint objects. A kpoint is defined with a lattice and frac
37    or cartesian coordinates syntax similar than the site object in
38    pymatgen.core.structure.
39    """
40
41    def __init__(
42        self,
43        coords,
44        lattice,
45        to_unit_cell=False,
46        coords_are_cartesian=False,
47        label=None,
48    ):
49        """
50        Args:
51            coords: coordinate of the kpoint as a numpy array
52            lattice: A pymatgen.core.lattice.Lattice lattice object representing
53                the reciprocal lattice of the kpoint
54            to_unit_cell: Translates fractional coordinate to the basic unit
55                cell, i.e., all fractional coordinates satisfy 0 <= a < 1.
56                Defaults to False.
57            coords_are_cartesian: Boolean indicating if the coordinates given are
58                in cartesian or fractional coordinates (by default fractional)
59            label: the label of the kpoint if any (None by default)
60        """
61        self._lattice = lattice
62        self._fcoords = lattice.get_fractional_coords(coords) if coords_are_cartesian else coords
63        self._label = label
64
65        if to_unit_cell:
66            for i, fc in enumerate(self._fcoords):
67                self._fcoords[i] -= math.floor(fc)
68
69        self._ccoords = lattice.get_cartesian_coords(self._fcoords)
70
71    @property
72    def lattice(self):
73        """
74        The lattice associated with the kpoint. It's a
75        pymatgen.core.lattice.Lattice object
76        """
77        return self._lattice
78
79    @property
80    def label(self):
81        """
82        The label associated with the kpoint
83        """
84        return self._label
85
86    @property
87    def frac_coords(self):
88        """
89        The fractional coordinates of the kpoint as a numpy array
90        """
91        return np.copy(self._fcoords)
92
93    @property
94    def cart_coords(self):
95        """
96        The cartesian coordinates of the kpoint as a numpy array
97        """
98        return np.copy(self._ccoords)
99
100    @property
101    def a(self):
102        """
103        Fractional a coordinate of the kpoint
104        """
105        return self._fcoords[0]
106
107    @property
108    def b(self):
109        """
110        Fractional b coordinate of the kpoint
111        """
112        return self._fcoords[1]
113
114    @property
115    def c(self):
116        """
117        Fractional c coordinate of the kpoint
118        """
119        return self._fcoords[2]
120
121    def __str__(self):
122        """
123        Returns a string with fractional, cartesian coordinates and label
124        """
125        return "{} {} {}".format(self.frac_coords, self.cart_coords, self.label)
126
127    def as_dict(self):
128        """
129        Json-serializable dict representation of a kpoint
130        """
131        return {
132            "lattice": self.lattice.as_dict(),
133            "fcoords": self.frac_coords.tolist(),
134            "ccoords": self.cart_coords.tolist(),
135            "label": self.label,
136            "@module": self.__class__.__module__,
137            "@class": self.__class__.__name__,
138        }
139
140    @classmethod
141    def from_dict(cls, d):
142        """
143        Create from dict.
144
145        Args:
146            A dict with all data for a kpoint object.
147
148        Returns:
149            A Kpoint object
150        """
151
152        return cls(
153            coords=d["fcoords"],
154            lattice=Lattice.from_dict(d["lattice"]),
155            coords_are_cartesian=False,
156            label=d["label"],
157        )
158
159
160class BandStructure:
161    """
162    This is the most generic band structure data possible
163    it's defined by a list of kpoints + energies for each of them
164
165    .. attribute:: kpoints:
166        the list of kpoints (as Kpoint objects) in the band structure
167
168    .. attribute:: lattice_rec
169
170        the reciprocal lattice of the band structure.
171
172    .. attribute:: efermi
173
174        the fermi energy
175
176    .. attribute::  is_spin_polarized
177
178        True if the band structure is spin-polarized, False otherwise
179
180    .. attribute:: bands
181
182        The energy eigenvalues as a {spin: ndarray}. Note that the use of an
183        ndarray is necessary for computational as well as memory efficiency
184        due to the large amount of numerical data. The indices of the ndarray
185        are [band_index, kpoint_index].
186
187    .. attribute:: nb_bands
188
189        returns the number of bands in the band structure
190
191    .. attribute:: structure
192
193        returns the structure
194
195    .. attribute:: projections
196
197        The projections as a {spin: ndarray}. Note that the use of an
198        ndarray is necessary for computational as well as memory efficiency
199        due to the large amount of numerical data. The indices of the ndarray
200        are [band_index, kpoint_index, orbital_index, ion_index].
201    """
202
203    def __init__(
204        self,
205        kpoints,
206        eigenvals,
207        lattice,
208        efermi,
209        labels_dict=None,
210        coords_are_cartesian=False,
211        structure=None,
212        projections=None,
213    ):
214        """
215        Args:
216            kpoints: list of kpoint as numpy arrays, in frac_coords of the
217                given lattice by default
218            eigenvals: dict of energies for spin up and spin down
219                {Spin.up:[][],Spin.down:[][]}, the first index of the array
220                [][] refers to the band and the second to the index of the
221                kpoint. The kpoints are ordered according to the order of the
222                kpoints array. If the band structure is not spin polarized, we
223                only store one data set under Spin.up
224            lattice: The reciprocal lattice as a pymatgen Lattice object.
225                Pymatgen uses the physics convention of reciprocal lattice vectors
226                WITH a 2*pi coefficient
227            efermi: fermi energy
228            labels_dict: (dict) of {} this links a kpoint (in frac coords or
229                cartesian coordinates depending on the coords) to a label.
230            coords_are_cartesian: Whether coordinates are cartesian.
231            structure: The crystal structure (as a pymatgen Structure object)
232                associated with the band structure. This is needed if we
233                provide projections to the band structure
234            projections: dict of orbital projections as {spin: ndarray}. The
235                indices of the ndarrayare [band_index, kpoint_index, orbital_index,
236                ion_index].If the band structure is not spin polarized, we only
237                store one data set under Spin.up.
238        """
239        self.efermi = efermi
240        self.lattice_rec = lattice
241        self.kpoints = []
242        self.labels_dict = {}
243        self.structure = structure
244        self.projections = projections or {}
245        self.projections = {k: np.array(v) for k, v in self.projections.items()}
246
247        if labels_dict is None:
248            labels_dict = {}
249
250        if len(self.projections) != 0 and self.structure is None:
251            raise Exception("if projections are provided a structure object" " needs also to be given")
252
253        for k in kpoints:
254            # let see if this kpoint has been assigned a label
255            label = None
256            for c in labels_dict:
257                if np.linalg.norm(k - np.array(labels_dict[c])) < 0.0001:
258                    label = c
259                    self.labels_dict[label] = Kpoint(
260                        k,
261                        lattice,
262                        label=label,
263                        coords_are_cartesian=coords_are_cartesian,
264                    )
265            self.kpoints.append(Kpoint(k, lattice, label=label, coords_are_cartesian=coords_are_cartesian))
266        self.bands = {spin: np.array(v) for spin, v in eigenvals.items()}
267        self.nb_bands = len(eigenvals[Spin.up])
268        self.is_spin_polarized = len(self.bands) == 2
269
270    def get_projection_on_elements(self):
271        """
272        Method returning a dictionary of projections on elements.
273
274        Returns:
275            a dictionary in the {Spin.up:[][{Element:values}],
276            Spin.down:[][{Element:values}]} format
277            if there is no projections in the band structure
278            returns an empty dict
279        """
280        result = {}
281        structure = self.structure
282        for spin, v in self.projections.items():
283            result[spin] = [
284                [collections.defaultdict(float) for i in range(len(self.kpoints))] for j in range(self.nb_bands)
285            ]
286            for i, j, k in itertools.product(
287                range(self.nb_bands),
288                range(len(self.kpoints)),
289                range(structure.num_sites),
290            ):
291                result[spin][i][j][str(structure[k].specie)] += np.sum(v[i, j, :, k])
292        return result
293
294    def get_projections_on_elements_and_orbitals(self, el_orb_spec):
295        """
296        Method returning a dictionary of projections on elements and specific
297        orbitals
298
299        Args:
300            el_orb_spec: A dictionary of Elements and Orbitals for which we want
301                to have projections on. It is given as: {Element:[orbitals]},
302                e.g., {'Cu':['d','s']}
303
304        Returns:
305            A dictionary of projections on elements in the
306            {Spin.up:[][{Element:{orb:values}}],
307            Spin.down:[][{Element:{orb:values}}]} format
308            if there is no projections in the band structure returns an empty
309            dict.
310        """
311        result = {}
312        structure = self.structure
313        el_orb_spec = {get_el_sp(el): orbs for el, orbs in el_orb_spec.items()}
314        for spin, v in self.projections.items():
315            result[spin] = [
316                [{str(e): collections.defaultdict(float) for e in el_orb_spec} for i in range(len(self.kpoints))]
317                for j in range(self.nb_bands)
318            ]
319
320            for i, j, k in itertools.product(
321                range(self.nb_bands),
322                range(len(self.kpoints)),
323                range(structure.num_sites),
324            ):
325                sp = structure[k].specie
326                for orb_i in range(len(v[i][j])):
327                    o = Orbital(orb_i).name[0]
328                    if sp in el_orb_spec:
329                        if o in el_orb_spec[sp]:
330                            result[spin][i][j][str(sp)][o] += v[i][j][orb_i][k]
331        return result
332
333    def is_metal(self, efermi_tol=1e-4):
334        """
335        Check if the band structure indicates a metal by looking if the fermi
336        level crosses a band.
337
338        Returns:
339            True if a metal, False if not
340        """
341        for spin, values in self.bands.items():
342            for i in range(self.nb_bands):
343                if np.any(values[i, :] - self.efermi < -efermi_tol) and np.any(values[i, :] - self.efermi > efermi_tol):
344                    return True
345        return False
346
347    def get_vbm(self):
348        """
349        Returns data about the VBM.
350
351        Returns:
352            dict as {"band_index","kpoint_index","kpoint","energy"}
353            - "band_index": A dict with spin keys pointing to a list of the
354            indices of the band containing the VBM (please note that you
355            can have several bands sharing the VBM) {Spin.up:[],
356            Spin.down:[]}
357            - "kpoint_index": The list of indices in self.kpoints for the
358            kpoint VBM. Please note that there can be several
359            kpoint_indices relating to the same kpoint (e.g., Gamma can
360            occur at different spots in the band structure line plot)
361            - "kpoint": The kpoint (as a kpoint object)
362            - "energy": The energy of the VBM
363            - "projections": The projections along sites and orbitals of the
364            VBM if any projection data is available (else it is an empty
365            dictionnary). The format is similar to the projections field in
366            BandStructure: {spin:{'Orbital': [proj]}} where the array
367            [proj] is ordered according to the sites in structure
368        """
369        if self.is_metal():
370            return {
371                "band_index": [],
372                "kpoint_index": [],
373                "kpoint": [],
374                "energy": None,
375                "projections": {},
376            }
377        max_tmp = -float("inf")
378        index = None
379        kpointvbm = None
380        for spin, v in self.bands.items():
381            for i, j in zip(*np.where(v < self.efermi)):
382                if v[i, j] > max_tmp:
383                    max_tmp = float(v[i, j])
384                    index = j
385                    kpointvbm = self.kpoints[j]
386
387        list_ind_kpts = []
388        if kpointvbm.label is not None:
389            for i, kpt in enumerate(self.kpoints):
390                if kpt.label == kpointvbm.label:
391                    list_ind_kpts.append(i)
392        else:
393            list_ind_kpts.append(index)
394        # get all other bands sharing the vbm
395        list_ind_band = collections.defaultdict(list)
396        for spin in self.bands:
397            for i in range(self.nb_bands):
398                if math.fabs(self.bands[spin][i][index] - max_tmp) < 0.001:
399                    list_ind_band[spin].append(i)
400        proj = {}
401        for spin, v in self.projections.items():
402            if len(list_ind_band[spin]) == 0:
403                continue
404            proj[spin] = v[list_ind_band[spin][0]][list_ind_kpts[0]]
405        return {
406            "band_index": list_ind_band,
407            "kpoint_index": list_ind_kpts,
408            "kpoint": kpointvbm,
409            "energy": max_tmp,
410            "projections": proj,
411        }
412
413    def get_cbm(self):
414        """
415        Returns data about the CBM.
416
417        Returns:
418            {"band_index","kpoint_index","kpoint","energy"}
419            - "band_index": A dict with spin keys pointing to a list of the
420            indices of the band containing the CBM (please note that you
421            can have several bands sharing the CBM) {Spin.up:[],
422            Spin.down:[]}
423            - "kpoint_index": The list of indices in self.kpoints for the
424            kpoint CBM. Please note that there can be several
425            kpoint_indices relating to the same kpoint (e.g., Gamma can
426            occur at different spots in the band structure line plot)
427            - "kpoint": The kpoint (as a kpoint object)
428            - "energy": The energy of the CBM
429            - "projections": The projections along sites and orbitals of the
430            CBM if any projection data is available (else it is an empty
431            dictionnary). The format is similar to the projections field in
432            BandStructure: {spin:{'Orbital': [proj]}} where the array
433            [proj] is ordered according to the sites in structure
434        """
435        if self.is_metal():
436            return {
437                "band_index": [],
438                "kpoint_index": [],
439                "kpoint": [],
440                "energy": None,
441                "projections": {},
442            }
443        max_tmp = float("inf")
444
445        index = None
446        kpointcbm = None
447        for spin, v in self.bands.items():
448            for i, j in zip(*np.where(v >= self.efermi)):
449                if v[i, j] < max_tmp:
450                    max_tmp = float(v[i, j])
451                    index = j
452                    kpointcbm = self.kpoints[j]
453
454        list_index_kpoints = []
455        if kpointcbm.label is not None:
456            for i, kpt in enumerate(self.kpoints):
457                if kpt.label == kpointcbm.label:
458                    list_index_kpoints.append(i)
459        else:
460            list_index_kpoints.append(index)
461
462        # get all other bands sharing the cbm
463        list_index_band = collections.defaultdict(list)
464        for spin in self.bands:
465            for i in range(self.nb_bands):
466                if math.fabs(self.bands[spin][i][index] - max_tmp) < 0.001:
467                    list_index_band[spin].append(i)
468        proj = {}
469        for spin, v in self.projections.items():
470            if len(list_index_band[spin]) == 0:
471                continue
472            proj[spin] = v[list_index_band[spin][0]][list_index_kpoints[0]]
473
474        return {
475            "band_index": list_index_band,
476            "kpoint_index": list_index_kpoints,
477            "kpoint": kpointcbm,
478            "energy": max_tmp,
479            "projections": proj,
480        }
481
482    def get_band_gap(self):
483        r"""
484        Returns band gap data.
485
486        Returns:
487            A dict {"energy","direct","transition"}:
488            "energy": band gap energy
489            "direct": A boolean telling if the gap is direct or not
490            "transition": kpoint labels of the transition (e.g., "\\Gamma-X")
491        """
492        if self.is_metal():
493            return {"energy": 0.0, "direct": False, "transition": None}
494        cbm = self.get_cbm()
495        vbm = self.get_vbm()
496        result = dict(direct=False, energy=0.0, transition=None)
497
498        result["energy"] = cbm["energy"] - vbm["energy"]
499
500        if (cbm["kpoint"].label is not None and cbm["kpoint"].label == vbm["kpoint"].label) or np.linalg.norm(
501            cbm["kpoint"].cart_coords - vbm["kpoint"].cart_coords
502        ) < 0.01:
503            result["direct"] = True
504
505        result["transition"] = "-".join(
506            [
507                str(c.label)
508                if c.label is not None
509                else str("(") + ",".join(["{0:.3f}".format(c.frac_coords[i]) for i in range(3)]) + str(")")
510                for c in [vbm["kpoint"], cbm["kpoint"]]
511            ]
512        )
513
514        return result
515
516    def get_direct_band_gap_dict(self):
517        """
518        Returns a dictionary of information about the direct
519        band gap
520
521        Returns:
522            a dictionary of the band gaps indexed by spin
523            along with their band indices and k-point index
524        """
525        if self.is_metal():
526            raise ValueError("get_direct_band_gap_dict should only be used with non-metals")
527        direct_gap_dict = {}
528        for spin, v in self.bands.items():
529            above = v[np.all(v > self.efermi, axis=1)]
530            min_above = np.min(above, axis=0)
531            below = v[np.all(v < self.efermi, axis=1)]
532            max_below = np.max(below, axis=0)
533            diff = min_above - max_below
534            kpoint_index = np.argmin(diff)
535            band_indices = [
536                np.argmax(below[:, kpoint_index]),
537                np.argmin(above[:, kpoint_index]) + len(below),
538            ]
539            direct_gap_dict[spin] = {
540                "value": diff[kpoint_index],
541                "kpoint_index": kpoint_index,
542                "band_indices": band_indices,
543            }
544        return direct_gap_dict
545
546    def get_direct_band_gap(self):
547        """
548        Returns the direct band gap.
549
550        Returns:
551             the value of the direct band gap
552        """
553        if self.is_metal():
554            return 0.0
555        dg = self.get_direct_band_gap_dict()
556        return min(v["value"] for v in dg.values())
557
558    def get_sym_eq_kpoints(self, kpoint, cartesian=False, tol=1e-2):
559        """
560        Returns a list of unique symmetrically equivalent k-points.
561
562        Args:
563            kpoint (1x3 array): coordinate of the k-point
564            cartesian (bool): kpoint is in cartesian or fractional coordinates
565            tol (float): tolerance below which coordinates are considered equal
566
567        Returns:
568            ([1x3 array] or None): if structure is not available returns None
569        """
570        if not self.structure:
571            return None
572        sg = SpacegroupAnalyzer(self.structure)
573        symmops = sg.get_point_group_operations(cartesian=cartesian)
574        points = np.dot(kpoint, [m.rotation_matrix for m in symmops])
575        rm_list = []
576        # identify and remove duplicates from the list of equivalent k-points:
577        for i in range(len(points) - 1):
578            for j in range(i + 1, len(points)):
579                if np.allclose(pbc_diff(points[i], points[j]), [0, 0, 0], tol):
580                    rm_list.append(i)
581                    break
582        return np.delete(points, rm_list, axis=0)
583
584    def get_kpoint_degeneracy(self, kpoint, cartesian=False, tol=1e-2):
585        """
586        Returns degeneracy of a given k-point based on structure symmetry
587        Args:
588            kpoint (1x3 array): coordinate of the k-point
589            cartesian (bool): kpoint is in cartesian or fractional coordinates
590            tol (float): tolerance below which coordinates are considered equal
591
592        Returns:
593            (int or None): degeneracy or None if structure is not available
594        """
595        all_kpts = self.get_sym_eq_kpoints(kpoint, cartesian, tol=tol)
596        if all_kpts is not None:
597            return len(all_kpts)
598        return None
599
600    def as_dict(self):
601        """
602        Json-serializable dict representation of BandStructure.
603        """
604        d = {
605            "@module": self.__class__.__module__,
606            "@class": self.__class__.__name__,
607            "lattice_rec": self.lattice_rec.as_dict(),
608            "efermi": self.efermi,
609            "kpoints": [],
610        }
611        # kpoints are not kpoint objects dicts but are frac coords (this makes
612        # the dict smaller and avoids the repetition of the lattice
613        for k in self.kpoints:
614            d["kpoints"].append(k.as_dict()["fcoords"])
615
616        d["bands"] = {str(int(spin)): self.bands[spin].tolist() for spin in self.bands}
617        d["is_metal"] = self.is_metal()
618        vbm = self.get_vbm()
619        d["vbm"] = {
620            "energy": vbm["energy"],
621            "kpoint_index": vbm["kpoint_index"],
622            "band_index": {str(int(spin)): vbm["band_index"][spin] for spin in vbm["band_index"]},
623            "projections": {str(spin): v.tolist() for spin, v in vbm["projections"].items()},
624        }
625        cbm = self.get_cbm()
626        d["cbm"] = {
627            "energy": cbm["energy"],
628            "kpoint_index": cbm["kpoint_index"],
629            "band_index": {str(int(spin)): cbm["band_index"][spin] for spin in cbm["band_index"]},
630            "projections": {str(spin): v.tolist() for spin, v in cbm["projections"].items()},
631        }
632        d["band_gap"] = self.get_band_gap()
633        d["labels_dict"] = {}
634        d["is_spin_polarized"] = self.is_spin_polarized
635
636        # MongoDB does not accept keys starting with $. Add a blanck space to fix the problem
637        for c, label in self.labels_dict.items():
638            mongo_key = c if not c.startswith("$") else " " + c
639            d["labels_dict"][mongo_key] = label.as_dict()["fcoords"]
640        d["projections"] = {}
641        if len(self.projections) != 0:
642            d["structure"] = self.structure.as_dict()
643            d["projections"] = {str(int(spin)): np.array(v).tolist() for spin, v in self.projections.items()}
644        return d
645
646    @classmethod
647    def from_dict(cls, d):
648        """
649        Create from dict.
650
651        Args:
652            A dict with all data for a band structure object.
653
654        Returns:
655            A BandStructure object
656        """
657        # Strip the label to recover initial string
658        # (see trick used in as_dict to handle $ chars)
659        labels_dict = {k.strip(): v for k, v in d["labels_dict"].items()}
660        projections = {}
661        structure = None
662        if isinstance(list(d["bands"].values())[0], dict):
663            eigenvals = {Spin(int(k)): np.array(d["bands"][k]["data"]) for k in d["bands"]}
664        else:
665            eigenvals = {Spin(int(k)): d["bands"][k] for k in d["bands"]}
666
667        if "structure" in d:
668            structure = Structure.from_dict(d["structure"])
669
670        try:
671            if d.get("projections"):
672                if isinstance(d["projections"]["1"][0][0], dict):
673                    raise ValueError("Old band structure dict format detected!")
674                projections = {Spin(int(spin)): np.array(v) for spin, v in d["projections"].items()}
675
676            return cls(
677                d["kpoints"],
678                eigenvals,
679                Lattice(d["lattice_rec"]["matrix"]),
680                d["efermi"],
681                labels_dict,
682                structure=structure,
683                projections=projections,
684            )
685
686        except Exception:
687            warnings.warn(
688                "Trying from_dict failed. Now we are trying the old "
689                "format. Please convert your BS dicts to the new "
690                "format. The old format will be retired in pymatgen "
691                "5.0."
692            )
693            return cls.from_old_dict(d)
694
695    @classmethod
696    def from_old_dict(cls, d):
697        """
698        Args:
699            d (dict): A dict with all data for a band structure symm line
700                object.
701        Returns:
702            A BandStructureSymmLine object
703        """
704        # Strip the label to recover initial string (see trick used in as_dict to handle $ chars)
705        labels_dict = {k.strip(): v for k, v in d["labels_dict"].items()}
706        projections = {}
707        structure = None
708        if "projections" in d and len(d["projections"]) != 0:
709            structure = Structure.from_dict(d["structure"])
710            projections = {}
711            for spin in d["projections"]:
712                dd = []
713                for i in range(len(d["projections"][spin])):
714                    ddd = []
715                    for j in range(len(d["projections"][spin][i])):
716                        dddd = []
717                        for k in range(len(d["projections"][spin][i][j])):
718                            ddddd = []
719                            orb = Orbital(k).name
720                            for l in range(len(d["projections"][spin][i][j][orb])):
721                                ddddd.append(d["projections"][spin][i][j][orb][l])
722                            dddd.append(np.array(ddddd))
723                        ddd.append(np.array(dddd))
724                    dd.append(np.array(ddd))
725                projections[Spin(int(spin))] = np.array(dd)
726
727        return BandStructure(
728            d["kpoints"],
729            {Spin(int(k)): d["bands"][k] for k in d["bands"]},
730            Lattice(d["lattice_rec"]["matrix"]),
731            d["efermi"],
732            labels_dict,
733            structure=structure,
734            projections=projections,
735        )
736
737
738class BandStructureSymmLine(BandStructure, MSONable):
739    r"""
740    This object stores band structures along selected (symmetry) lines in the
741    Brillouin zone. We call the different symmetry lines (ex: \\Gamma to Z)
742    "branches".
743    """
744
745    def __init__(
746        self,
747        kpoints,
748        eigenvals,
749        lattice,
750        efermi,
751        labels_dict,
752        coords_are_cartesian=False,
753        structure=None,
754        projections=None,
755    ):
756        """
757        Args:
758            kpoints: list of kpoint as numpy arrays, in frac_coords of the
759                given lattice by default
760            eigenvals: dict of energies for spin up and spin down
761                {Spin.up:[][],Spin.down:[][]}, the first index of the array
762                [][] refers to the band and the second to the index of the
763                kpoint. The kpoints are ordered according to the order of the
764                kpoints array. If the band structure is not spin polarized, we
765                only store one data set under Spin.up.
766            lattice: The reciprocal lattice.
767                Pymatgen uses the physics convention of reciprocal lattice vectors
768                WITH a 2*pi coefficient
769            efermi: fermi energy
770            label_dict: (dict) of {} this link a kpoint (in frac coords or
771                cartesian coordinates depending on the coords).
772            coords_are_cartesian: Whether coordinates are cartesian.
773            structure: The crystal structure (as a pymatgen Structure object)
774                associated with the band structure. This is needed if we
775                provide projections to the band structure.
776            projections: dict of orbital projections as {spin: ndarray}. The
777                indices of the ndarrayare [band_index, kpoint_index, orbital_index,
778                ion_index].If the band structure is not spin polarized, we only
779                store one data set under Spin.up.
780        """
781        super().__init__(
782            kpoints,
783            eigenvals,
784            lattice,
785            efermi,
786            labels_dict,
787            coords_are_cartesian,
788            structure,
789            projections,
790        )
791        self.distance = []
792        self.branches = []
793        one_group = []
794        branches_tmp = []
795        # get labels and distance for each kpoint
796        previous_kpoint = self.kpoints[0]
797        previous_distance = 0.0
798
799        previous_label = self.kpoints[0].label
800        for i, kpt in enumerate(self.kpoints):
801            label = kpt.label
802            if label is not None and previous_label is not None:
803                self.distance.append(previous_distance)
804            else:
805                self.distance.append(np.linalg.norm(kpt.cart_coords - previous_kpoint.cart_coords) + previous_distance)
806            previous_kpoint = kpt
807            previous_distance = self.distance[i]
808            if label:
809                if previous_label:
810                    if len(one_group) != 0:
811                        branches_tmp.append(one_group)
812                    one_group = []
813            previous_label = label
814            one_group.append(i)
815
816        if len(one_group) != 0:
817            branches_tmp.append(one_group)
818        for b in branches_tmp:
819            self.branches.append(
820                {
821                    "start_index": b[0],
822                    "end_index": b[-1],
823                    "name": str(self.kpoints[b[0]].label) + "-" + str(self.kpoints[b[-1]].label),
824                }
825            )
826
827        self.is_spin_polarized = False
828        if len(self.bands) == 2:
829            self.is_spin_polarized = True
830
831    def get_equivalent_kpoints(self, index):
832        """
833        Returns the list of kpoint indices equivalent (meaning they are the
834        same frac coords) to the given one.
835
836        Args:
837            index: the kpoint index
838
839        Returns:
840            a list of equivalent indices
841
842        TODO: now it uses the label we might want to use coordinates instead
843        (in case there was a mislabel)
844        """
845        # if the kpoint has no label it can"t have a repetition along the band
846        # structure line object
847
848        if self.kpoints[index].label is None:
849            return [index]
850
851        list_index_kpoints = []
852        for i, kpt in enumerate(self.kpoints):
853            if kpt.label == self.kpoints[index].label:
854                list_index_kpoints.append(i)
855
856        return list_index_kpoints
857
858    def get_branch(self, index):
859        r"""
860        Returns in what branch(es) is the kpoint. There can be several
861        branches.
862
863        Args:
864            index: the kpoint index
865
866        Returns:
867            A list of dictionaries [{"name","start_index","end_index","index"}]
868            indicating all branches in which the k_point is. It takes into
869            account the fact that one kpoint (e.g., \\Gamma) can be in several
870            branches
871        """
872        to_return = []
873        for i in self.get_equivalent_kpoints(index):
874            for b in self.branches:
875                if b["start_index"] <= i <= b["end_index"]:
876                    to_return.append(
877                        {
878                            "name": b["name"],
879                            "start_index": b["start_index"],
880                            "end_index": b["end_index"],
881                            "index": i,
882                        }
883                    )
884        return to_return
885
886    def apply_scissor(self, new_band_gap):
887        """
888        Apply a scissor operator (shift of the CBM) to fit the given band gap.
889        If it's a metal. We look for the band crossing the fermi level
890        and shift this one up. This will not work all the time for metals!
891
892        Args:
893            new_band_gap: the band gap the scissor band structure need to have.
894
895        Returns:
896            a BandStructureSymmLine object with the applied scissor shift
897        """
898        if self.is_metal():
899            # moves then the highest index band crossing the fermi level
900            # find this band...
901            max_index = -1000
902            # spin_index = None
903            for i in range(self.nb_bands):
904                below = False
905                above = False
906                for j in range(len(self.kpoints)):
907                    if self.bands[Spin.up][i][j] < self.efermi:
908                        below = True
909                    if self.bands[Spin.up][i][j] > self.efermi:
910                        above = True
911                if above and below:
912                    if i > max_index:
913                        max_index = i
914                        # spin_index = Spin.up
915                if self.is_spin_polarized:
916                    below = False
917                    above = False
918                    for j in range(len(self.kpoints)):
919                        if self.bands[Spin.down][i][j] < self.efermi:
920                            below = True
921                        if self.bands[Spin.down][i][j] > self.efermi:
922                            above = True
923                    if above and below:
924                        if i > max_index:
925                            max_index = i
926                            # spin_index = Spin.down
927            old_dict = self.as_dict()
928            shift = new_band_gap
929            for spin in old_dict["bands"]:
930                for k in range(len(old_dict["bands"][spin])):
931                    for v in range(len(old_dict["bands"][spin][k])):
932                        if k >= max_index:
933                            old_dict["bands"][spin][k][v] = old_dict["bands"][spin][k][v] + shift
934        else:
935
936            shift = new_band_gap - self.get_band_gap()["energy"]
937            old_dict = self.as_dict()
938            for spin in old_dict["bands"]:
939                for k in range(len(old_dict["bands"][spin])):
940                    for v in range(len(old_dict["bands"][spin][k])):
941                        if old_dict["bands"][spin][k][v] >= old_dict["cbm"]["energy"]:
942                            old_dict["bands"][spin][k][v] = old_dict["bands"][spin][k][v] + shift
943            old_dict["efermi"] = old_dict["efermi"] + shift
944        return self.from_dict(old_dict)
945
946    def as_dict(self):
947        """
948        Json-serializable dict representation of BandStructureSymmLine.
949        """
950        d = super().as_dict()
951        d["branches"] = self.branches
952        return d
953
954
955class LobsterBandStructureSymmLine(BandStructureSymmLine):
956    """
957    Lobster subclass of BandStructure with customized functions.
958    """
959
960    def as_dict(self):
961        """
962        Json-serializable dict representation of BandStructureSymmLine.
963        """
964
965        d = {
966            "@module": self.__class__.__module__,
967            "@class": self.__class__.__name__,
968            "lattice_rec": self.lattice_rec.as_dict(),
969            "efermi": self.efermi,
970            "kpoints": [],
971        }
972        # kpoints are not kpoint objects dicts but are frac coords (this makes
973        # the dict smaller and avoids the repetition of the lattice
974        for k in self.kpoints:
975            d["kpoints"].append(k.as_dict()["fcoords"])
976        d["branches"] = self.branches
977        d["bands"] = {str(int(spin)): self.bands[spin].tolist() for spin in self.bands}
978        d["is_metal"] = self.is_metal()
979        vbm = self.get_vbm()
980        d["vbm"] = {
981            "energy": vbm["energy"],
982            "kpoint_index": [int(x) for x in vbm["kpoint_index"]],
983            "band_index": {str(int(spin)): vbm["band_index"][spin] for spin in vbm["band_index"]},
984            "projections": {str(spin): v for spin, v in vbm["projections"].items()},
985        }
986        cbm = self.get_cbm()
987        d["cbm"] = {
988            "energy": cbm["energy"],
989            "kpoint_index": [int(x) for x in cbm["kpoint_index"]],
990            "band_index": {str(int(spin)): cbm["band_index"][spin] for spin in cbm["band_index"]},
991            "projections": {str(spin): v for spin, v in cbm["projections"].items()},
992        }
993        d["band_gap"] = self.get_band_gap()
994        d["labels_dict"] = {}
995        d["is_spin_polarized"] = self.is_spin_polarized
996        # MongoDB does not accept keys starting with $. Add a blanck space to fix the problem
997        for c, label in self.labels_dict.items():
998            mongo_key = c if not c.startswith("$") else " " + c
999            d["labels_dict"][mongo_key] = label.as_dict()["fcoords"]
1000        if len(self.projections) != 0:
1001            d["structure"] = self.structure.as_dict()
1002            d["projections"] = {str(int(spin)): np.array(v).tolist() for spin, v in self.projections.items()}
1003        return d
1004
1005    @classmethod
1006    def from_dict(cls, d):
1007        """
1008        Args:
1009            d (dict): A dict with all data for a band structure symm line
1010                object.
1011
1012        Returns:
1013            A BandStructureSymmLine object
1014        """
1015        try:
1016            # Strip the label to recover initial string (see trick used in as_dict to handle $ chars)
1017            labels_dict = {k.strip(): v for k, v in d["labels_dict"].items()}
1018            projections = {}
1019            structure = None
1020            if d.get("projections"):
1021                if isinstance(d["projections"]["1"][0][0], dict):
1022                    raise ValueError("Old band structure dict format detected!")
1023                structure = Structure.from_dict(d["structure"])
1024                projections = {Spin(int(spin)): np.array(v) for spin, v in d["projections"].items()}
1025
1026            return LobsterBandStructureSymmLine(
1027                d["kpoints"],
1028                {Spin(int(k)): d["bands"][k] for k in d["bands"]},
1029                Lattice(d["lattice_rec"]["matrix"]),
1030                d["efermi"],
1031                labels_dict,
1032                structure=structure,
1033                projections=projections,
1034            )
1035        except Exception:
1036            warnings.warn(
1037                "Trying from_dict failed. Now we are trying the old "
1038                "format. Please convert your BS dicts to the new "
1039                "format. The old format will be retired in pymatgen "
1040                "5.0."
1041            )
1042            return LobsterBandStructureSymmLine.from_old_dict(d)
1043
1044    @classmethod
1045    def from_old_dict(cls, d):
1046        """
1047        Args:
1048            d (dict): A dict with all data for a band structure symm line
1049                object.
1050        Returns:
1051            A BandStructureSymmLine object
1052        """
1053        # Strip the label to recover initial string (see trick used in as_dict to handle $ chars)
1054        labels_dict = {k.strip(): v for k, v in d["labels_dict"].items()}
1055        projections = {}
1056        structure = None
1057        if "projections" in d and len(d["projections"]) != 0:
1058            structure = Structure.from_dict(d["structure"])
1059            projections = {}
1060            for spin in d["projections"]:
1061                dd = []
1062                for i in range(len(d["projections"][spin])):
1063                    ddd = []
1064                    for j in range(len(d["projections"][spin][i])):
1065                        ddd.append(d["projections"][spin][i][j])
1066                    dd.append(np.array(ddd))
1067                projections[Spin(int(spin))] = np.array(dd)
1068
1069        return LobsterBandStructureSymmLine(
1070            d["kpoints"],
1071            {Spin(int(k)): d["bands"][k] for k in d["bands"]},
1072            Lattice(d["lattice_rec"]["matrix"]),
1073            d["efermi"],
1074            labels_dict,
1075            structure=structure,
1076            projections=projections,
1077        )
1078
1079    def get_projection_on_elements(self):
1080        """
1081        Method returning a dictionary of projections on elements.
1082        It sums over all available orbitals for each element.
1083
1084        Returns:
1085            a dictionary in the {Spin.up:[][{Element:values}],
1086            Spin.down:[][{Element:values}]} format
1087            if there is no projections in the band structure
1088            returns an empty dict
1089        """
1090        result = {}
1091        for spin, v in self.projections.items():
1092            result[spin] = [
1093                [collections.defaultdict(float) for i in range(len(self.kpoints))] for j in range(self.nb_bands)
1094            ]
1095            for i, j in itertools.product(range(self.nb_bands), range(len(self.kpoints))):
1096                for key, item in v[i][j].items():
1097                    for key2, item2 in item.items():
1098                        specie = str(Element(re.split(r"[0-9]+", key)[0]))
1099                        result[spin][i][j][specie] += item2
1100        return result
1101
1102    def get_projections_on_elements_and_orbitals(self, el_orb_spec):
1103        """
1104        Method returning a dictionary of projections on elements and specific
1105        orbitals
1106
1107        Args:
1108            el_orb_spec: A dictionary of Elements and Orbitals for which we want
1109                to have projections on. It is given as: {Element:[orbitals]},
1110                e.g., {'Si':['3s','3p']} or {'Si':['3s','3p_x', '3p_y', '3p_z']} depending on input files
1111
1112        Returns:
1113            A dictionary of projections on elements in the
1114            {Spin.up:[][{Element:{orb:values}}],
1115            Spin.down:[][{Element:{orb:values}}]} format
1116            if there is no projections in the band structure returns an empty
1117            dict.
1118        """
1119        result = {}
1120        el_orb_spec = {get_el_sp(el): orbs for el, orbs in el_orb_spec.items()}
1121        for spin, v in self.projections.items():
1122            result[spin] = [
1123                [{str(e): collections.defaultdict(float) for e in el_orb_spec} for i in range(len(self.kpoints))]
1124                for j in range(self.nb_bands)
1125            ]
1126
1127            for i, j in itertools.product(range(self.nb_bands), range(len(self.kpoints))):
1128                for key, item in v[i][j].items():
1129                    for key2, item2 in item.items():
1130                        specie = str(Element(re.split(r"[0-9]+", key)[0]))
1131                        if get_el_sp(str(specie)) in el_orb_spec:
1132                            if key2 in el_orb_spec[get_el_sp(str(specie))]:
1133                                result[spin][i][j][specie][key2] += item2
1134        return result
1135
1136
1137def get_reconstructed_band_structure(list_bs, efermi=None):
1138    """
1139    This method takes a list of band structures and reconstructs
1140    one band structure object from all of them.
1141
1142    This is typically very useful when you split non self consistent
1143    band structure runs in several independent jobs and want to merge back
1144    the results
1145
1146    Args:
1147        list_bs: A list of BandStructure or BandStructureSymmLine objects.
1148        efermi: The Fermi energy of the reconstructed band structure. If
1149            None is assigned an average of all the Fermi energy in each
1150            object in the list_bs is used.
1151
1152    Returns:
1153        A BandStructure or BandStructureSymmLine object (depending on
1154        the type of the list_bs objects)
1155    """
1156    if efermi is None:
1157        efermi = sum([b.efermi for b in list_bs]) / len(list_bs)
1158
1159    kpoints = []
1160    labels_dict = {}
1161    rec_lattice = list_bs[0].lattice_rec
1162    nb_bands = min([list_bs[i].nb_bands for i in range(len(list_bs))])
1163
1164    kpoints = np.concatenate([[k.frac_coords for k in bs.kpoints] for bs in list_bs])
1165    dicts = [bs.labels_dict for bs in list_bs]
1166    labels_dict = {k: v.frac_coords for d in dicts for k, v in d.items()}
1167
1168    eigenvals = {}
1169    eigenvals[Spin.up] = np.concatenate([bs.bands[Spin.up][:nb_bands] for bs in list_bs], axis=1)
1170
1171    if list_bs[0].is_spin_polarized:
1172        eigenvals[Spin.down] = np.concatenate([bs.bands[Spin.down][:nb_bands] for bs in list_bs], axis=1)
1173
1174    projections = {}
1175    if len(list_bs[0].projections) != 0:
1176        projs = [bs.projections[Spin.up][:nb_bands] for bs in list_bs]
1177        projections[Spin.up] = np.concatenate(projs, axis=1)
1178
1179        if list_bs[0].is_spin_polarized:
1180            projs = [bs.projections[Spin.down][:nb_bands] for bs in list_bs]
1181            projections[Spin.down] = np.concatenate(projs, axis=1)
1182
1183    if isinstance(list_bs[0], BandStructureSymmLine):
1184        return BandStructureSymmLine(
1185            kpoints,
1186            eigenvals,
1187            rec_lattice,
1188            efermi,
1189            labels_dict,
1190            structure=list_bs[0].structure,
1191            projections=projections,
1192        )
1193    return BandStructure(
1194        kpoints,
1195        eigenvals,
1196        rec_lattice,
1197        efermi,
1198        labels_dict,
1199        structure=list_bs[0].structure,
1200        projections=projections,
1201    )
1202