1# coding: utf-8 2# Copyright (c) Pymatgen Development Team. 3# Distributed under the terms of the MIT License. 4 5""" 6This module provides classes to define everything related to band structures. 7""" 8 9import collections 10import itertools 11import math 12import re 13import warnings 14 15import numpy as np 16from monty.json import MSONable 17 18from pymatgen.core.lattice import Lattice 19from pymatgen.core.periodic_table import Element, get_el_sp 20from pymatgen.core.structure import Structure 21from pymatgen.electronic_structure.core import Orbital, Spin 22from pymatgen.symmetry.analyzer import SpacegroupAnalyzer 23from pymatgen.util.coord import pbc_diff 24 25__author__ = "Geoffroy Hautier, Shyue Ping Ong, Michael Kocher" 26__copyright__ = "Copyright 2012, The Materials Project" 27__version__ = "1.0" 28__maintainer__ = "Geoffroy Hautier" 29__email__ = "geoffroy@uclouvain.be" 30__status__ = "Development" 31__date__ = "March 14, 2012" 32 33 34class Kpoint(MSONable): 35 """ 36 Class to store kpoint objects. A kpoint is defined with a lattice and frac 37 or cartesian coordinates syntax similar than the site object in 38 pymatgen.core.structure. 39 """ 40 41 def __init__( 42 self, 43 coords, 44 lattice, 45 to_unit_cell=False, 46 coords_are_cartesian=False, 47 label=None, 48 ): 49 """ 50 Args: 51 coords: coordinate of the kpoint as a numpy array 52 lattice: A pymatgen.core.lattice.Lattice lattice object representing 53 the reciprocal lattice of the kpoint 54 to_unit_cell: Translates fractional coordinate to the basic unit 55 cell, i.e., all fractional coordinates satisfy 0 <= a < 1. 56 Defaults to False. 57 coords_are_cartesian: Boolean indicating if the coordinates given are 58 in cartesian or fractional coordinates (by default fractional) 59 label: the label of the kpoint if any (None by default) 60 """ 61 self._lattice = lattice 62 self._fcoords = lattice.get_fractional_coords(coords) if coords_are_cartesian else coords 63 self._label = label 64 65 if to_unit_cell: 66 for i, fc in enumerate(self._fcoords): 67 self._fcoords[i] -= math.floor(fc) 68 69 self._ccoords = lattice.get_cartesian_coords(self._fcoords) 70 71 @property 72 def lattice(self): 73 """ 74 The lattice associated with the kpoint. It's a 75 pymatgen.core.lattice.Lattice object 76 """ 77 return self._lattice 78 79 @property 80 def label(self): 81 """ 82 The label associated with the kpoint 83 """ 84 return self._label 85 86 @property 87 def frac_coords(self): 88 """ 89 The fractional coordinates of the kpoint as a numpy array 90 """ 91 return np.copy(self._fcoords) 92 93 @property 94 def cart_coords(self): 95 """ 96 The cartesian coordinates of the kpoint as a numpy array 97 """ 98 return np.copy(self._ccoords) 99 100 @property 101 def a(self): 102 """ 103 Fractional a coordinate of the kpoint 104 """ 105 return self._fcoords[0] 106 107 @property 108 def b(self): 109 """ 110 Fractional b coordinate of the kpoint 111 """ 112 return self._fcoords[1] 113 114 @property 115 def c(self): 116 """ 117 Fractional c coordinate of the kpoint 118 """ 119 return self._fcoords[2] 120 121 def __str__(self): 122 """ 123 Returns a string with fractional, cartesian coordinates and label 124 """ 125 return "{} {} {}".format(self.frac_coords, self.cart_coords, self.label) 126 127 def as_dict(self): 128 """ 129 Json-serializable dict representation of a kpoint 130 """ 131 return { 132 "lattice": self.lattice.as_dict(), 133 "fcoords": self.frac_coords.tolist(), 134 "ccoords": self.cart_coords.tolist(), 135 "label": self.label, 136 "@module": self.__class__.__module__, 137 "@class": self.__class__.__name__, 138 } 139 140 @classmethod 141 def from_dict(cls, d): 142 """ 143 Create from dict. 144 145 Args: 146 A dict with all data for a kpoint object. 147 148 Returns: 149 A Kpoint object 150 """ 151 152 return cls( 153 coords=d["fcoords"], 154 lattice=Lattice.from_dict(d["lattice"]), 155 coords_are_cartesian=False, 156 label=d["label"], 157 ) 158 159 160class BandStructure: 161 """ 162 This is the most generic band structure data possible 163 it's defined by a list of kpoints + energies for each of them 164 165 .. attribute:: kpoints: 166 the list of kpoints (as Kpoint objects) in the band structure 167 168 .. attribute:: lattice_rec 169 170 the reciprocal lattice of the band structure. 171 172 .. attribute:: efermi 173 174 the fermi energy 175 176 .. attribute:: is_spin_polarized 177 178 True if the band structure is spin-polarized, False otherwise 179 180 .. attribute:: bands 181 182 The energy eigenvalues as a {spin: ndarray}. Note that the use of an 183 ndarray is necessary for computational as well as memory efficiency 184 due to the large amount of numerical data. The indices of the ndarray 185 are [band_index, kpoint_index]. 186 187 .. attribute:: nb_bands 188 189 returns the number of bands in the band structure 190 191 .. attribute:: structure 192 193 returns the structure 194 195 .. attribute:: projections 196 197 The projections as a {spin: ndarray}. Note that the use of an 198 ndarray is necessary for computational as well as memory efficiency 199 due to the large amount of numerical data. The indices of the ndarray 200 are [band_index, kpoint_index, orbital_index, ion_index]. 201 """ 202 203 def __init__( 204 self, 205 kpoints, 206 eigenvals, 207 lattice, 208 efermi, 209 labels_dict=None, 210 coords_are_cartesian=False, 211 structure=None, 212 projections=None, 213 ): 214 """ 215 Args: 216 kpoints: list of kpoint as numpy arrays, in frac_coords of the 217 given lattice by default 218 eigenvals: dict of energies for spin up and spin down 219 {Spin.up:[][],Spin.down:[][]}, the first index of the array 220 [][] refers to the band and the second to the index of the 221 kpoint. The kpoints are ordered according to the order of the 222 kpoints array. If the band structure is not spin polarized, we 223 only store one data set under Spin.up 224 lattice: The reciprocal lattice as a pymatgen Lattice object. 225 Pymatgen uses the physics convention of reciprocal lattice vectors 226 WITH a 2*pi coefficient 227 efermi: fermi energy 228 labels_dict: (dict) of {} this links a kpoint (in frac coords or 229 cartesian coordinates depending on the coords) to a label. 230 coords_are_cartesian: Whether coordinates are cartesian. 231 structure: The crystal structure (as a pymatgen Structure object) 232 associated with the band structure. This is needed if we 233 provide projections to the band structure 234 projections: dict of orbital projections as {spin: ndarray}. The 235 indices of the ndarrayare [band_index, kpoint_index, orbital_index, 236 ion_index].If the band structure is not spin polarized, we only 237 store one data set under Spin.up. 238 """ 239 self.efermi = efermi 240 self.lattice_rec = lattice 241 self.kpoints = [] 242 self.labels_dict = {} 243 self.structure = structure 244 self.projections = projections or {} 245 self.projections = {k: np.array(v) for k, v in self.projections.items()} 246 247 if labels_dict is None: 248 labels_dict = {} 249 250 if len(self.projections) != 0 and self.structure is None: 251 raise Exception("if projections are provided a structure object" " needs also to be given") 252 253 for k in kpoints: 254 # let see if this kpoint has been assigned a label 255 label = None 256 for c in labels_dict: 257 if np.linalg.norm(k - np.array(labels_dict[c])) < 0.0001: 258 label = c 259 self.labels_dict[label] = Kpoint( 260 k, 261 lattice, 262 label=label, 263 coords_are_cartesian=coords_are_cartesian, 264 ) 265 self.kpoints.append(Kpoint(k, lattice, label=label, coords_are_cartesian=coords_are_cartesian)) 266 self.bands = {spin: np.array(v) for spin, v in eigenvals.items()} 267 self.nb_bands = len(eigenvals[Spin.up]) 268 self.is_spin_polarized = len(self.bands) == 2 269 270 def get_projection_on_elements(self): 271 """ 272 Method returning a dictionary of projections on elements. 273 274 Returns: 275 a dictionary in the {Spin.up:[][{Element:values}], 276 Spin.down:[][{Element:values}]} format 277 if there is no projections in the band structure 278 returns an empty dict 279 """ 280 result = {} 281 structure = self.structure 282 for spin, v in self.projections.items(): 283 result[spin] = [ 284 [collections.defaultdict(float) for i in range(len(self.kpoints))] for j in range(self.nb_bands) 285 ] 286 for i, j, k in itertools.product( 287 range(self.nb_bands), 288 range(len(self.kpoints)), 289 range(structure.num_sites), 290 ): 291 result[spin][i][j][str(structure[k].specie)] += np.sum(v[i, j, :, k]) 292 return result 293 294 def get_projections_on_elements_and_orbitals(self, el_orb_spec): 295 """ 296 Method returning a dictionary of projections on elements and specific 297 orbitals 298 299 Args: 300 el_orb_spec: A dictionary of Elements and Orbitals for which we want 301 to have projections on. It is given as: {Element:[orbitals]}, 302 e.g., {'Cu':['d','s']} 303 304 Returns: 305 A dictionary of projections on elements in the 306 {Spin.up:[][{Element:{orb:values}}], 307 Spin.down:[][{Element:{orb:values}}]} format 308 if there is no projections in the band structure returns an empty 309 dict. 310 """ 311 result = {} 312 structure = self.structure 313 el_orb_spec = {get_el_sp(el): orbs for el, orbs in el_orb_spec.items()} 314 for spin, v in self.projections.items(): 315 result[spin] = [ 316 [{str(e): collections.defaultdict(float) for e in el_orb_spec} for i in range(len(self.kpoints))] 317 for j in range(self.nb_bands) 318 ] 319 320 for i, j, k in itertools.product( 321 range(self.nb_bands), 322 range(len(self.kpoints)), 323 range(structure.num_sites), 324 ): 325 sp = structure[k].specie 326 for orb_i in range(len(v[i][j])): 327 o = Orbital(orb_i).name[0] 328 if sp in el_orb_spec: 329 if o in el_orb_spec[sp]: 330 result[spin][i][j][str(sp)][o] += v[i][j][orb_i][k] 331 return result 332 333 def is_metal(self, efermi_tol=1e-4): 334 """ 335 Check if the band structure indicates a metal by looking if the fermi 336 level crosses a band. 337 338 Returns: 339 True if a metal, False if not 340 """ 341 for spin, values in self.bands.items(): 342 for i in range(self.nb_bands): 343 if np.any(values[i, :] - self.efermi < -efermi_tol) and np.any(values[i, :] - self.efermi > efermi_tol): 344 return True 345 return False 346 347 def get_vbm(self): 348 """ 349 Returns data about the VBM. 350 351 Returns: 352 dict as {"band_index","kpoint_index","kpoint","energy"} 353 - "band_index": A dict with spin keys pointing to a list of the 354 indices of the band containing the VBM (please note that you 355 can have several bands sharing the VBM) {Spin.up:[], 356 Spin.down:[]} 357 - "kpoint_index": The list of indices in self.kpoints for the 358 kpoint VBM. Please note that there can be several 359 kpoint_indices relating to the same kpoint (e.g., Gamma can 360 occur at different spots in the band structure line plot) 361 - "kpoint": The kpoint (as a kpoint object) 362 - "energy": The energy of the VBM 363 - "projections": The projections along sites and orbitals of the 364 VBM if any projection data is available (else it is an empty 365 dictionnary). The format is similar to the projections field in 366 BandStructure: {spin:{'Orbital': [proj]}} where the array 367 [proj] is ordered according to the sites in structure 368 """ 369 if self.is_metal(): 370 return { 371 "band_index": [], 372 "kpoint_index": [], 373 "kpoint": [], 374 "energy": None, 375 "projections": {}, 376 } 377 max_tmp = -float("inf") 378 index = None 379 kpointvbm = None 380 for spin, v in self.bands.items(): 381 for i, j in zip(*np.where(v < self.efermi)): 382 if v[i, j] > max_tmp: 383 max_tmp = float(v[i, j]) 384 index = j 385 kpointvbm = self.kpoints[j] 386 387 list_ind_kpts = [] 388 if kpointvbm.label is not None: 389 for i, kpt in enumerate(self.kpoints): 390 if kpt.label == kpointvbm.label: 391 list_ind_kpts.append(i) 392 else: 393 list_ind_kpts.append(index) 394 # get all other bands sharing the vbm 395 list_ind_band = collections.defaultdict(list) 396 for spin in self.bands: 397 for i in range(self.nb_bands): 398 if math.fabs(self.bands[spin][i][index] - max_tmp) < 0.001: 399 list_ind_band[spin].append(i) 400 proj = {} 401 for spin, v in self.projections.items(): 402 if len(list_ind_band[spin]) == 0: 403 continue 404 proj[spin] = v[list_ind_band[spin][0]][list_ind_kpts[0]] 405 return { 406 "band_index": list_ind_band, 407 "kpoint_index": list_ind_kpts, 408 "kpoint": kpointvbm, 409 "energy": max_tmp, 410 "projections": proj, 411 } 412 413 def get_cbm(self): 414 """ 415 Returns data about the CBM. 416 417 Returns: 418 {"band_index","kpoint_index","kpoint","energy"} 419 - "band_index": A dict with spin keys pointing to a list of the 420 indices of the band containing the CBM (please note that you 421 can have several bands sharing the CBM) {Spin.up:[], 422 Spin.down:[]} 423 - "kpoint_index": The list of indices in self.kpoints for the 424 kpoint CBM. Please note that there can be several 425 kpoint_indices relating to the same kpoint (e.g., Gamma can 426 occur at different spots in the band structure line plot) 427 - "kpoint": The kpoint (as a kpoint object) 428 - "energy": The energy of the CBM 429 - "projections": The projections along sites and orbitals of the 430 CBM if any projection data is available (else it is an empty 431 dictionnary). The format is similar to the projections field in 432 BandStructure: {spin:{'Orbital': [proj]}} where the array 433 [proj] is ordered according to the sites in structure 434 """ 435 if self.is_metal(): 436 return { 437 "band_index": [], 438 "kpoint_index": [], 439 "kpoint": [], 440 "energy": None, 441 "projections": {}, 442 } 443 max_tmp = float("inf") 444 445 index = None 446 kpointcbm = None 447 for spin, v in self.bands.items(): 448 for i, j in zip(*np.where(v >= self.efermi)): 449 if v[i, j] < max_tmp: 450 max_tmp = float(v[i, j]) 451 index = j 452 kpointcbm = self.kpoints[j] 453 454 list_index_kpoints = [] 455 if kpointcbm.label is not None: 456 for i, kpt in enumerate(self.kpoints): 457 if kpt.label == kpointcbm.label: 458 list_index_kpoints.append(i) 459 else: 460 list_index_kpoints.append(index) 461 462 # get all other bands sharing the cbm 463 list_index_band = collections.defaultdict(list) 464 for spin in self.bands: 465 for i in range(self.nb_bands): 466 if math.fabs(self.bands[spin][i][index] - max_tmp) < 0.001: 467 list_index_band[spin].append(i) 468 proj = {} 469 for spin, v in self.projections.items(): 470 if len(list_index_band[spin]) == 0: 471 continue 472 proj[spin] = v[list_index_band[spin][0]][list_index_kpoints[0]] 473 474 return { 475 "band_index": list_index_band, 476 "kpoint_index": list_index_kpoints, 477 "kpoint": kpointcbm, 478 "energy": max_tmp, 479 "projections": proj, 480 } 481 482 def get_band_gap(self): 483 r""" 484 Returns band gap data. 485 486 Returns: 487 A dict {"energy","direct","transition"}: 488 "energy": band gap energy 489 "direct": A boolean telling if the gap is direct or not 490 "transition": kpoint labels of the transition (e.g., "\\Gamma-X") 491 """ 492 if self.is_metal(): 493 return {"energy": 0.0, "direct": False, "transition": None} 494 cbm = self.get_cbm() 495 vbm = self.get_vbm() 496 result = dict(direct=False, energy=0.0, transition=None) 497 498 result["energy"] = cbm["energy"] - vbm["energy"] 499 500 if (cbm["kpoint"].label is not None and cbm["kpoint"].label == vbm["kpoint"].label) or np.linalg.norm( 501 cbm["kpoint"].cart_coords - vbm["kpoint"].cart_coords 502 ) < 0.01: 503 result["direct"] = True 504 505 result["transition"] = "-".join( 506 [ 507 str(c.label) 508 if c.label is not None 509 else str("(") + ",".join(["{0:.3f}".format(c.frac_coords[i]) for i in range(3)]) + str(")") 510 for c in [vbm["kpoint"], cbm["kpoint"]] 511 ] 512 ) 513 514 return result 515 516 def get_direct_band_gap_dict(self): 517 """ 518 Returns a dictionary of information about the direct 519 band gap 520 521 Returns: 522 a dictionary of the band gaps indexed by spin 523 along with their band indices and k-point index 524 """ 525 if self.is_metal(): 526 raise ValueError("get_direct_band_gap_dict should only be used with non-metals") 527 direct_gap_dict = {} 528 for spin, v in self.bands.items(): 529 above = v[np.all(v > self.efermi, axis=1)] 530 min_above = np.min(above, axis=0) 531 below = v[np.all(v < self.efermi, axis=1)] 532 max_below = np.max(below, axis=0) 533 diff = min_above - max_below 534 kpoint_index = np.argmin(diff) 535 band_indices = [ 536 np.argmax(below[:, kpoint_index]), 537 np.argmin(above[:, kpoint_index]) + len(below), 538 ] 539 direct_gap_dict[spin] = { 540 "value": diff[kpoint_index], 541 "kpoint_index": kpoint_index, 542 "band_indices": band_indices, 543 } 544 return direct_gap_dict 545 546 def get_direct_band_gap(self): 547 """ 548 Returns the direct band gap. 549 550 Returns: 551 the value of the direct band gap 552 """ 553 if self.is_metal(): 554 return 0.0 555 dg = self.get_direct_band_gap_dict() 556 return min(v["value"] for v in dg.values()) 557 558 def get_sym_eq_kpoints(self, kpoint, cartesian=False, tol=1e-2): 559 """ 560 Returns a list of unique symmetrically equivalent k-points. 561 562 Args: 563 kpoint (1x3 array): coordinate of the k-point 564 cartesian (bool): kpoint is in cartesian or fractional coordinates 565 tol (float): tolerance below which coordinates are considered equal 566 567 Returns: 568 ([1x3 array] or None): if structure is not available returns None 569 """ 570 if not self.structure: 571 return None 572 sg = SpacegroupAnalyzer(self.structure) 573 symmops = sg.get_point_group_operations(cartesian=cartesian) 574 points = np.dot(kpoint, [m.rotation_matrix for m in symmops]) 575 rm_list = [] 576 # identify and remove duplicates from the list of equivalent k-points: 577 for i in range(len(points) - 1): 578 for j in range(i + 1, len(points)): 579 if np.allclose(pbc_diff(points[i], points[j]), [0, 0, 0], tol): 580 rm_list.append(i) 581 break 582 return np.delete(points, rm_list, axis=0) 583 584 def get_kpoint_degeneracy(self, kpoint, cartesian=False, tol=1e-2): 585 """ 586 Returns degeneracy of a given k-point based on structure symmetry 587 Args: 588 kpoint (1x3 array): coordinate of the k-point 589 cartesian (bool): kpoint is in cartesian or fractional coordinates 590 tol (float): tolerance below which coordinates are considered equal 591 592 Returns: 593 (int or None): degeneracy or None if structure is not available 594 """ 595 all_kpts = self.get_sym_eq_kpoints(kpoint, cartesian, tol=tol) 596 if all_kpts is not None: 597 return len(all_kpts) 598 return None 599 600 def as_dict(self): 601 """ 602 Json-serializable dict representation of BandStructure. 603 """ 604 d = { 605 "@module": self.__class__.__module__, 606 "@class": self.__class__.__name__, 607 "lattice_rec": self.lattice_rec.as_dict(), 608 "efermi": self.efermi, 609 "kpoints": [], 610 } 611 # kpoints are not kpoint objects dicts but are frac coords (this makes 612 # the dict smaller and avoids the repetition of the lattice 613 for k in self.kpoints: 614 d["kpoints"].append(k.as_dict()["fcoords"]) 615 616 d["bands"] = {str(int(spin)): self.bands[spin].tolist() for spin in self.bands} 617 d["is_metal"] = self.is_metal() 618 vbm = self.get_vbm() 619 d["vbm"] = { 620 "energy": vbm["energy"], 621 "kpoint_index": vbm["kpoint_index"], 622 "band_index": {str(int(spin)): vbm["band_index"][spin] for spin in vbm["band_index"]}, 623 "projections": {str(spin): v.tolist() for spin, v in vbm["projections"].items()}, 624 } 625 cbm = self.get_cbm() 626 d["cbm"] = { 627 "energy": cbm["energy"], 628 "kpoint_index": cbm["kpoint_index"], 629 "band_index": {str(int(spin)): cbm["band_index"][spin] for spin in cbm["band_index"]}, 630 "projections": {str(spin): v.tolist() for spin, v in cbm["projections"].items()}, 631 } 632 d["band_gap"] = self.get_band_gap() 633 d["labels_dict"] = {} 634 d["is_spin_polarized"] = self.is_spin_polarized 635 636 # MongoDB does not accept keys starting with $. Add a blanck space to fix the problem 637 for c, label in self.labels_dict.items(): 638 mongo_key = c if not c.startswith("$") else " " + c 639 d["labels_dict"][mongo_key] = label.as_dict()["fcoords"] 640 d["projections"] = {} 641 if len(self.projections) != 0: 642 d["structure"] = self.structure.as_dict() 643 d["projections"] = {str(int(spin)): np.array(v).tolist() for spin, v in self.projections.items()} 644 return d 645 646 @classmethod 647 def from_dict(cls, d): 648 """ 649 Create from dict. 650 651 Args: 652 A dict with all data for a band structure object. 653 654 Returns: 655 A BandStructure object 656 """ 657 # Strip the label to recover initial string 658 # (see trick used in as_dict to handle $ chars) 659 labels_dict = {k.strip(): v for k, v in d["labels_dict"].items()} 660 projections = {} 661 structure = None 662 if isinstance(list(d["bands"].values())[0], dict): 663 eigenvals = {Spin(int(k)): np.array(d["bands"][k]["data"]) for k in d["bands"]} 664 else: 665 eigenvals = {Spin(int(k)): d["bands"][k] for k in d["bands"]} 666 667 if "structure" in d: 668 structure = Structure.from_dict(d["structure"]) 669 670 try: 671 if d.get("projections"): 672 if isinstance(d["projections"]["1"][0][0], dict): 673 raise ValueError("Old band structure dict format detected!") 674 projections = {Spin(int(spin)): np.array(v) for spin, v in d["projections"].items()} 675 676 return cls( 677 d["kpoints"], 678 eigenvals, 679 Lattice(d["lattice_rec"]["matrix"]), 680 d["efermi"], 681 labels_dict, 682 structure=structure, 683 projections=projections, 684 ) 685 686 except Exception: 687 warnings.warn( 688 "Trying from_dict failed. Now we are trying the old " 689 "format. Please convert your BS dicts to the new " 690 "format. The old format will be retired in pymatgen " 691 "5.0." 692 ) 693 return cls.from_old_dict(d) 694 695 @classmethod 696 def from_old_dict(cls, d): 697 """ 698 Args: 699 d (dict): A dict with all data for a band structure symm line 700 object. 701 Returns: 702 A BandStructureSymmLine object 703 """ 704 # Strip the label to recover initial string (see trick used in as_dict to handle $ chars) 705 labels_dict = {k.strip(): v for k, v in d["labels_dict"].items()} 706 projections = {} 707 structure = None 708 if "projections" in d and len(d["projections"]) != 0: 709 structure = Structure.from_dict(d["structure"]) 710 projections = {} 711 for spin in d["projections"]: 712 dd = [] 713 for i in range(len(d["projections"][spin])): 714 ddd = [] 715 for j in range(len(d["projections"][spin][i])): 716 dddd = [] 717 for k in range(len(d["projections"][spin][i][j])): 718 ddddd = [] 719 orb = Orbital(k).name 720 for l in range(len(d["projections"][spin][i][j][orb])): 721 ddddd.append(d["projections"][spin][i][j][orb][l]) 722 dddd.append(np.array(ddddd)) 723 ddd.append(np.array(dddd)) 724 dd.append(np.array(ddd)) 725 projections[Spin(int(spin))] = np.array(dd) 726 727 return BandStructure( 728 d["kpoints"], 729 {Spin(int(k)): d["bands"][k] for k in d["bands"]}, 730 Lattice(d["lattice_rec"]["matrix"]), 731 d["efermi"], 732 labels_dict, 733 structure=structure, 734 projections=projections, 735 ) 736 737 738class BandStructureSymmLine(BandStructure, MSONable): 739 r""" 740 This object stores band structures along selected (symmetry) lines in the 741 Brillouin zone. We call the different symmetry lines (ex: \\Gamma to Z) 742 "branches". 743 """ 744 745 def __init__( 746 self, 747 kpoints, 748 eigenvals, 749 lattice, 750 efermi, 751 labels_dict, 752 coords_are_cartesian=False, 753 structure=None, 754 projections=None, 755 ): 756 """ 757 Args: 758 kpoints: list of kpoint as numpy arrays, in frac_coords of the 759 given lattice by default 760 eigenvals: dict of energies for spin up and spin down 761 {Spin.up:[][],Spin.down:[][]}, the first index of the array 762 [][] refers to the band and the second to the index of the 763 kpoint. The kpoints are ordered according to the order of the 764 kpoints array. If the band structure is not spin polarized, we 765 only store one data set under Spin.up. 766 lattice: The reciprocal lattice. 767 Pymatgen uses the physics convention of reciprocal lattice vectors 768 WITH a 2*pi coefficient 769 efermi: fermi energy 770 label_dict: (dict) of {} this link a kpoint (in frac coords or 771 cartesian coordinates depending on the coords). 772 coords_are_cartesian: Whether coordinates are cartesian. 773 structure: The crystal structure (as a pymatgen Structure object) 774 associated with the band structure. This is needed if we 775 provide projections to the band structure. 776 projections: dict of orbital projections as {spin: ndarray}. The 777 indices of the ndarrayare [band_index, kpoint_index, orbital_index, 778 ion_index].If the band structure is not spin polarized, we only 779 store one data set under Spin.up. 780 """ 781 super().__init__( 782 kpoints, 783 eigenvals, 784 lattice, 785 efermi, 786 labels_dict, 787 coords_are_cartesian, 788 structure, 789 projections, 790 ) 791 self.distance = [] 792 self.branches = [] 793 one_group = [] 794 branches_tmp = [] 795 # get labels and distance for each kpoint 796 previous_kpoint = self.kpoints[0] 797 previous_distance = 0.0 798 799 previous_label = self.kpoints[0].label 800 for i, kpt in enumerate(self.kpoints): 801 label = kpt.label 802 if label is not None and previous_label is not None: 803 self.distance.append(previous_distance) 804 else: 805 self.distance.append(np.linalg.norm(kpt.cart_coords - previous_kpoint.cart_coords) + previous_distance) 806 previous_kpoint = kpt 807 previous_distance = self.distance[i] 808 if label: 809 if previous_label: 810 if len(one_group) != 0: 811 branches_tmp.append(one_group) 812 one_group = [] 813 previous_label = label 814 one_group.append(i) 815 816 if len(one_group) != 0: 817 branches_tmp.append(one_group) 818 for b in branches_tmp: 819 self.branches.append( 820 { 821 "start_index": b[0], 822 "end_index": b[-1], 823 "name": str(self.kpoints[b[0]].label) + "-" + str(self.kpoints[b[-1]].label), 824 } 825 ) 826 827 self.is_spin_polarized = False 828 if len(self.bands) == 2: 829 self.is_spin_polarized = True 830 831 def get_equivalent_kpoints(self, index): 832 """ 833 Returns the list of kpoint indices equivalent (meaning they are the 834 same frac coords) to the given one. 835 836 Args: 837 index: the kpoint index 838 839 Returns: 840 a list of equivalent indices 841 842 TODO: now it uses the label we might want to use coordinates instead 843 (in case there was a mislabel) 844 """ 845 # if the kpoint has no label it can"t have a repetition along the band 846 # structure line object 847 848 if self.kpoints[index].label is None: 849 return [index] 850 851 list_index_kpoints = [] 852 for i, kpt in enumerate(self.kpoints): 853 if kpt.label == self.kpoints[index].label: 854 list_index_kpoints.append(i) 855 856 return list_index_kpoints 857 858 def get_branch(self, index): 859 r""" 860 Returns in what branch(es) is the kpoint. There can be several 861 branches. 862 863 Args: 864 index: the kpoint index 865 866 Returns: 867 A list of dictionaries [{"name","start_index","end_index","index"}] 868 indicating all branches in which the k_point is. It takes into 869 account the fact that one kpoint (e.g., \\Gamma) can be in several 870 branches 871 """ 872 to_return = [] 873 for i in self.get_equivalent_kpoints(index): 874 for b in self.branches: 875 if b["start_index"] <= i <= b["end_index"]: 876 to_return.append( 877 { 878 "name": b["name"], 879 "start_index": b["start_index"], 880 "end_index": b["end_index"], 881 "index": i, 882 } 883 ) 884 return to_return 885 886 def apply_scissor(self, new_band_gap): 887 """ 888 Apply a scissor operator (shift of the CBM) to fit the given band gap. 889 If it's a metal. We look for the band crossing the fermi level 890 and shift this one up. This will not work all the time for metals! 891 892 Args: 893 new_band_gap: the band gap the scissor band structure need to have. 894 895 Returns: 896 a BandStructureSymmLine object with the applied scissor shift 897 """ 898 if self.is_metal(): 899 # moves then the highest index band crossing the fermi level 900 # find this band... 901 max_index = -1000 902 # spin_index = None 903 for i in range(self.nb_bands): 904 below = False 905 above = False 906 for j in range(len(self.kpoints)): 907 if self.bands[Spin.up][i][j] < self.efermi: 908 below = True 909 if self.bands[Spin.up][i][j] > self.efermi: 910 above = True 911 if above and below: 912 if i > max_index: 913 max_index = i 914 # spin_index = Spin.up 915 if self.is_spin_polarized: 916 below = False 917 above = False 918 for j in range(len(self.kpoints)): 919 if self.bands[Spin.down][i][j] < self.efermi: 920 below = True 921 if self.bands[Spin.down][i][j] > self.efermi: 922 above = True 923 if above and below: 924 if i > max_index: 925 max_index = i 926 # spin_index = Spin.down 927 old_dict = self.as_dict() 928 shift = new_band_gap 929 for spin in old_dict["bands"]: 930 for k in range(len(old_dict["bands"][spin])): 931 for v in range(len(old_dict["bands"][spin][k])): 932 if k >= max_index: 933 old_dict["bands"][spin][k][v] = old_dict["bands"][spin][k][v] + shift 934 else: 935 936 shift = new_band_gap - self.get_band_gap()["energy"] 937 old_dict = self.as_dict() 938 for spin in old_dict["bands"]: 939 for k in range(len(old_dict["bands"][spin])): 940 for v in range(len(old_dict["bands"][spin][k])): 941 if old_dict["bands"][spin][k][v] >= old_dict["cbm"]["energy"]: 942 old_dict["bands"][spin][k][v] = old_dict["bands"][spin][k][v] + shift 943 old_dict["efermi"] = old_dict["efermi"] + shift 944 return self.from_dict(old_dict) 945 946 def as_dict(self): 947 """ 948 Json-serializable dict representation of BandStructureSymmLine. 949 """ 950 d = super().as_dict() 951 d["branches"] = self.branches 952 return d 953 954 955class LobsterBandStructureSymmLine(BandStructureSymmLine): 956 """ 957 Lobster subclass of BandStructure with customized functions. 958 """ 959 960 def as_dict(self): 961 """ 962 Json-serializable dict representation of BandStructureSymmLine. 963 """ 964 965 d = { 966 "@module": self.__class__.__module__, 967 "@class": self.__class__.__name__, 968 "lattice_rec": self.lattice_rec.as_dict(), 969 "efermi": self.efermi, 970 "kpoints": [], 971 } 972 # kpoints are not kpoint objects dicts but are frac coords (this makes 973 # the dict smaller and avoids the repetition of the lattice 974 for k in self.kpoints: 975 d["kpoints"].append(k.as_dict()["fcoords"]) 976 d["branches"] = self.branches 977 d["bands"] = {str(int(spin)): self.bands[spin].tolist() for spin in self.bands} 978 d["is_metal"] = self.is_metal() 979 vbm = self.get_vbm() 980 d["vbm"] = { 981 "energy": vbm["energy"], 982 "kpoint_index": [int(x) for x in vbm["kpoint_index"]], 983 "band_index": {str(int(spin)): vbm["band_index"][spin] for spin in vbm["band_index"]}, 984 "projections": {str(spin): v for spin, v in vbm["projections"].items()}, 985 } 986 cbm = self.get_cbm() 987 d["cbm"] = { 988 "energy": cbm["energy"], 989 "kpoint_index": [int(x) for x in cbm["kpoint_index"]], 990 "band_index": {str(int(spin)): cbm["band_index"][spin] for spin in cbm["band_index"]}, 991 "projections": {str(spin): v for spin, v in cbm["projections"].items()}, 992 } 993 d["band_gap"] = self.get_band_gap() 994 d["labels_dict"] = {} 995 d["is_spin_polarized"] = self.is_spin_polarized 996 # MongoDB does not accept keys starting with $. Add a blanck space to fix the problem 997 for c, label in self.labels_dict.items(): 998 mongo_key = c if not c.startswith("$") else " " + c 999 d["labels_dict"][mongo_key] = label.as_dict()["fcoords"] 1000 if len(self.projections) != 0: 1001 d["structure"] = self.structure.as_dict() 1002 d["projections"] = {str(int(spin)): np.array(v).tolist() for spin, v in self.projections.items()} 1003 return d 1004 1005 @classmethod 1006 def from_dict(cls, d): 1007 """ 1008 Args: 1009 d (dict): A dict with all data for a band structure symm line 1010 object. 1011 1012 Returns: 1013 A BandStructureSymmLine object 1014 """ 1015 try: 1016 # Strip the label to recover initial string (see trick used in as_dict to handle $ chars) 1017 labels_dict = {k.strip(): v for k, v in d["labels_dict"].items()} 1018 projections = {} 1019 structure = None 1020 if d.get("projections"): 1021 if isinstance(d["projections"]["1"][0][0], dict): 1022 raise ValueError("Old band structure dict format detected!") 1023 structure = Structure.from_dict(d["structure"]) 1024 projections = {Spin(int(spin)): np.array(v) for spin, v in d["projections"].items()} 1025 1026 return LobsterBandStructureSymmLine( 1027 d["kpoints"], 1028 {Spin(int(k)): d["bands"][k] for k in d["bands"]}, 1029 Lattice(d["lattice_rec"]["matrix"]), 1030 d["efermi"], 1031 labels_dict, 1032 structure=structure, 1033 projections=projections, 1034 ) 1035 except Exception: 1036 warnings.warn( 1037 "Trying from_dict failed. Now we are trying the old " 1038 "format. Please convert your BS dicts to the new " 1039 "format. The old format will be retired in pymatgen " 1040 "5.0." 1041 ) 1042 return LobsterBandStructureSymmLine.from_old_dict(d) 1043 1044 @classmethod 1045 def from_old_dict(cls, d): 1046 """ 1047 Args: 1048 d (dict): A dict with all data for a band structure symm line 1049 object. 1050 Returns: 1051 A BandStructureSymmLine object 1052 """ 1053 # Strip the label to recover initial string (see trick used in as_dict to handle $ chars) 1054 labels_dict = {k.strip(): v for k, v in d["labels_dict"].items()} 1055 projections = {} 1056 structure = None 1057 if "projections" in d and len(d["projections"]) != 0: 1058 structure = Structure.from_dict(d["structure"]) 1059 projections = {} 1060 for spin in d["projections"]: 1061 dd = [] 1062 for i in range(len(d["projections"][spin])): 1063 ddd = [] 1064 for j in range(len(d["projections"][spin][i])): 1065 ddd.append(d["projections"][spin][i][j]) 1066 dd.append(np.array(ddd)) 1067 projections[Spin(int(spin))] = np.array(dd) 1068 1069 return LobsterBandStructureSymmLine( 1070 d["kpoints"], 1071 {Spin(int(k)): d["bands"][k] for k in d["bands"]}, 1072 Lattice(d["lattice_rec"]["matrix"]), 1073 d["efermi"], 1074 labels_dict, 1075 structure=structure, 1076 projections=projections, 1077 ) 1078 1079 def get_projection_on_elements(self): 1080 """ 1081 Method returning a dictionary of projections on elements. 1082 It sums over all available orbitals for each element. 1083 1084 Returns: 1085 a dictionary in the {Spin.up:[][{Element:values}], 1086 Spin.down:[][{Element:values}]} format 1087 if there is no projections in the band structure 1088 returns an empty dict 1089 """ 1090 result = {} 1091 for spin, v in self.projections.items(): 1092 result[spin] = [ 1093 [collections.defaultdict(float) for i in range(len(self.kpoints))] for j in range(self.nb_bands) 1094 ] 1095 for i, j in itertools.product(range(self.nb_bands), range(len(self.kpoints))): 1096 for key, item in v[i][j].items(): 1097 for key2, item2 in item.items(): 1098 specie = str(Element(re.split(r"[0-9]+", key)[0])) 1099 result[spin][i][j][specie] += item2 1100 return result 1101 1102 def get_projections_on_elements_and_orbitals(self, el_orb_spec): 1103 """ 1104 Method returning a dictionary of projections on elements and specific 1105 orbitals 1106 1107 Args: 1108 el_orb_spec: A dictionary of Elements and Orbitals for which we want 1109 to have projections on. It is given as: {Element:[orbitals]}, 1110 e.g., {'Si':['3s','3p']} or {'Si':['3s','3p_x', '3p_y', '3p_z']} depending on input files 1111 1112 Returns: 1113 A dictionary of projections on elements in the 1114 {Spin.up:[][{Element:{orb:values}}], 1115 Spin.down:[][{Element:{orb:values}}]} format 1116 if there is no projections in the band structure returns an empty 1117 dict. 1118 """ 1119 result = {} 1120 el_orb_spec = {get_el_sp(el): orbs for el, orbs in el_orb_spec.items()} 1121 for spin, v in self.projections.items(): 1122 result[spin] = [ 1123 [{str(e): collections.defaultdict(float) for e in el_orb_spec} for i in range(len(self.kpoints))] 1124 for j in range(self.nb_bands) 1125 ] 1126 1127 for i, j in itertools.product(range(self.nb_bands), range(len(self.kpoints))): 1128 for key, item in v[i][j].items(): 1129 for key2, item2 in item.items(): 1130 specie = str(Element(re.split(r"[0-9]+", key)[0])) 1131 if get_el_sp(str(specie)) in el_orb_spec: 1132 if key2 in el_orb_spec[get_el_sp(str(specie))]: 1133 result[spin][i][j][specie][key2] += item2 1134 return result 1135 1136 1137def get_reconstructed_band_structure(list_bs, efermi=None): 1138 """ 1139 This method takes a list of band structures and reconstructs 1140 one band structure object from all of them. 1141 1142 This is typically very useful when you split non self consistent 1143 band structure runs in several independent jobs and want to merge back 1144 the results 1145 1146 Args: 1147 list_bs: A list of BandStructure or BandStructureSymmLine objects. 1148 efermi: The Fermi energy of the reconstructed band structure. If 1149 None is assigned an average of all the Fermi energy in each 1150 object in the list_bs is used. 1151 1152 Returns: 1153 A BandStructure or BandStructureSymmLine object (depending on 1154 the type of the list_bs objects) 1155 """ 1156 if efermi is None: 1157 efermi = sum([b.efermi for b in list_bs]) / len(list_bs) 1158 1159 kpoints = [] 1160 labels_dict = {} 1161 rec_lattice = list_bs[0].lattice_rec 1162 nb_bands = min([list_bs[i].nb_bands for i in range(len(list_bs))]) 1163 1164 kpoints = np.concatenate([[k.frac_coords for k in bs.kpoints] for bs in list_bs]) 1165 dicts = [bs.labels_dict for bs in list_bs] 1166 labels_dict = {k: v.frac_coords for d in dicts for k, v in d.items()} 1167 1168 eigenvals = {} 1169 eigenvals[Spin.up] = np.concatenate([bs.bands[Spin.up][:nb_bands] for bs in list_bs], axis=1) 1170 1171 if list_bs[0].is_spin_polarized: 1172 eigenvals[Spin.down] = np.concatenate([bs.bands[Spin.down][:nb_bands] for bs in list_bs], axis=1) 1173 1174 projections = {} 1175 if len(list_bs[0].projections) != 0: 1176 projs = [bs.projections[Spin.up][:nb_bands] for bs in list_bs] 1177 projections[Spin.up] = np.concatenate(projs, axis=1) 1178 1179 if list_bs[0].is_spin_polarized: 1180 projs = [bs.projections[Spin.down][:nb_bands] for bs in list_bs] 1181 projections[Spin.down] = np.concatenate(projs, axis=1) 1182 1183 if isinstance(list_bs[0], BandStructureSymmLine): 1184 return BandStructureSymmLine( 1185 kpoints, 1186 eigenvals, 1187 rec_lattice, 1188 efermi, 1189 labels_dict, 1190 structure=list_bs[0].structure, 1191 projections=projections, 1192 ) 1193 return BandStructure( 1194 kpoints, 1195 eigenvals, 1196 rec_lattice, 1197 efermi, 1198 labels_dict, 1199 structure=list_bs[0].structure, 1200 projections=projections, 1201 ) 1202