1# coding: utf-8 2# Copyright (c) Pymatgen Development Team. 3# Distributed under the terms of the MIT License. 4""" 5This module implements plotter for DOS and band structure. 6""" 7 8import copy 9import itertools 10import logging 11import math 12import warnings 13from collections import Counter, OrderedDict 14 15import matplotlib.lines as mlines 16import numpy as np 17import scipy.interpolate as scint 18from monty.dev import requires 19from monty.json import jsanitize 20 21try: 22 from mayavi import mlab 23except ImportError: 24 mlab = None 25 26from pymatgen.core.periodic_table import Element 27from pymatgen.electronic_structure.bandstructure import BandStructureSymmLine 28from pymatgen.electronic_structure.boltztrap import BoltztrapError 29from pymatgen.electronic_structure.core import OrbitalType, Spin 30from pymatgen.util.plotting import add_fig_kwargs, get_ax3d_fig_plt, pretty_plot 31 32__author__ = "Shyue Ping Ong, Geoffroy Hautier, Anubhav Jain" 33__copyright__ = "Copyright 2012, The Materials Project" 34__version__ = "0.1" 35__maintainer__ = "Shyue Ping Ong" 36__email__ = "shyuep@gmail.com" 37__date__ = "May 1, 2012" 38 39logger = logging.getLogger(__name__) 40 41 42class DosPlotter: 43 """ 44 Class for plotting DOSs. Note that the interface is extremely flexible 45 given that there are many different ways in which people want to view 46 DOS. The typical usage is:: 47 48 # Initializes plotter with some optional args. Defaults are usually 49 # fine, 50 plotter = DosPlotter() 51 52 # Adds a DOS with a label. 53 plotter.add_dos("Total DOS", dos) 54 55 # Alternatively, you can add a dict of DOSs. This is the typical 56 # form returned by CompleteDos.get_spd/element/others_dos(). 57 plotter.add_dos_dict({"dos1": dos1, "dos2": dos2}) 58 plotter.add_dos_dict(complete_dos.get_spd_dos()) 59 """ 60 61 def __init__(self, zero_at_efermi=True, stack=False, sigma=None): 62 """ 63 Args: 64 zero_at_efermi: Whether to shift all Dos to have zero energy at the 65 fermi energy. Defaults to True. 66 stack: Whether to plot the DOS as a stacked area graph 67 key_sort_func: function used to sort the dos_dict keys. 68 sigma: A float specifying a standard deviation for Gaussian smearing 69 the DOS for nicer looking plots. Defaults to None for no 70 smearing. 71 """ 72 self.zero_at_efermi = zero_at_efermi 73 self.stack = stack 74 self.sigma = sigma 75 self._doses = OrderedDict() 76 77 def add_dos(self, label, dos): 78 """ 79 Adds a dos for plotting. 80 81 Args: 82 label: 83 label for the DOS. Must be unique. 84 dos: 85 Dos object 86 """ 87 energies = dos.energies - dos.efermi if self.zero_at_efermi else dos.energies 88 densities = dos.get_smeared_densities(self.sigma) if self.sigma else dos.densities 89 efermi = dos.efermi 90 self._doses[label] = { 91 "energies": energies, 92 "densities": densities, 93 "efermi": efermi, 94 } 95 96 def add_dos_dict(self, dos_dict, key_sort_func=None): 97 """ 98 Add a dictionary of doses, with an optional sorting function for the 99 keys. 100 101 Args: 102 dos_dict: dict of {label: Dos} 103 key_sort_func: function used to sort the dos_dict keys. 104 """ 105 if key_sort_func: 106 keys = sorted(dos_dict.keys(), key=key_sort_func) 107 else: 108 keys = dos_dict.keys() 109 for label in keys: 110 self.add_dos(label, dos_dict[label]) 111 112 def get_dos_dict(self): 113 """ 114 Returns the added doses as a json-serializable dict. Note that if you 115 have specified smearing for the DOS plot, the densities returned will 116 be the smeared densities, not the original densities. 117 118 Returns: 119 dict: Dict of dos data. Generally of the form 120 {label: {'energies':..., 'densities': {'up':...}, 'efermi':efermi}} 121 """ 122 return jsanitize(self._doses) 123 124 def get_plot(self, xlim=None, ylim=None): 125 """ 126 Get a matplotlib plot showing the DOS. 127 128 Args: 129 xlim: Specifies the x-axis limits. Set to None for automatic 130 determination. 131 ylim: Specifies the y-axis limits. 132 """ 133 134 ncolors = max(3, len(self._doses)) 135 ncolors = min(9, ncolors) 136 137 import palettable 138 139 # pylint: disable=E1101 140 colors = palettable.colorbrewer.qualitative.Set1_9.mpl_colors 141 142 y = None 143 alldensities = [] 144 allenergies = [] 145 plt = pretty_plot(12, 8) 146 147 # Note that this complicated processing of energies is to allow for 148 # stacked plots in matplotlib. 149 for key, dos in self._doses.items(): 150 energies = dos["energies"] 151 densities = dos["densities"] 152 if not y: 153 y = { 154 Spin.up: np.zeros(energies.shape), 155 Spin.down: np.zeros(energies.shape), 156 } 157 newdens = {} 158 for spin in [Spin.up, Spin.down]: 159 if spin in densities: 160 if self.stack: 161 y[spin] += densities[spin] 162 newdens[spin] = y[spin].copy() 163 else: 164 newdens[spin] = densities[spin] 165 allenergies.append(energies) 166 alldensities.append(newdens) 167 168 keys = list(self._doses.keys()) 169 keys.reverse() 170 alldensities.reverse() 171 allenergies.reverse() 172 allpts = [] 173 for i, key in enumerate(keys): 174 x = [] 175 y = [] 176 for spin in [Spin.up, Spin.down]: 177 if spin in alldensities[i]: 178 densities = list(int(spin) * alldensities[i][spin]) 179 energies = list(allenergies[i]) 180 if spin == Spin.down: 181 energies.reverse() 182 densities.reverse() 183 x.extend(energies) 184 y.extend(densities) 185 allpts.extend(list(zip(x, y))) 186 if self.stack: 187 plt.fill(x, y, color=colors[i % ncolors], label=str(key)) 188 else: 189 plt.plot(x, y, color=colors[i % ncolors], label=str(key), linewidth=3) 190 if not self.zero_at_efermi: 191 ylim = plt.ylim() 192 plt.plot( 193 [self._doses[key]["efermi"], self._doses[key]["efermi"]], 194 ylim, 195 color=colors[i % ncolors], 196 linestyle="--", 197 linewidth=2, 198 ) 199 200 if xlim: 201 plt.xlim(xlim) 202 if ylim: 203 plt.ylim(ylim) 204 else: 205 xlim = plt.xlim() 206 relevanty = [p[1] for p in allpts if xlim[0] < p[0] < xlim[1]] 207 plt.ylim((min(relevanty), max(relevanty))) 208 209 if self.zero_at_efermi: 210 ylim = plt.ylim() 211 plt.plot([0, 0], ylim, "k--", linewidth=2) 212 213 plt.xlabel("Energies (eV)") 214 plt.ylabel("Density of states") 215 216 plt.axhline(y=0, color="k", linestyle="--", linewidth=2) 217 plt.legend() 218 leg = plt.gca().get_legend() 219 ltext = leg.get_texts() # all the text.Text instance in the legend 220 plt.setp(ltext, fontsize=30) 221 plt.tight_layout() 222 return plt 223 224 def save_plot(self, filename, img_format="eps", xlim=None, ylim=None): 225 """ 226 Save matplotlib plot to a file. 227 228 Args: 229 filename: Filename to write to. 230 img_format: Image format to use. Defaults to EPS. 231 xlim: Specifies the x-axis limits. Set to None for automatic 232 determination. 233 ylim: Specifies the y-axis limits. 234 """ 235 plt = self.get_plot(xlim, ylim) 236 plt.savefig(filename, format=img_format) 237 238 def show(self, xlim=None, ylim=None): 239 """ 240 Show the plot using matplotlib. 241 242 Args: 243 xlim: Specifies the x-axis limits. Set to None for automatic 244 determination. 245 ylim: Specifies the y-axis limits. 246 """ 247 plt = self.get_plot(xlim, ylim) 248 plt.show() 249 250 251class BSPlotter: 252 """ 253 Class to plot or get data to facilitate the plot of band structure objects. 254 """ 255 256 def __init__(self, bs): 257 """ 258 Args: 259 bs: A BandStructureSymmLine object. 260 """ 261 262 self._bs = [] 263 self._nb_bands = [] 264 265 self.add_bs(bs) 266 267 def _check_bs_kpath(self, bs_list): 268 """ 269 Helper method that chack the all the band objs in bs_list are 270 BandStructureSymmLine objs and they all have the same kpath. 271 """ 272 273 # check obj type 274 for bs in bs_list: 275 if not isinstance(bs, BandStructureSymmLine): 276 raise ValueError( 277 "BSPlotter only works with BandStructureSymmLine objects. " 278 "A BandStructure object (on a uniform grid for instance and " 279 "not along symmetry lines won't work)" 280 ) 281 282 # check the kpath 283 if len(bs_list) == 1 and self._bs == []: 284 return True 285 286 if self._bs == []: 287 kpath_ref = [br["name"] for br in bs_list[0].branches] 288 else: 289 kpath_ref = [br["name"] for br in self._bs[0].branches] 290 291 for bs in bs_list: 292 if kpath_ref != [br["name"] for br in bs.branches]: 293 msg = ( 294 f"BSPlotter only works with BandStructureSymmLine " 295 f"which have the same kpath. \n{bs} has a different kpath!" 296 ) 297 raise ValueError(msg) 298 299 return True 300 301 def add_bs(self, bs): 302 """ 303 Method to add bands objects to the BSPlotter 304 """ 305 if not isinstance(bs, list): 306 bs = [bs] 307 308 if self._check_bs_kpath(bs): 309 self._bs.extend(bs) 310 # TODO: come with an intelligent way to cut the highest unconverged 311 # bands 312 self._nb_bands.extend([b.nb_bands for b in bs]) 313 314 def _maketicks(self, plt): 315 """ 316 utility private method to add ticks to a band structure 317 """ 318 ticks = self.get_ticks() 319 # Sanitize only plot the uniq values 320 uniq_d = [] 321 uniq_l = [] 322 temp_ticks = list(zip(ticks["distance"], ticks["label"])) 323 for i, t in enumerate(temp_ticks): 324 if i == 0: 325 uniq_d.append(t[0]) 326 uniq_l.append(t[1]) 327 logger.debug("Adding label {l} at {d}".format(l=t[0], d=t[1])) 328 else: 329 if t[1] == temp_ticks[i - 1][1]: 330 logger.debug("Skipping label {i}".format(i=t[1])) 331 else: 332 logger.debug("Adding label {l} at {d}".format(l=t[0], d=t[1])) 333 uniq_d.append(t[0]) 334 uniq_l.append(t[1]) 335 336 logger.debug("Unique labels are %s" % list(zip(uniq_d, uniq_l))) 337 plt.gca().set_xticks(uniq_d) 338 plt.gca().set_xticklabels(uniq_l) 339 340 for i in range(len(ticks["label"])): 341 if ticks["label"][i] is not None: 342 # don't print the same label twice 343 if i != 0: 344 if ticks["label"][i] == ticks["label"][i - 1]: 345 logger.debug("already print label... " "skipping label {i}".format(i=ticks["label"][i])) 346 else: 347 logger.debug( 348 "Adding a line at {d}" " for label {l}".format(d=ticks["distance"][i], l=ticks["label"][i]) 349 ) 350 plt.axvline(ticks["distance"][i], color="k") 351 else: 352 logger.debug( 353 "Adding a line at {d} for label {l}".format(d=ticks["distance"][i], l=ticks["label"][i]) 354 ) 355 plt.axvline(ticks["distance"][i], color="k") 356 return plt 357 358 @staticmethod 359 def _get_branch_steps(branches): 360 """ 361 Method to find discontinuous branches 362 """ 363 steps = [0] 364 for b1, b2 in zip(branches[:-1], branches[1:]): 365 if b2["name"].split("-")[0] != b1["name"].split("-")[-1]: 366 steps.append(b2["start_index"]) 367 steps.append(branches[-1]["end_index"] + 1) 368 return steps 369 370 @staticmethod 371 def _rescale_distances(bs_ref, bs): 372 """ 373 Method to rescale distances of bs to distances in bs_ref. 374 This is used for plotting two bandstructures (same k-path) 375 of different materials. 376 """ 377 scaled_distances = [] 378 379 for br, br2 in zip(bs_ref.branches, bs.branches): 380 s = br["start_index"] 381 e = br["end_index"] 382 max_d = bs_ref.distance[e] 383 min_d = bs_ref.distance[s] 384 s2 = br2["start_index"] 385 e2 = br2["end_index"] 386 np = e2 - s2 387 if np == 0: 388 # it deals with single point branches 389 scaled_distances.extend([min_d]) 390 else: 391 scaled_distances.extend([(max_d - min_d) / np * i + min_d for i in range(np + 1)]) 392 393 return scaled_distances 394 395 def bs_plot_data(self, zero_to_efermi=True, bs=None, bs_ref=None, split_branches=True): 396 """ 397 Get the data nicely formatted for a plot 398 399 Args: 400 zero_to_efermi: Automatically subtract off the Fermi energy from the 401 eigenvalues and plot. 402 bs: the bandstructure to get the data from. If not provided, the first 403 one in the self._bs list will be used. 404 bs_ref: is the bandstructure of reference when a rescale of the distances 405 is need to plot multiple bands 406 split_branches: if True distances and energies are split according to the 407 branches. If False distances and energies are split only where branches 408 are discontinuous (reducing the number of lines to plot). 409 410 Returns: 411 dict: A dictionary of the following format: 412 ticks: A dict with the 'distances' at which there is a kpoint (the 413 x axis) and the labels (None if no label). 414 energy: A dict storing bands for spin up and spin down data 415 {Spin:[np.array(nb_bands,kpoints),...]} as a list of discontinuous kpath 416 of energies. The energy of multiple continuous branches are stored together. 417 vbm: A list of tuples (distance,energy) marking the vbms. The 418 energies are shifted with respect to the fermi level is the 419 option has been selected. 420 cbm: A list of tuples (distance,energy) marking the cbms. The 421 energies are shifted with respect to the fermi level is the 422 option has been selected. 423 lattice: The reciprocal lattice. 424 zero_energy: This is the energy used as zero for the plot. 425 band_gap:A string indicating the band gap and its nature (empty if 426 it's a metal). 427 is_metal: True if the band structure is metallic (i.e., there is at 428 least one band crossing the fermi level). 429 """ 430 431 if bs is None: 432 if isinstance(self._bs, list): 433 # if BSPlotter 434 bs = self._bs[0] 435 else: 436 # if BSPlotterProjected 437 bs = self._bs 438 439 energies = {str(sp): [] for sp in bs.bands.keys()} 440 441 bs_is_metal = bs.is_metal() 442 443 if not bs_is_metal: 444 vbm = bs.get_vbm() 445 cbm = bs.get_cbm() 446 447 zero_energy = 0.0 448 if zero_to_efermi: 449 if bs_is_metal: 450 zero_energy = bs.efermi 451 else: 452 zero_energy = vbm["energy"] 453 454 # rescale distances when a bs_ref is given as reference, 455 # and when bs and bs_ref have different points in branches. 456 # Usually bs_ref is the first one in self._bs list is bs_ref 457 distances = bs.distance 458 if bs_ref is not None: 459 if bs_ref.branches != bs.branches: 460 distances = self._rescale_distances(bs_ref, bs) 461 462 if split_branches: 463 steps = [br["end_index"] + 1 for br in bs.branches][:-1] 464 else: 465 # join all the continuous branches 466 # to reduce the total number of branches to plot 467 steps = self._get_branch_steps(bs.branches)[1:-1] 468 469 distances = np.split(distances, steps) 470 for sp in bs.bands.keys(): 471 energies[str(sp)] = np.hsplit(bs.bands[sp] - zero_energy, steps) 472 473 ticks = self.get_ticks() 474 475 vbm_plot = [] 476 cbm_plot = [] 477 bg_str = "" 478 479 if not bs_is_metal: 480 for index in cbm["kpoint_index"]: 481 cbm_plot.append( 482 ( 483 bs.distance[index], 484 cbm["energy"] - zero_energy if zero_to_efermi else cbm["energy"], 485 ) 486 ) 487 488 for index in vbm["kpoint_index"]: 489 vbm_plot.append( 490 ( 491 bs.distance[index], 492 vbm["energy"] - zero_energy if zero_to_efermi else vbm["energy"], 493 ) 494 ) 495 496 bg = bs.get_band_gap() 497 direct = "Indirect" 498 if bg["direct"]: 499 direct = "Direct" 500 501 bg_str = "{} {} bandgap = {}".format(direct, bg["transition"], bg["energy"]) 502 503 return { 504 "ticks": ticks, 505 "distances": distances, 506 "energy": energies, 507 "vbm": vbm_plot, 508 "cbm": cbm_plot, 509 "lattice": bs.lattice_rec.as_dict(), 510 "zero_energy": zero_energy, 511 "is_metal": bs_is_metal, 512 "band_gap": bg_str, 513 } 514 515 @staticmethod 516 def _interpolate_bands(distances, energies, smooth_tol=0, smooth_k=3, smooth_np=100): 517 """ 518 Method that interpolates the provided energies using B-splines as 519 implemented in scipy.interpolate. Distances and energies has to provided 520 already split into pieces (branches work good, for longer segments 521 the interpolation may fail). 522 523 Interpolation failure can be caused by trying to fit an entire 524 band with one spline rather than fitting with piecewise splines 525 (splines are ill-suited to fit discontinuities). 526 527 The number of splines used to fit a band is determined by the 528 number of branches (high symmetry lines) defined in the 529 BandStructureSymmLine object (see BandStructureSymmLine._branches). 530 """ 531 532 int_energies, int_distances = [], [] 533 smooth_k_orig = smooth_k 534 535 for dist, ene in zip(distances, energies): 536 br_en = [] 537 warning_nan = ( 538 f"WARNING! Distance / branch, band cannot be " 539 f"interpolated. See full warning in source. " 540 f"If this is not a mistake, try increasing " 541 f"smooth_tol. Current smooth_tol is {smooth_tol}." 542 ) 543 544 warning_m_fewer_k = ( 545 f"The number of points (m) has to be higher then " 546 f"the order (k) of the splines. In this branch {len(dist)} " 547 f"points are found, while k is set to {smooth_k}. " 548 f"Smooth_k will be reduced to {smooth_k - 1} for this branch." 549 ) 550 551 # skip single point branches 552 if len(dist) in (2, 3): 553 # reducing smooth_k when the number 554 # of points are fewer then k 555 smooth_k = len(dist) - 1 556 warnings.warn(warning_m_fewer_k) 557 elif len(dist) == 1: 558 warnings.warn("Skipping single point branch") 559 continue 560 561 int_distances.append(np.linspace(dist[0], dist[-1], smooth_np)) 562 563 for ien in ene: 564 tck = scint.splrep(dist, ien, s=smooth_tol, k=smooth_k) 565 566 br_en.append(scint.splev(int_distances[-1], tck)) 567 568 smooth_k = smooth_k_orig 569 570 int_energies.append(np.vstack(br_en)) 571 572 if np.any(np.isnan(int_energies[-1])): 573 warnings.warn(warning_nan) 574 575 return int_distances, int_energies 576 577 def get_plot( 578 self, 579 zero_to_efermi=True, 580 ylim=None, 581 smooth=False, 582 vbm_cbm_marker=False, 583 smooth_tol=0, 584 smooth_k=3, 585 smooth_np=100, 586 bs_labels=[], 587 ): 588 """ 589 Get a matplotlib object for the bandstructures plot. 590 Multiple bandstructure objs are plotted together if they have the 591 same high symm path. 592 593 Args: 594 zero_to_efermi: Automatically subtract off the Fermi energy from 595 the eigenvalues and plot (E-Ef). 596 ylim: Specify the y-axis (energy) limits; by default None let 597 the code choose. It is vbm-4 and cbm+4 if insulator 598 efermi-10 and efermi+10 if metal 599 smooth (bool or list(bools)): interpolates the bands by a spline cubic. 600 A single bool values means to interpolate all the bandstructure objs. 601 A list of bools allows to select the bandstructure obs to interpolate. 602 smooth_tol (float) : tolerance for fitting spline to band data. 603 Default is None such that no tolerance will be used. 604 smooth_k (int): degree of splines 1<k<5 605 smooth_np (int): number of interpolated points per each branch. 606 bs_labels: labels for each band for the plot legend. 607 """ 608 plt = pretty_plot(12, 8) 609 610 if isinstance(smooth, bool): 611 smooth = [smooth] * len(self._bs) 612 613 handles = [] 614 vbm_min, cbm_max = [], [] 615 616 colors = list(plt.rcParams["axes.prop_cycle"].by_key().values())[0] 617 for ibs, bs in enumerate(self._bs): 618 619 # set first bs in the list as ref for rescaling the distances of the other bands 620 bs_ref = self._bs[0] if len(self._bs) > 1 and ibs > 0 else None 621 622 if smooth[ibs]: 623 # interpolation works good on short segments like branches 624 data = self.bs_plot_data(zero_to_efermi, bs, bs_ref, split_branches=True) 625 else: 626 data = self.bs_plot_data(zero_to_efermi, bs, bs_ref, split_branches=False) 627 628 # remember if one bs is a metal for setting the ylim later 629 one_is_metal = False 630 if not one_is_metal and data["is_metal"]: 631 one_is_metal = data["is_metal"] 632 633 # remember all the cbm and vbm for setting the ylim later 634 if not data["is_metal"]: 635 cbm_max.append(data["cbm"][0][1]) 636 vbm_min.append(data["vbm"][0][1]) 637 638 for sp in bs.bands.keys(): 639 ls = "-" if str(sp) == "1" else "--" 640 641 if bs_labels != []: 642 bs_label = f"{bs_labels[ibs]} {sp.name}" 643 else: 644 bs_label = f"Band {ibs} {sp.name}" 645 646 handles.append(mlines.Line2D([], [], lw=2, ls=ls, color=colors[ibs], label=bs_label)) 647 648 distances, energies = data["distances"], data["energy"][str(sp)] 649 650 if smooth[ibs]: 651 distances, energies = self._interpolate_bands( 652 distances, 653 energies, 654 smooth_tol=smooth_tol, 655 smooth_k=smooth_k, 656 smooth_np=smooth_np, 657 ) 658 # join all branches together 659 distances = np.hstack(distances) 660 energies = np.hstack(energies) 661 # split only discontinuous branches 662 steps = self._get_branch_steps(bs.branches)[1:-1] 663 distances = np.split(distances, steps) 664 energies = np.hsplit(energies, steps) 665 666 for dist, ene in zip(distances, energies): 667 plt.plot(dist, ene.T, c=colors[ibs], ls=ls) 668 669 # plot markers for vbm and cbm 670 if vbm_cbm_marker: 671 for cbm in data["cbm"]: 672 plt.scatter(cbm[0], cbm[1], color="r", marker="o", s=100) 673 for vbm in data["vbm"]: 674 plt.scatter(vbm[0], vbm[1], color="g", marker="o", s=100) 675 676 # Draw Fermi energy, only if not the zero 677 if not zero_to_efermi: 678 ef = bs.efermi 679 plt.axhline(ef, lw=2, ls="-.", color=colors[ibs]) 680 681 # defaults for ylim 682 e_min = -4 683 e_max = 4 684 if one_is_metal: 685 e_min = -10 686 e_max = 10 687 688 if ylim is None: 689 if zero_to_efermi: 690 if one_is_metal: 691 # Plot A Metal 692 plt.ylim(e_min, e_max) 693 else: 694 plt.ylim(e_min, max(cbm_max) + e_max) 695 else: 696 all_efermi = [b.efermi for b in self._bs] 697 ll = min([min(vbm_min), min(all_efermi)]) 698 hh = max([max(cbm_max), max(all_efermi)]) 699 plt.ylim(ll + e_min, hh + e_max) 700 else: 701 plt.ylim(ylim) 702 703 self._maketicks(plt) 704 705 # Main X and Y Labels 706 plt.xlabel(r"$\mathrm{Wave\ Vector}$", fontsize=30) 707 ylabel = r"$\mathrm{E\ -\ E_f\ (eV)}$" if zero_to_efermi else r"$\mathrm{Energy\ (eV)}$" 708 plt.ylabel(ylabel, fontsize=30) 709 710 # X range (K) 711 # last distance point 712 x_max = data["distances"][-1][-1] 713 plt.xlim(0, x_max) 714 715 plt.legend(handles=handles) 716 717 plt.tight_layout() 718 719 # auto tight_layout when resizing or pressing t 720 def fix_layout(event): 721 if (event.name == "key_press_event" and event.key == "t") or event.name == "resize_event": 722 plt.gcf().tight_layout() 723 plt.gcf().canvas.draw() 724 725 plt.gcf().canvas.mpl_connect("key_press_event", fix_layout) 726 plt.gcf().canvas.mpl_connect("resize_event", fix_layout) 727 728 return plt 729 730 def show(self, zero_to_efermi=True, ylim=None, smooth=False, smooth_tol=None): 731 """ 732 Show the plot using matplotlib. 733 734 Args: 735 zero_to_efermi: Automatically subtract off the Fermi energy from 736 the eigenvalues and plot (E-Ef). 737 ylim: Specify the y-axis (energy) limits; by default None let 738 the code choose. It is vbm-4 and cbm+4 if insulator 739 efermi-10 and efermi+10 if metal 740 smooth: interpolates the bands by a spline cubic 741 smooth_tol (float) : tolerance for fitting spline to band data. 742 Default is None such that no tolerance will be used. 743 """ 744 plt = self.get_plot(zero_to_efermi, ylim, smooth) 745 plt.show() 746 747 def save_plot(self, filename, img_format="eps", ylim=None, zero_to_efermi=True, smooth=False): 748 """ 749 Save matplotlib plot to a file. 750 751 Args: 752 filename: Filename to write to. 753 img_format: Image format to use. Defaults to EPS. 754 ylim: Specifies the y-axis limits. 755 """ 756 plt = self.get_plot(ylim=ylim, zero_to_efermi=zero_to_efermi, smooth=smooth) 757 plt.savefig(filename, format=img_format) 758 plt.close() 759 760 def get_ticks(self): 761 """ 762 Get all ticks and labels for a band structure plot. 763 764 Returns: 765 dict: A dictionary with 'distance': a list of distance at which 766 ticks should be set and 'label': a list of label for each of those 767 ticks. 768 """ 769 bs = self._bs[0] if isinstance(self._bs, list) else self._bs 770 ticks, distance = [], [] 771 for br in bs.branches: 772 s, e = br["start_index"], br["end_index"] 773 774 labels = br["name"].split("-") 775 776 # skip those branches with only one point 777 if labels[0] == labels[1]: 778 continue 779 780 # add latex $$ 781 for i, l in enumerate(labels): 782 if l.startswith("\\") or "_" in l: 783 labels[i] = "$" + l + "$" 784 785 # If next branch is not continuous, 786 # join the firts lbl to the previous tick label 787 # and add the second lbl to ticks list 788 # otherwise add to ticks list both new labels. 789 # Similar for distances. 790 if ticks != [] and labels[0] != ticks[-1]: 791 ticks[-1] += "$\\mid$" + labels[0] 792 ticks.append(labels[1]) 793 distance.append(bs.distance[e]) 794 else: 795 ticks.extend(labels) 796 distance.extend([bs.distance[s], bs.distance[e]]) 797 798 return {"distance": distance, "label": ticks} 799 800 def get_ticks_old(self): 801 """ 802 Get all ticks and labels for a band structure plot. 803 804 Returns: 805 dict: A dictionary with 'distance': a list of distance at which 806 ticks should be set and 'label': a list of label for each of those 807 ticks. 808 """ 809 bs = self._bs[0] 810 tick_distance = [] 811 tick_labels = [] 812 previous_label = bs.kpoints[0].label 813 previous_branch = bs.branches[0]["name"] 814 for i, c in enumerate(bs.kpoints): 815 if c.label is not None: 816 tick_distance.append(bs.distance[i]) 817 this_branch = None 818 for b in bs.branches: 819 if b["start_index"] <= i <= b["end_index"]: 820 this_branch = b["name"] 821 break 822 if c.label != previous_label and previous_branch != this_branch: 823 label1 = c.label 824 if label1.startswith("\\") or label1.find("_") != -1: 825 label1 = "$" + label1 + "$" 826 label0 = previous_label 827 if label0.startswith("\\") or label0.find("_") != -1: 828 label0 = "$" + label0 + "$" 829 tick_labels.pop() 830 tick_distance.pop() 831 tick_labels.append(label0 + "$\\mid$" + label1) 832 else: 833 if c.label.startswith("\\") or c.label.find("_") != -1: 834 tick_labels.append("$" + c.label + "$") 835 else: 836 tick_labels.append(c.label) 837 previous_label = c.label 838 previous_branch = this_branch 839 return {"distance": tick_distance, "label": tick_labels} 840 841 def plot_compare(self, other_plotter, legend=True): 842 """ 843 plot two band structure for comparison. One is in red the other in blue 844 (no difference in spins). The two band structures need to be defined 845 on the same symmetry lines! and the distance between symmetry lines is 846 the one of the band structure used to build the BSPlotter 847 848 Args: 849 another band structure object defined along the same symmetry lines 850 851 Returns: 852 a matplotlib object with both band structures 853 854 """ 855 warnings.warn("Deprecated method. " "Use BSPlotter([sbs1,sbs2,...]).get_plot() instead.") 856 857 # TODO: add exception if the band structures are not compatible 858 import matplotlib.lines as mlines 859 860 plt = self.get_plot() 861 data_orig = self.bs_plot_data() 862 data = other_plotter.bs_plot_data() 863 band_linewidth = 1 864 for i in range(other_plotter._nb_bands): 865 for d in range(len(data_orig["distances"])): 866 plt.plot( 867 data_orig["distances"][d], 868 [e[str(Spin.up)][i] for e in data["energy"]][d], 869 "c-", 870 linewidth=band_linewidth, 871 ) 872 if other_plotter._bs.is_spin_polarized: 873 plt.plot( 874 data_orig["distances"][d], 875 [e[str(Spin.down)][i] for e in data["energy"]][d], 876 "m--", 877 linewidth=band_linewidth, 878 ) 879 if legend: 880 handles = [ 881 mlines.Line2D([], [], linewidth=2, color="b", label="bs 1 up"), 882 mlines.Line2D([], [], linewidth=2, color="r", label="bs 1 down", linestyle="--"), 883 mlines.Line2D([], [], linewidth=2, color="c", label="bs 2 up"), 884 mlines.Line2D([], [], linewidth=2, color="m", linestyle="--", label="bs 2 down"), 885 ] 886 887 plt.legend(handles=handles) 888 return plt 889 890 def plot_brillouin(self): 891 """ 892 plot the Brillouin zone 893 """ 894 895 # get labels and lines 896 labels = {} 897 for k in self._bs[0].kpoints: 898 if k.label: 899 labels[k.label] = k.frac_coords 900 901 lines = [] 902 for b in self._bs[0].branches: 903 lines.append( 904 [ 905 self._bs[0].kpoints[b["start_index"]].frac_coords, 906 self._bs[0].kpoints[b["end_index"]].frac_coords, 907 ] 908 ) 909 910 plot_brillouin_zone(self._bs[0].lattice_rec, lines=lines, labels=labels) 911 912 913class BSPlotterProjected(BSPlotter): 914 """ 915 Class to plot or get data to facilitate the plot of band structure objects 916 projected along orbitals, elements or sites. 917 """ 918 919 def __init__(self, bs): 920 """ 921 Args: 922 bs: A BandStructureSymmLine object with projections. 923 """ 924 if isinstance(bs, list): 925 warnings.warn( 926 "Multiple bands are not handled by BSPlotterProjected." "The first band in the list will be considered" 927 ) 928 bs = bs[0] 929 930 if len(bs.projections) == 0: 931 raise ValueError("try to plot projections on a band structure without any") 932 933 self._bs = bs 934 self._nb_bands = bs.nb_bands 935 936 def _get_projections_by_branches(self, dictio): 937 proj = self._bs.get_projections_on_elements_and_orbitals(dictio) 938 proj_br = [] 939 for b in self._bs.branches: 940 if self._bs.is_spin_polarized: 941 proj_br.append( 942 { 943 str(Spin.up): [[] for l in range(self._nb_bands)], 944 str(Spin.down): [[] for l in range(self._nb_bands)], 945 } 946 ) 947 else: 948 proj_br.append({str(Spin.up): [[] for l in range(self._nb_bands)]}) 949 950 for i in range(self._nb_bands): 951 for j in range(b["start_index"], b["end_index"] + 1): 952 proj_br[-1][str(Spin.up)][i].append( 953 {e: {o: proj[Spin.up][i][j][e][o] for o in proj[Spin.up][i][j][e]} for e in proj[Spin.up][i][j]} 954 ) 955 if self._bs.is_spin_polarized: 956 for b in self._bs.branches: 957 for i in range(self._nb_bands): 958 for j in range(b["start_index"], b["end_index"] + 1): 959 proj_br[-1][str(Spin.down)][i].append( 960 { 961 e: {o: proj[Spin.down][i][j][e][o] for o in proj[Spin.down][i][j][e]} 962 for e in proj[Spin.down][i][j] 963 } 964 ) 965 return proj_br 966 967 def get_projected_plots_dots(self, dictio, zero_to_efermi=True, ylim=None, vbm_cbm_marker=False): 968 """ 969 Method returning a plot composed of subplots along different elements 970 and orbitals. 971 972 Args: 973 dictio: The element and orbitals you want a projection on. The 974 format is {Element:[Orbitals]} for instance 975 {'Cu':['d','s'],'O':['p']} will give projections for Cu on 976 d and s orbitals and on oxygen p. 977 If you use this class to plot LobsterBandStructureSymmLine, 978 the orbitals are named as in the FATBAND filename, e.g. 979 "2p" or "2p_x" 980 981 Returns: 982 a pylab object with different subfigures for each projection 983 The blue and red colors are for spin up and spin down. 984 The bigger the red or blue dot in the band structure the higher 985 character for the corresponding element and orbital. 986 """ 987 band_linewidth = 1.0 988 fig_cols = len(dictio) * 100 989 fig_rows = max([len(v) for v in dictio.values()]) * 10 990 proj = self._get_projections_by_branches(dictio) 991 data = self.bs_plot_data(zero_to_efermi) 992 plt = pretty_plot(12, 8) 993 e_min = -4 994 e_max = 4 995 if self._bs.is_metal(): 996 e_min = -10 997 e_max = 10 998 count = 1 999 1000 for el in dictio: 1001 for o in dictio[el]: 1002 plt.subplot(fig_rows + fig_cols + count) 1003 self._maketicks(plt) 1004 for b in range(len(data["distances"])): 1005 for i in range(self._nb_bands): 1006 plt.plot( 1007 data["distances"][b], 1008 data["energy"][str(Spin.up)][b][i], 1009 "b-", 1010 linewidth=band_linewidth, 1011 ) 1012 if self._bs.is_spin_polarized: 1013 plt.plot( 1014 data["distances"][b], 1015 data["energy"][str(Spin.down)][b][i], 1016 "r--", 1017 linewidth=band_linewidth, 1018 ) 1019 for j in range(len(data["energy"][str(Spin.up)][b][i])): 1020 plt.plot( 1021 data["distances"][b][j], 1022 data["energy"][str(Spin.down)][b][i][j], 1023 "ro", 1024 markersize=proj[b][str(Spin.down)][i][j][str(el)][o] * 15.0, 1025 ) 1026 for j in range(len(data["energy"][str(Spin.up)][b][i])): 1027 plt.plot( 1028 data["distances"][b][j], 1029 data["energy"][str(Spin.up)][b][i][j], 1030 "bo", 1031 markersize=proj[b][str(Spin.up)][i][j][str(el)][o] * 15.0, 1032 ) 1033 if ylim is None: 1034 if self._bs.is_metal(): 1035 if zero_to_efermi: 1036 plt.ylim(e_min, e_max) 1037 else: 1038 plt.ylim(self._bs.efermi + e_min, self._bs.efermi + e_max) 1039 else: 1040 if vbm_cbm_marker: 1041 for cbm in data["cbm"]: 1042 plt.scatter(cbm[0], cbm[1], color="r", marker="o", s=100) 1043 1044 for vbm in data["vbm"]: 1045 plt.scatter(vbm[0], vbm[1], color="g", marker="o", s=100) 1046 1047 plt.ylim(data["vbm"][0][1] + e_min, data["cbm"][0][1] + e_max) 1048 else: 1049 plt.ylim(ylim) 1050 plt.title(str(el) + " " + str(o)) 1051 count += 1 1052 return plt 1053 1054 def get_elt_projected_plots(self, zero_to_efermi=True, ylim=None, vbm_cbm_marker=False): 1055 """ 1056 Method returning a plot composed of subplots along different elements 1057 1058 Returns: 1059 a pylab object with different subfigures for each projection 1060 The blue and red colors are for spin up and spin down 1061 The bigger the red or blue dot in the band structure the higher 1062 character for the corresponding element and orbital 1063 """ 1064 band_linewidth = 1.0 1065 proj = self._get_projections_by_branches( 1066 {e.symbol: ["s", "p", "d"] for e in self._bs.structure.composition.elements} 1067 ) 1068 data = self.bs_plot_data(zero_to_efermi) 1069 plt = pretty_plot(12, 8) 1070 e_min = -4 1071 e_max = 4 1072 if self._bs.is_metal(): 1073 e_min = -10 1074 e_max = 10 1075 count = 1 1076 for el in self._bs.structure.composition.elements: 1077 plt.subplot(220 + count) 1078 self._maketicks(plt) 1079 for b in range(len(data["distances"])): 1080 for i in range(self._nb_bands): 1081 plt.plot( 1082 data["distances"][b], 1083 data["energy"][str(Spin.up)][b][i], 1084 "-", 1085 color=[192 / 255, 192 / 255, 192 / 255], 1086 linewidth=band_linewidth, 1087 ) 1088 if self._bs.is_spin_polarized: 1089 plt.plot( 1090 data["distances"][b], 1091 data["energy"][str(Spin.down)][b][i], 1092 "--", 1093 color=[128 / 255, 128 / 255, 128 / 255], 1094 linewidth=band_linewidth, 1095 ) 1096 for j in range(len(data["energy"][str(Spin.up)][b][i])): 1097 markerscale = sum( 1098 [ 1099 proj[b][str(Spin.down)][i][j][str(el)][o] 1100 for o in proj[b][str(Spin.down)][i][j][str(el)] 1101 ] 1102 ) 1103 plt.plot( 1104 data["distances"][b][j], 1105 data["energy"][str(Spin.down)][b][i][j], 1106 "bo", 1107 markersize=markerscale * 15.0, 1108 color=[ 1109 markerscale, 1110 0.3 * markerscale, 1111 0.4 * markerscale, 1112 ], 1113 ) 1114 for j in range(len(data["energy"][str(Spin.up)][b][i])): 1115 markerscale = sum( 1116 [proj[b][str(Spin.up)][i][j][str(el)][o] for o in proj[b][str(Spin.up)][i][j][str(el)]] 1117 ) 1118 plt.plot( 1119 data["distances"][b][j], 1120 data["energy"][str(Spin.up)][b][i][j], 1121 "o", 1122 markersize=markerscale * 15.0, 1123 color=[markerscale, 0.3 * markerscale, 0.4 * markerscale], 1124 ) 1125 if ylim is None: 1126 if self._bs.is_metal(): 1127 if zero_to_efermi: 1128 plt.ylim(e_min, e_max) 1129 else: 1130 plt.ylim(self._bs.efermi + e_min, self._bs.efermi + e_max) 1131 else: 1132 if vbm_cbm_marker: 1133 for cbm in data["cbm"]: 1134 plt.scatter(cbm[0], cbm[1], color="r", marker="o", s=100) 1135 1136 for vbm in data["vbm"]: 1137 plt.scatter(vbm[0], vbm[1], color="g", marker="o", s=100) 1138 1139 plt.ylim(data["vbm"][0][1] + e_min, data["cbm"][0][1] + e_max) 1140 else: 1141 plt.ylim(ylim) 1142 plt.title(str(el)) 1143 count += 1 1144 1145 return plt 1146 1147 def get_elt_projected_plots_color(self, zero_to_efermi=True, elt_ordered=None): 1148 """ 1149 returns a pylab plot object with one plot where the band structure 1150 line color depends on the character of the band (along different 1151 elements). Each element is associated with red, green or blue 1152 and the corresponding rgb color depending on the character of the band 1153 is used. The method can only deal with binary and ternary compounds 1154 1155 spin up and spin down are differientiated by a '-' and a '--' line 1156 1157 Args: 1158 elt_ordered: A list of Element ordered. The first one is red, 1159 second green, last blue 1160 1161 Returns: 1162 a pylab object 1163 1164 """ 1165 band_linewidth = 3.0 1166 if len(self._bs.structure.composition.elements) > 3: 1167 raise ValueError 1168 if elt_ordered is None: 1169 elt_ordered = self._bs.structure.composition.elements 1170 proj = self._get_projections_by_branches( 1171 {e.symbol: ["s", "p", "d"] for e in self._bs.structure.composition.elements} 1172 ) 1173 data = self.bs_plot_data(zero_to_efermi) 1174 plt = pretty_plot(12, 8) 1175 1176 spins = [Spin.up] 1177 if self._bs.is_spin_polarized: 1178 spins = [Spin.up, Spin.down] 1179 self._maketicks(plt) 1180 for s in spins: 1181 for b in range(len(data["distances"])): 1182 for i in range(self._nb_bands): 1183 for j in range(len(data["energy"][str(s)][b][i]) - 1): 1184 sum_e = 0.0 1185 for el in elt_ordered: 1186 sum_e = sum_e + sum( 1187 [proj[b][str(s)][i][j][str(el)][o] for o in proj[b][str(s)][i][j][str(el)]] 1188 ) 1189 if sum_e == 0.0: 1190 color = [0.0] * len(elt_ordered) 1191 else: 1192 color = [ 1193 sum([proj[b][str(s)][i][j][str(el)][o] for o in proj[b][str(s)][i][j][str(el)]]) / sum_e 1194 for el in elt_ordered 1195 ] 1196 if len(color) == 2: 1197 color.append(0.0) 1198 color[2] = color[1] 1199 color[1] = 0.0 1200 sign = "-" 1201 if s == Spin.down: 1202 sign = "--" 1203 plt.plot( 1204 [data["distances"][b][j], data["distances"][b][j + 1]], 1205 [ 1206 data["energy"][str(s)][b][i][j], 1207 data["energy"][str(s)][b][i][j + 1], 1208 ], 1209 sign, 1210 color=color, 1211 linewidth=band_linewidth, 1212 ) 1213 1214 if self._bs.is_metal(): 1215 if zero_to_efermi: 1216 e_min = -10 1217 e_max = 10 1218 plt.ylim(e_min, e_max) 1219 plt.ylim(self._bs.efermi + e_min, self._bs.efermi + e_max) 1220 else: 1221 plt.ylim(data["vbm"][0][1] - 4.0, data["cbm"][0][1] + 2.0) 1222 return plt 1223 1224 def _get_projections_by_branches_patom_pmorb(self, dictio, dictpa, sum_atoms, sum_morbs, selected_branches): 1225 import copy 1226 1227 setos = { 1228 "s": 0, 1229 "py": 1, 1230 "pz": 2, 1231 "px": 3, 1232 "dxy": 4, 1233 "dyz": 5, 1234 "dz2": 6, 1235 "dxz": 7, 1236 "dx2": 8, 1237 "f_3": 9, 1238 "f_2": 10, 1239 "f_1": 11, 1240 "f0": 12, 1241 "f1": 13, 1242 "f2": 14, 1243 "f3": 15, 1244 } 1245 1246 num_branches = len(self._bs.branches) 1247 if selected_branches is not None: 1248 indices = [] 1249 if not isinstance(selected_branches, list): 1250 raise TypeError("You do not give a correct type of 'selected_branches'. It should be 'list' type.") 1251 if len(selected_branches) == 0: 1252 raise ValueError("The 'selected_branches' is empty. We cannot do anything.") 1253 for index in selected_branches: 1254 if not isinstance(index, int): 1255 raise ValueError( 1256 "You do not give a correct type of index of symmetry lines. It should be " "'int' type" 1257 ) 1258 if index > num_branches or index < 1: 1259 raise ValueError( 1260 "You give a incorrect index of symmetry lines: %s. The index should be in " 1261 "range of [1, %s]." % (str(index), str(num_branches)) 1262 ) 1263 indices.append(index - 1) 1264 else: 1265 indices = range(0, num_branches) 1266 1267 proj = self._bs.projections 1268 proj_br = [] 1269 for index in indices: 1270 b = self._bs.branches[index] 1271 print(b) 1272 if self._bs.is_spin_polarized: 1273 proj_br.append( 1274 { 1275 str(Spin.up): [[] for l in range(self._nb_bands)], 1276 str(Spin.down): [[] for l in range(self._nb_bands)], 1277 } 1278 ) 1279 else: 1280 proj_br.append({str(Spin.up): [[] for l in range(self._nb_bands)]}) 1281 1282 for i in range(self._nb_bands): 1283 for j in range(b["start_index"], b["end_index"] + 1): 1284 edict = {} 1285 for elt in dictpa: 1286 for anum in dictpa[elt]: 1287 edict[elt + str(anum)] = {} 1288 for morb in dictio[elt]: 1289 edict[elt + str(anum)][morb] = proj[Spin.up][i][j][setos[morb]][anum - 1] 1290 proj_br[-1][str(Spin.up)][i].append(edict) 1291 1292 if self._bs.is_spin_polarized: 1293 for i in range(self._nb_bands): 1294 for j in range(b["start_index"], b["end_index"] + 1): 1295 edict = {} 1296 for elt in dictpa: 1297 for anum in dictpa[elt]: 1298 edict[elt + str(anum)] = {} 1299 for morb in dictio[elt]: 1300 edict[elt + str(anum)][morb] = proj[Spin.up][i][j][setos[morb]][anum - 1] 1301 proj_br[-1][str(Spin.down)][i].append(edict) 1302 1303 # Adjusting projections for plot 1304 dictio_d, dictpa_d = self._summarize_keys_for_plot(dictio, dictpa, sum_atoms, sum_morbs) 1305 print("dictio_d: %s" % str(dictio_d)) 1306 print("dictpa_d: %s" % str(dictpa_d)) 1307 1308 if (sum_atoms is None) and (sum_morbs is None): 1309 proj_br_d = copy.deepcopy(proj_br) 1310 else: 1311 proj_br_d = [] 1312 branch = -1 1313 for index in indices: 1314 branch += 1 1315 br = self._bs.branches[index] 1316 if self._bs.is_spin_polarized: 1317 proj_br_d.append( 1318 { 1319 str(Spin.up): [[] for l in range(self._nb_bands)], 1320 str(Spin.down): [[] for l in range(self._nb_bands)], 1321 } 1322 ) 1323 else: 1324 proj_br_d.append({str(Spin.up): [[] for l in range(self._nb_bands)]}) 1325 1326 if (sum_atoms is not None) and (sum_morbs is None): 1327 for i in range(self._nb_bands): 1328 for j in range(br["end_index"] - br["start_index"] + 1): 1329 atoms_morbs = copy.deepcopy(proj_br[branch][str(Spin.up)][i][j]) 1330 edict = {} 1331 for elt in dictpa: 1332 if elt in sum_atoms: 1333 for anum in dictpa_d[elt][:-1]: 1334 edict[elt + anum] = copy.deepcopy(atoms_morbs[elt + anum]) 1335 edict[elt + dictpa_d[elt][-1]] = {} 1336 for morb in dictio[elt]: 1337 sprojection = 0.0 1338 for anum in sum_atoms[elt]: 1339 sprojection += atoms_morbs[elt + str(anum)][morb] 1340 edict[elt + dictpa_d[elt][-1]][morb] = sprojection 1341 else: 1342 for anum in dictpa_d[elt]: 1343 edict[elt + anum] = copy.deepcopy(atoms_morbs[elt + anum]) 1344 proj_br_d[-1][str(Spin.up)][i].append(edict) 1345 if self._bs.is_spin_polarized: 1346 for i in range(self._nb_bands): 1347 for j in range(br["end_index"] - br["start_index"] + 1): 1348 atoms_morbs = copy.deepcopy(proj_br[branch][str(Spin.down)][i][j]) 1349 edict = {} 1350 for elt in dictpa: 1351 if elt in sum_atoms: 1352 for anum in dictpa_d[elt][:-1]: 1353 edict[elt + anum] = copy.deepcopy(atoms_morbs[elt + anum]) 1354 edict[elt + dictpa_d[elt][-1]] = {} 1355 for morb in dictio[elt]: 1356 sprojection = 0.0 1357 for anum in sum_atoms[elt]: 1358 sprojection += atoms_morbs[elt + str(anum)][morb] 1359 edict[elt + dictpa_d[elt][-1]][morb] = sprojection 1360 else: 1361 for anum in dictpa_d[elt]: 1362 edict[elt + anum] = copy.deepcopy(atoms_morbs[elt + anum]) 1363 proj_br_d[-1][str(Spin.down)][i].append(edict) 1364 1365 elif (sum_atoms is None) and (sum_morbs is not None): 1366 for i in range(self._nb_bands): 1367 for j in range(br["end_index"] - br["start_index"] + 1): 1368 atoms_morbs = copy.deepcopy(proj_br[branch][str(Spin.up)][i][j]) 1369 edict = {} 1370 for elt in dictpa: 1371 if elt in sum_morbs: 1372 for anum in dictpa_d[elt]: 1373 edict[elt + anum] = {} 1374 for morb in dictio_d[elt][:-1]: 1375 edict[elt + anum][morb] = atoms_morbs[elt + anum][morb] 1376 sprojection = 0.0 1377 for morb in sum_morbs[elt]: 1378 sprojection += atoms_morbs[elt + anum][morb] 1379 edict[elt + anum][dictio_d[elt][-1]] = sprojection 1380 else: 1381 for anum in dictpa_d[elt]: 1382 edict[elt + anum] = copy.deepcopy(atoms_morbs[elt + anum]) 1383 proj_br_d[-1][str(Spin.up)][i].append(edict) 1384 if self._bs.is_spin_polarized: 1385 for i in range(self._nb_bands): 1386 for j in range(br["end_index"] - br["start_index"] + 1): 1387 atoms_morbs = copy.deepcopy(proj_br[branch][str(Spin.down)][i][j]) 1388 edict = {} 1389 for elt in dictpa: 1390 if elt in sum_morbs: 1391 for anum in dictpa_d[elt]: 1392 edict[elt + anum] = {} 1393 for morb in dictio_d[elt][:-1]: 1394 edict[elt + anum][morb] = atoms_morbs[elt + anum][morb] 1395 sprojection = 0.0 1396 for morb in sum_morbs[elt]: 1397 sprojection += atoms_morbs[elt + anum][morb] 1398 edict[elt + anum][dictio_d[elt][-1]] = sprojection 1399 else: 1400 for anum in dictpa_d[elt]: 1401 edict[elt + anum] = copy.deepcopy(atoms_morbs[elt + anum]) 1402 proj_br_d[-1][str(Spin.down)][i].append(edict) 1403 1404 else: 1405 for i in range(self._nb_bands): 1406 for j in range(br["end_index"] - br["start_index"] + 1): 1407 atoms_morbs = copy.deepcopy(proj_br[branch][str(Spin.up)][i][j]) 1408 edict = {} 1409 for elt in dictpa: 1410 if (elt in sum_atoms) and (elt in sum_morbs): 1411 for anum in dictpa_d[elt][:-1]: 1412 edict[elt + anum] = {} 1413 for morb in dictio_d[elt][:-1]: 1414 edict[elt + anum][morb] = atoms_morbs[elt + anum][morb] 1415 sprojection = 0.0 1416 for morb in sum_morbs[elt]: 1417 sprojection += atoms_morbs[elt + anum][morb] 1418 edict[elt + anum][dictio_d[elt][-1]] = sprojection 1419 1420 edict[elt + dictpa_d[elt][-1]] = {} 1421 for morb in dictio_d[elt][:-1]: 1422 sprojection = 0.0 1423 for anum in sum_atoms[elt]: 1424 sprojection += atoms_morbs[elt + str(anum)][morb] 1425 edict[elt + dictpa_d[elt][-1]][morb] = sprojection 1426 1427 sprojection = 0.0 1428 for anum in sum_atoms[elt]: 1429 for morb in sum_morbs[elt]: 1430 sprojection += atoms_morbs[elt + str(anum)][morb] 1431 edict[elt + dictpa_d[elt][-1]][dictio_d[elt][-1]] = sprojection 1432 1433 elif (elt in sum_atoms) and (elt not in sum_morbs): 1434 for anum in dictpa_d[elt][:-1]: 1435 edict[elt + anum] = copy.deepcopy(atoms_morbs[elt + anum]) 1436 edict[elt + dictpa_d[elt][-1]] = {} 1437 for morb in dictio[elt]: 1438 sprojection = 0.0 1439 for anum in sum_atoms[elt]: 1440 sprojection += atoms_morbs[elt + str(anum)][morb] 1441 edict[elt + dictpa_d[elt][-1]][morb] = sprojection 1442 1443 elif (elt not in sum_atoms) and (elt in sum_morbs): 1444 for anum in dictpa_d[elt]: 1445 edict[elt + anum] = {} 1446 for morb in dictio_d[elt][:-1]: 1447 edict[elt + anum][morb] = atoms_morbs[elt + anum][morb] 1448 sprojection = 0.0 1449 for morb in sum_morbs[elt]: 1450 sprojection += atoms_morbs[elt + anum][morb] 1451 edict[elt + anum][dictio_d[elt][-1]] = sprojection 1452 1453 else: 1454 for anum in dictpa_d[elt]: 1455 edict[elt + anum] = {} 1456 for morb in dictio_d[elt]: 1457 edict[elt + anum][morb] = atoms_morbs[elt + anum][morb] 1458 proj_br_d[-1][str(Spin.up)][i].append(edict) 1459 1460 if self._bs.is_spin_polarized: 1461 for i in range(self._nb_bands): 1462 for j in range(br["end_index"] - br["start_index"] + 1): 1463 atoms_morbs = copy.deepcopy(proj_br[branch][str(Spin.down)][i][j]) 1464 edict = {} 1465 for elt in dictpa: 1466 if (elt in sum_atoms) and (elt in sum_morbs): 1467 for anum in dictpa_d[elt][:-1]: 1468 edict[elt + anum] = {} 1469 for morb in dictio_d[elt][:-1]: 1470 edict[elt + anum][morb] = atoms_morbs[elt + anum][morb] 1471 sprojection = 0.0 1472 for morb in sum_morbs[elt]: 1473 sprojection += atoms_morbs[elt + anum][morb] 1474 edict[elt + anum][dictio_d[elt][-1]] = sprojection 1475 1476 edict[elt + dictpa_d[elt][-1]] = {} 1477 for morb in dictio_d[elt][:-1]: 1478 sprojection = 0.0 1479 for anum in sum_atoms[elt]: 1480 sprojection += atoms_morbs[elt + str(anum)][morb] 1481 edict[elt + dictpa_d[elt][-1]][morb] = sprojection 1482 1483 sprojection = 0.0 1484 for anum in sum_atoms[elt]: 1485 for morb in sum_morbs[elt]: 1486 sprojection += atoms_morbs[elt + str(anum)][morb] 1487 edict[elt + dictpa_d[elt][-1]][dictio_d[elt][-1]] = sprojection 1488 1489 elif (elt in sum_atoms) and (elt not in sum_morbs): 1490 for anum in dictpa_d[elt][:-1]: 1491 edict[elt + anum] = copy.deepcopy(atoms_morbs[elt + anum]) 1492 edict[elt + dictpa_d[elt][-1]] = {} 1493 for morb in dictio[elt]: 1494 sprojection = 0.0 1495 for anum in sum_atoms[elt]: 1496 sprojection += atoms_morbs[elt + str(anum)][morb] 1497 edict[elt + dictpa_d[elt][-1]][morb] = sprojection 1498 1499 elif (elt not in sum_atoms) and (elt in sum_morbs): 1500 for anum in dictpa_d[elt]: 1501 edict[elt + anum] = {} 1502 for morb in dictio_d[elt][:-1]: 1503 edict[elt + anum][morb] = atoms_morbs[elt + anum][morb] 1504 sprojection = 0.0 1505 for morb in sum_morbs[elt]: 1506 sprojection += atoms_morbs[elt + anum][morb] 1507 edict[elt + anum][dictio_d[elt][-1]] = sprojection 1508 1509 else: 1510 for anum in dictpa_d[elt]: 1511 edict[elt + anum] = {} 1512 for morb in dictio_d[elt]: 1513 edict[elt + anum][morb] = atoms_morbs[elt + anum][morb] 1514 proj_br_d[-1][str(Spin.down)][i].append(edict) 1515 1516 return proj_br_d, dictio_d, dictpa_d, indices 1517 1518 def get_projected_plots_dots_patom_pmorb( 1519 self, 1520 dictio, 1521 dictpa, 1522 sum_atoms=None, 1523 sum_morbs=None, 1524 zero_to_efermi=True, 1525 ylim=None, 1526 vbm_cbm_marker=False, 1527 selected_branches=None, 1528 w_h_size=(12, 8), 1529 num_column=None, 1530 ): 1531 """ 1532 Method returns a plot composed of subplots for different atoms and 1533 orbitals (subshell orbitals such as 's', 'p', 'd' and 'f' defined by 1534 azimuthal quantum numbers l = 0, 1, 2 and 3, respectively or 1535 individual orbitals like 'px', 'py' and 'pz' defined by magnetic 1536 quantum numbers m = -1, 1 and 0, respectively). 1537 This is an extension of "get_projected_plots_dots" method. 1538 1539 Args: 1540 dictio: The elements and the orbitals you need to project on. The 1541 format is {Element:[Orbitals]}, for instance: 1542 {'Cu':['dxy','s','px'],'O':['px','py','pz']} will give 1543 projections for Cu on orbitals dxy, s, px and 1544 for O on orbitals px, py, pz. If you want to sum over all 1545 individual orbitals of subshell orbitals, 1546 for example, 'px', 'py' and 'pz' of O, just simply set 1547 {'Cu':['dxy','s','px'],'O':['p']} and set sum_morbs (see 1548 explanations below) as {'O':[p],...}. 1549 Otherwise, you will get an error. 1550 dictpa: The elements and their sites (defined by site numbers) you 1551 need to project on. The format is 1552 {Element: [Site numbers]}, for instance: {'Cu':[1,5],'O':[3,4]} 1553 will give projections for Cu on site-1 1554 and on site-5, O on site-3 and on site-4 in the cell. 1555 Attention: 1556 The correct site numbers of atoms are consistent with 1557 themselves in the structure computed. Normally, 1558 the structure should be totally similar with POSCAR file, 1559 however, sometimes VASP can rotate or 1560 translate the cell. Thus, it would be safe if using Vasprun 1561 class to get the final_structure and as a 1562 result, correct index numbers of atoms. 1563 sum_atoms: Sum projection of the similar atoms together (e.g.: Cu 1564 on site-1 and Cu on site-5). The format is 1565 {Element: [Site numbers]}, for instance: 1566 {'Cu': [1,5], 'O': [3,4]} means summing projections over Cu on 1567 site-1 and Cu on site-5 and O on site-3 1568 and on site-4. If you do not want to use this functional, just 1569 turn it off by setting sum_atoms = None. 1570 sum_morbs: Sum projections of individual orbitals of similar atoms 1571 together (e.g.: 'dxy' and 'dxz'). The 1572 format is {Element: [individual orbitals]}, for instance: 1573 {'Cu': ['dxy', 'dxz'], 'O': ['px', 'py']} means summing 1574 projections over 'dxy' and 'dxz' of Cu and 'px' 1575 and 'py' of O. If you do not want to use this functional, just 1576 turn it off by setting sum_morbs = None. 1577 selected_branches: The index of symmetry lines you chose for 1578 plotting. This can be useful when the number of 1579 symmetry lines (in KPOINTS file) are manny while you only want 1580 to show for certain ones. The format is 1581 [index of line], for instance: 1582 [1, 3, 4] means you just need to do projection along lines 1583 number 1, 3 and 4 while neglecting lines 1584 number 2 and so on. By default, this is None type and all 1585 symmetry lines will be plotted. 1586 w_h_size: This variable help you to control the width and height 1587 of figure. By default, width = 12 and 1588 height = 8 (inches). The width/height ratio is kept the same 1589 for subfigures and the size of each depends 1590 on how many number of subfigures are plotted. 1591 num_column: This variable help you to manage how the subfigures are 1592 arranged in the figure by setting 1593 up the number of columns of subfigures. The value should be an 1594 int number. For example, num_column = 3 1595 means you want to plot subfigures in 3 columns. By default, 1596 num_column = None and subfigures are 1597 aligned in 2 columns. 1598 1599 Returns: 1600 A pylab object with different subfigures for different projections. 1601 The blue and red colors lines are bands 1602 for spin up and spin down. The green and cyan dots are projections 1603 for spin up and spin down. The bigger 1604 the green or cyan dots in the projected band structures, the higher 1605 character for the corresponding elements 1606 and orbitals. List of individual orbitals and their numbers (set up 1607 by VASP and no special meaning): 1608 s = 0; py = 1 pz = 2 px = 3; dxy = 4 dyz = 5 dz2 = 6 dxz = 7 dx2 = 8; 1609 f_3 = 9 f_2 = 10 f_1 = 11 f0 = 12 f1 = 13 f2 = 14 f3 = 15 1610 """ 1611 dictio, sum_morbs = self._Orbitals_SumOrbitals(dictio, sum_morbs) 1612 dictpa, sum_atoms, number_figs = self._number_of_subfigures(dictio, dictpa, sum_atoms, sum_morbs) 1613 print("Number of subfigures: %s" % str(number_figs)) 1614 if number_figs > 9: 1615 print( 1616 "The number of sub-figures %s might be too manny and the implementation might take a long time.\n" 1617 "A smaller number or a plot with selected symmetry lines (selected_branches) might be better.\n" 1618 % str(number_figs) 1619 ) 1620 from pymatgen.util.plotting import pretty_plot 1621 1622 band_linewidth = 0.5 1623 plt = pretty_plot(w_h_size[0], w_h_size[1]) 1624 ( 1625 proj_br_d, 1626 dictio_d, 1627 dictpa_d, 1628 branches, 1629 ) = self._get_projections_by_branches_patom_pmorb(dictio, dictpa, sum_atoms, sum_morbs, selected_branches) 1630 data = self.bs_plot_data(zero_to_efermi) 1631 e_min = -4 1632 e_max = 4 1633 if self._bs.is_metal(): 1634 e_min = -10 1635 e_max = 10 1636 1637 count = 0 1638 for elt in dictpa_d: 1639 for numa in dictpa_d[elt]: 1640 for o in dictio_d[elt]: 1641 1642 count += 1 1643 if num_column is None: 1644 if number_figs == 1: 1645 plt.subplot(1, 1, 1) 1646 else: 1647 row = number_figs / 2 1648 if number_figs % 2 == 0: 1649 plt.subplot(row, 2, count) 1650 else: 1651 plt.subplot(row + 1, 2, count) 1652 elif isinstance(num_column, int): 1653 row = number_figs / num_column 1654 if number_figs % num_column == 0: 1655 plt.subplot(row, num_column, count) 1656 else: 1657 plt.subplot(row + 1, num_column, count) 1658 else: 1659 raise ValueError("The invalid 'num_column' is assigned. It should be an integer.") 1660 1661 plt, shift = self._maketicks_selected(plt, branches) 1662 br = -1 1663 for b in branches: 1664 br += 1 1665 for i in range(self._nb_bands): 1666 plt.plot( 1667 list(map(lambda x: x - shift[br], data["distances"][b])), 1668 [data["energy"][str(Spin.up)][b][i][j] for j in range(len(data["distances"][b]))], 1669 "b-", 1670 linewidth=band_linewidth, 1671 ) 1672 1673 if self._bs.is_spin_polarized: 1674 plt.plot( 1675 list( 1676 map( 1677 lambda x: x - shift[br], 1678 data["distances"][b], 1679 ) 1680 ), 1681 [data["energy"][str(Spin.down)][b][i][j] for j in range(len(data["distances"][b]))], 1682 "r--", 1683 linewidth=band_linewidth, 1684 ) 1685 for j in range(len(data["energy"][str(Spin.up)][b][i])): 1686 plt.plot( 1687 data["distances"][b][j] - shift[br], 1688 data["energy"][str(Spin.down)][b][i][j], 1689 "co", 1690 markersize=proj_br_d[br][str(Spin.down)][i][j][elt + numa][o] * 15.0, 1691 ) 1692 1693 for j in range(len(data["energy"][str(Spin.up)][b][i])): 1694 plt.plot( 1695 data["distances"][b][j] - shift[br], 1696 data["energy"][str(Spin.up)][b][i][j], 1697 "go", 1698 markersize=proj_br_d[br][str(Spin.up)][i][j][elt + numa][o] * 15.0, 1699 ) 1700 1701 if ylim is None: 1702 if self._bs.is_metal(): 1703 if zero_to_efermi: 1704 plt.ylim(e_min, e_max) 1705 else: 1706 plt.ylim(self._bs.efermi + e_min, self._bs._efermi + e_max) 1707 else: 1708 if vbm_cbm_marker: 1709 for cbm in data["cbm"]: 1710 plt.scatter(cbm[0], cbm[1], color="r", marker="o", s=100) 1711 1712 for vbm in data["vbm"]: 1713 plt.scatter(vbm[0], vbm[1], color="g", marker="o", s=100) 1714 1715 plt.ylim(data["vbm"][0][1] + e_min, data["cbm"][0][1] + e_max) 1716 else: 1717 plt.ylim(ylim) 1718 plt.title(elt + " " + numa + " " + str(o)) 1719 1720 return plt 1721 1722 @classmethod 1723 def _Orbitals_SumOrbitals(cls, dictio, sum_morbs): 1724 all_orbitals = [ 1725 "s", 1726 "p", 1727 "d", 1728 "f", 1729 "px", 1730 "py", 1731 "pz", 1732 "dxy", 1733 "dyz", 1734 "dxz", 1735 "dx2", 1736 "dz2", 1737 "f_3", 1738 "f_2", 1739 "f_1", 1740 "f0", 1741 "f1", 1742 "f2", 1743 "f3", 1744 ] 1745 individual_orbs = { 1746 "p": ["px", "py", "pz"], 1747 "d": ["dxy", "dyz", "dxz", "dx2", "dz2"], 1748 "f": ["f_3", "f_2", "f_1", "f0", "f1", "f2", "f3"], 1749 } 1750 1751 if not isinstance(dictio, dict): 1752 raise TypeError("The invalid type of 'dictio' was bound. It should be dict type.") 1753 if len(dictio.keys()) == 0: 1754 raise KeyError("The 'dictio' is empty. We cannot do anything.") 1755 1756 for elt in dictio: 1757 if Element.is_valid_symbol(elt): 1758 if isinstance(dictio[elt], list): 1759 if len(dictio[elt]) == 0: 1760 raise ValueError("The dictio[%s] is empty. We cannot do anything" % elt) 1761 for orb in dictio[elt]: 1762 if not isinstance(orb, str): 1763 raise ValueError( 1764 "The invalid format of orbitals is in 'dictio[%s]': %s. " 1765 "They should be string." % (elt, str(orb)) 1766 ) 1767 if orb not in all_orbitals: 1768 raise ValueError("The invalid name of orbital is given in 'dictio[%s]'." % elt) 1769 if orb in individual_orbs.keys(): 1770 if len(set(dictio[elt]).intersection(individual_orbs[orb])) != 0: 1771 raise ValueError("The 'dictio[%s]' contains orbitals repeated." % elt) 1772 nelems = Counter(dictio[elt]).values() 1773 if sum(nelems) > len(nelems): 1774 raise ValueError("You put in at least two similar orbitals in dictio[%s]." % elt) 1775 else: 1776 raise TypeError( 1777 "The invalid type of value was put into 'dictio[%s]'. It should be list " "type." % elt 1778 ) 1779 else: 1780 raise KeyError("The invalid element was put into 'dictio' as a key: %s" % elt) 1781 1782 if sum_morbs is None: 1783 print("You do not want to sum projection over orbitals.") 1784 elif not isinstance(sum_morbs, dict): 1785 raise TypeError("The invalid type of 'sum_orbs' was bound. It should be dict or 'None' type.") 1786 elif len(sum_morbs.keys()) == 0: 1787 raise KeyError("The 'sum_morbs' is empty. We cannot do anything") 1788 else: 1789 for elt in sum_morbs: 1790 if Element.is_valid_symbol(elt): 1791 if isinstance(sum_morbs[elt], list): 1792 for orb in sum_morbs[elt]: 1793 if not isinstance(orb, str): 1794 raise TypeError( 1795 "The invalid format of orbitals is in 'sum_morbs[%s]': %s. " 1796 "They should be string." % (elt, str(orb)) 1797 ) 1798 if orb not in all_orbitals: 1799 raise ValueError("The invalid name of orbital in 'sum_morbs[%s]' is given." % elt) 1800 if orb in individual_orbs.keys(): 1801 if len(set(sum_morbs[elt]).intersection(individual_orbs[orb])) != 0: 1802 raise ValueError("The 'sum_morbs[%s]' contains orbitals repeated." % elt) 1803 nelems = Counter(sum_morbs[elt]).values() 1804 if sum(nelems) > len(nelems): 1805 raise ValueError("You put in at least two similar orbitals in sum_morbs[%s]." % elt) 1806 else: 1807 raise TypeError( 1808 "The invalid type of value was put into 'sum_morbs[%s]'. It should be list " "type." % elt 1809 ) 1810 if elt not in dictio.keys(): 1811 raise ValueError( 1812 "You cannot sum projection over orbitals of atoms '%s' because they are not " 1813 "mentioned in 'dictio'." % elt 1814 ) 1815 else: 1816 raise KeyError("The invalid element was put into 'sum_morbs' as a key: %s" % elt) 1817 1818 for elt in dictio: 1819 if len(dictio[elt]) == 1: 1820 if len(dictio[elt][0]) > 1: 1821 if elt in sum_morbs.keys(): 1822 raise ValueError( 1823 "You cannot sum projection over one individual orbital '%s' of '%s'." 1824 % (dictio[elt][0], elt) 1825 ) 1826 else: 1827 if sum_morbs is None: 1828 pass 1829 elif elt not in sum_morbs.keys(): 1830 print("You do not want to sum projection over orbitals of element: %s" % elt) 1831 else: 1832 if len(sum_morbs[elt]) == 0: 1833 raise ValueError("The empty list is an invalid value for sum_morbs[%s]." % elt) 1834 if len(sum_morbs[elt]) > 1: 1835 for orb in sum_morbs[elt]: 1836 if dictio[elt][0] not in orb: 1837 raise ValueError( 1838 "The invalid orbital '%s' was put into 'sum_morbs[%s]'." % (orb, elt) 1839 ) 1840 else: 1841 if orb == "s" or len(orb) > 1: 1842 raise ValueError("The invalid orbital '%s' was put into sum_orbs['%s']." % (orb, elt)) 1843 sum_morbs[elt] = individual_orbs[dictio[elt][0]] 1844 dictio[elt] = individual_orbs[dictio[elt][0]] 1845 else: 1846 duplicate = copy.deepcopy(dictio[elt]) 1847 for orb in dictio[elt]: 1848 if orb in individual_orbs.keys(): 1849 duplicate.remove(orb) 1850 for o in individual_orbs[orb]: 1851 duplicate.append(o) 1852 dictio[elt] = copy.deepcopy(duplicate) 1853 1854 if sum_morbs is None: 1855 pass 1856 elif elt not in sum_morbs.keys(): 1857 print("You do not want to sum projection over orbitals of element: %s" % elt) 1858 else: 1859 if len(sum_morbs[elt]) == 0: 1860 raise ValueError("The empty list is an invalid value for sum_morbs[%s]." % elt) 1861 if len(sum_morbs[elt]) == 1: 1862 orb = sum_morbs[elt][0] 1863 if orb == "s": 1864 raise ValueError( 1865 "We do not sum projection over only 's' orbital of the same " "type of element." 1866 ) 1867 if orb in individual_orbs.keys(): 1868 sum_morbs[elt].pop(0) 1869 for o in individual_orbs[orb]: 1870 sum_morbs[elt].append(o) 1871 else: 1872 raise ValueError("You never sum projection over one orbital in sum_morbs[%s]" % elt) 1873 else: 1874 duplicate = copy.deepcopy(sum_morbs[elt]) 1875 for orb in sum_morbs[elt]: 1876 if orb in individual_orbs.keys(): 1877 duplicate.remove(orb) 1878 for o in individual_orbs[orb]: 1879 duplicate.append(o) 1880 sum_morbs[elt] = copy.deepcopy(duplicate) 1881 1882 for orb in sum_morbs[elt]: 1883 if orb not in dictio[elt]: 1884 raise ValueError( 1885 "The orbitals of sum_morbs[%s] conflict with those of dictio[%s]." % (elt, elt) 1886 ) 1887 1888 return dictio, sum_morbs 1889 1890 def _number_of_subfigures(self, dictio, dictpa, sum_atoms, sum_morbs): 1891 from collections import Counter 1892 1893 from pymatgen.core.periodic_table import Element 1894 1895 if not isinstance(dictpa, dict): 1896 raise TypeError("The invalid type of 'dictpa' was bound. It should be dict type.") 1897 if len(dictpa.keys()) == 0: 1898 raise KeyError("The 'dictpa' is empty. We cannot do anything.") 1899 for elt in dictpa: 1900 if Element.is_valid_symbol(elt): 1901 if isinstance(dictpa[elt], list): 1902 if len(dictpa[elt]) == 0: 1903 raise ValueError("The dictpa[%s] is empty. We cannot do anything" % elt) 1904 _sites = self._bs.structure.sites 1905 indices = [] 1906 for i in range(0, len(_sites)): # pylint: disable=C0200 1907 if list(_sites[i]._species.keys())[0].__eq__(Element(elt)): 1908 indices.append(i + 1) 1909 for number in dictpa[elt]: 1910 if isinstance(number, str): 1911 if number.lower() == "all": 1912 dictpa[elt] = indices 1913 print("You want to consider all '%s' atoms." % elt) 1914 break 1915 1916 raise ValueError("You put wrong site numbers in 'dictpa[%s]': %s." % (elt, str(number))) 1917 if isinstance(number, int): 1918 if number not in indices: 1919 raise ValueError("You put wrong site numbers in 'dictpa[%s]': %s." % (elt, str(number))) 1920 else: 1921 raise ValueError("You put wrong site numbers in 'dictpa[%s]': %s." % (elt, str(number))) 1922 nelems = Counter(dictpa[elt]).values() 1923 if sum(nelems) > len(nelems): 1924 raise ValueError("You put at least two similar site numbers into 'dictpa[%s]'." % elt) 1925 else: 1926 raise TypeError( 1927 "The invalid type of value was put into 'dictpa[%s]'. It should be list " "type." % elt 1928 ) 1929 else: 1930 raise KeyError("The invalid element was put into 'dictpa' as a key: %s" % elt) 1931 1932 if len(list(dictio.keys())) != len(list(dictpa.keys())): 1933 raise KeyError("The number of keys in 'dictio' and 'dictpa' are not the same.") 1934 for elt in dictio.keys(): 1935 if elt not in dictpa.keys(): 1936 raise KeyError("The element '%s' is not in both dictpa and dictio." % elt) 1937 for elt in dictpa.keys(): 1938 if elt not in dictio.keys(): 1939 raise KeyError("The element '%s' in not in both dictpa and dictio." % elt) 1940 1941 if sum_atoms is None: 1942 print("You do not want to sum projection over atoms.") 1943 elif not isinstance(sum_atoms, dict): 1944 raise TypeError("The invalid type of 'sum_atoms' was bound. It should be dict type.") 1945 elif len(sum_atoms.keys()) == 0: 1946 raise KeyError("The 'sum_atoms' is empty. We cannot do anything.") 1947 else: 1948 for elt in sum_atoms: 1949 if Element.is_valid_symbol(elt): 1950 if isinstance(sum_atoms[elt], list): 1951 if len(sum_atoms[elt]) == 0: 1952 raise ValueError("The sum_atoms[%s] is empty. We cannot do anything" % elt) 1953 _sites = self._bs.structure.sites 1954 indices = [] 1955 for i in range(0, len(_sites)): # pylint: disable=C0200 1956 if list(_sites[i]._species.keys())[0].__eq__(Element(elt)): 1957 indices.append(i + 1) 1958 for number in sum_atoms[elt]: 1959 if isinstance(number, str): 1960 if number.lower() == "all": 1961 sum_atoms[elt] = indices 1962 print("You want to sum projection over all '%s' atoms." % elt) 1963 break 1964 raise ValueError("You put wrong site numbers in 'sum_atoms[%s]'." % elt) 1965 if isinstance(number, int): 1966 if number not in indices: 1967 raise ValueError("You put wrong site numbers in 'sum_atoms[%s]'." % elt) 1968 if number not in dictpa[elt]: 1969 raise ValueError( 1970 "You cannot sum projection with atom number '%s' because it is not " 1971 "metioned in dicpta[%s]" % (str(number), elt) 1972 ) 1973 else: 1974 raise ValueError("You put wrong site numbers in 'sum_atoms[%s]'." % elt) 1975 nelems = Counter(sum_atoms[elt]).values() 1976 if sum(nelems) > len(nelems): 1977 raise ValueError("You put at least two similar site numbers into 'sum_atoms[%s]'." % elt) 1978 else: 1979 raise TypeError( 1980 "The invalid type of value was put into 'sum_atoms[%s]'. It should be list " "type." % elt 1981 ) 1982 if elt not in dictpa.keys(): 1983 raise ValueError( 1984 "You cannot sum projection over atoms '%s' because it is not " 1985 "mentioned in 'dictio'." % elt 1986 ) 1987 else: 1988 raise KeyError("The invalid element was put into 'sum_atoms' as a key: %s" % elt) 1989 if len(sum_atoms[elt]) == 1: 1990 raise ValueError("We do not sum projection over only one atom: %s" % elt) 1991 1992 max_number_figs = 0 1993 decrease = 0 1994 for elt in dictio: 1995 max_number_figs += len(dictio[elt]) * len(dictpa[elt]) 1996 1997 if (sum_atoms is None) and (sum_morbs is None): 1998 number_figs = max_number_figs 1999 elif (sum_atoms is not None) and (sum_morbs is None): 2000 for elt in sum_atoms: 2001 decrease += (len(sum_atoms[elt]) - 1) * len(dictio[elt]) 2002 number_figs = max_number_figs - decrease 2003 elif (sum_atoms is None) and (sum_morbs is not None): 2004 for elt in sum_morbs: 2005 decrease += (len(sum_morbs[elt]) - 1) * len(dictpa[elt]) 2006 number_figs = max_number_figs - decrease 2007 elif (sum_atoms is not None) and (sum_morbs is not None): 2008 for elt in sum_atoms: 2009 decrease += (len(sum_atoms[elt]) - 1) * len(dictio[elt]) 2010 for elt in sum_morbs: 2011 if elt in sum_atoms: 2012 decrease += (len(sum_morbs[elt]) - 1) * (len(dictpa[elt]) - len(sum_atoms[elt]) + 1) 2013 else: 2014 decrease += (len(sum_morbs[elt]) - 1) * len(dictpa[elt]) 2015 number_figs = max_number_figs - decrease 2016 else: 2017 raise ValueError("Invalid format of 'sum_atoms' and 'sum_morbs'.") 2018 2019 return dictpa, sum_atoms, number_figs 2020 2021 def _summarize_keys_for_plot(self, dictio, dictpa, sum_atoms, sum_morbs): 2022 from pymatgen.core.periodic_table import Element 2023 2024 individual_orbs = { 2025 "p": ["px", "py", "pz"], 2026 "d": ["dxy", "dyz", "dxz", "dx2", "dz2"], 2027 "f": ["f_3", "f_2", "f_1", "f0", "f1", "f2", "f3"], 2028 } 2029 2030 def number_label(list_numbers): 2031 list_numbers = sorted(list_numbers) 2032 divide = [[]] 2033 divide[0].append(list_numbers[0]) 2034 group = 0 2035 for i in range(1, len(list_numbers)): 2036 if list_numbers[i] == list_numbers[i - 1] + 1: 2037 divide[group].append(list_numbers[i]) 2038 else: 2039 group += 1 2040 divide.append([list_numbers[i]]) 2041 label = "" 2042 for elem in divide: 2043 if len(elem) > 1: 2044 label += str(elem[0]) + "-" + str(elem[-1]) + "," 2045 else: 2046 label += str(elem[0]) + "," 2047 return label[:-1] 2048 2049 def orbital_label(list_orbitals): 2050 divide = {} 2051 for orb in list_orbitals: 2052 if orb[0] in divide: 2053 divide[orb[0]].append(orb) 2054 else: 2055 divide[orb[0]] = [] 2056 divide[orb[0]].append(orb) 2057 label = "" 2058 for elem, v in divide.items(): 2059 if elem == "s": 2060 label += "s" + "," 2061 else: 2062 if len(v) == len(individual_orbs[elem]): 2063 label += elem + "," 2064 else: 2065 l = [o[1:] for o in v] 2066 label += elem + str(l).replace("['", "").replace("']", "").replace("', '", "-") + "," 2067 return label[:-1] 2068 2069 if (sum_atoms is None) and (sum_morbs is None): 2070 dictio_d = dictio 2071 dictpa_d = {elt: [str(anum) for anum in dictpa[elt]] for elt in dictpa} 2072 2073 elif (sum_atoms is not None) and (sum_morbs is None): 2074 dictio_d = dictio 2075 dictpa_d = {} 2076 for elt in dictpa: 2077 dictpa_d[elt] = [] 2078 if elt in sum_atoms: 2079 _sites = self._bs.structure.sites 2080 indices = [] 2081 for i in range(0, len(_sites)): # pylint: disable=C0200 2082 if list(_sites[i]._species.keys())[0].__eq__(Element(elt)): 2083 indices.append(i + 1) 2084 flag_1 = len(set(dictpa[elt]).intersection(indices)) 2085 flag_2 = len(set(sum_atoms[elt]).intersection(indices)) 2086 if flag_1 == len(indices) and flag_2 == len(indices): 2087 dictpa_d[elt].append("all") 2088 else: 2089 for anum in dictpa[elt]: 2090 if anum not in sum_atoms[elt]: 2091 dictpa_d[elt].append(str(anum)) 2092 label = number_label(sum_atoms[elt]) 2093 dictpa_d[elt].append(label) 2094 else: 2095 for anum in dictpa[elt]: 2096 dictpa_d[elt].append(str(anum)) 2097 2098 elif (sum_atoms is None) and (sum_morbs is not None): 2099 dictio_d = {} 2100 for elt in dictio: 2101 dictio_d[elt] = [] 2102 if elt in sum_morbs: 2103 for morb in dictio[elt]: 2104 if morb not in sum_morbs[elt]: 2105 dictio_d[elt].append(morb) 2106 label = orbital_label(sum_morbs[elt]) 2107 dictio_d[elt].append(label) 2108 else: 2109 dictio_d[elt] = dictio[elt] 2110 dictpa_d = {elt: [str(anum) for anum in dictpa[elt]] for elt in dictpa} 2111 2112 else: 2113 dictio_d = {} 2114 for elt in dictio: 2115 dictio_d[elt] = [] 2116 if elt in sum_morbs: 2117 for morb in dictio[elt]: 2118 if morb not in sum_morbs[elt]: 2119 dictio_d[elt].append(morb) 2120 label = orbital_label(sum_morbs[elt]) 2121 dictio_d[elt].append(label) 2122 else: 2123 dictio_d[elt] = dictio[elt] 2124 dictpa_d = {} 2125 for elt in dictpa: 2126 dictpa_d[elt] = [] 2127 if elt in sum_atoms: 2128 _sites = self._bs.structure.sites 2129 indices = [] 2130 for i in range(0, len(_sites)): # pylint: disable=C0200 2131 if list(_sites[i]._species.keys())[0].__eq__(Element(elt)): 2132 indices.append(i + 1) 2133 flag_1 = len(set(dictpa[elt]).intersection(indices)) 2134 flag_2 = len(set(sum_atoms[elt]).intersection(indices)) 2135 if flag_1 == len(indices) and flag_2 == len(indices): 2136 dictpa_d[elt].append("all") 2137 else: 2138 for anum in dictpa[elt]: 2139 if anum not in sum_atoms[elt]: 2140 dictpa_d[elt].append(str(anum)) 2141 label = number_label(sum_atoms[elt]) 2142 dictpa_d[elt].append(label) 2143 else: 2144 for anum in dictpa[elt]: 2145 dictpa_d[elt].append(str(anum)) 2146 2147 return dictio_d, dictpa_d 2148 2149 def _maketicks_selected(self, plt, branches): 2150 """ 2151 utility private method to add ticks to a band structure with selected branches 2152 """ 2153 ticks = self.get_ticks() 2154 distance = [] 2155 label = [] 2156 rm_elems = [] 2157 for i in range(1, len(ticks["distance"])): 2158 if ticks["label"][i] == ticks["label"][i - 1]: 2159 rm_elems.append(i) 2160 for i in range(len(ticks["distance"])): 2161 if i not in rm_elems: 2162 distance.append(ticks["distance"][i]) 2163 label.append(ticks["label"][i]) 2164 l_branches = [distance[i] - distance[i - 1] for i in range(1, len(distance))] 2165 n_distance = [] 2166 n_label = [] 2167 for branch in branches: 2168 n_distance.append(l_branches[branch]) 2169 if ("$\\mid$" not in label[branch]) and ("$\\mid$" not in label[branch + 1]): 2170 n_label.append([label[branch], label[branch + 1]]) 2171 elif ("$\\mid$" in label[branch]) and ("$\\mid$" not in label[branch + 1]): 2172 n_label.append([label[branch].split("$")[-1], label[branch + 1]]) 2173 elif ("$\\mid$" not in label[branch]) and ("$\\mid$" in label[branch + 1]): 2174 n_label.append([label[branch], label[branch + 1].split("$")[0]]) 2175 else: 2176 n_label.append([label[branch].split("$")[-1], label[branch + 1].split("$")[0]]) 2177 2178 f_distance = [] 2179 rf_distance = [] 2180 f_label = [] 2181 f_label.append(n_label[0][0]) 2182 f_label.append(n_label[0][1]) 2183 f_distance.append(0.0) 2184 f_distance.append(n_distance[0]) 2185 rf_distance.append(0.0) 2186 rf_distance.append(n_distance[0]) 2187 length = n_distance[0] 2188 for i in range(1, len(n_distance)): 2189 if n_label[i][0] == n_label[i - 1][1]: 2190 f_distance.append(length) 2191 f_distance.append(length + n_distance[i]) 2192 f_label.append(n_label[i][0]) 2193 f_label.append(n_label[i][1]) 2194 else: 2195 f_distance.append(length + n_distance[i]) 2196 f_label[-1] = n_label[i - 1][1] + "$\\mid$" + n_label[i][0] 2197 f_label.append(n_label[i][1]) 2198 rf_distance.append(length + n_distance[i]) 2199 length += n_distance[i] 2200 2201 n_ticks = {"distance": f_distance, "label": f_label} 2202 uniq_d = [] 2203 uniq_l = [] 2204 temp_ticks = list(zip(n_ticks["distance"], n_ticks["label"])) 2205 for i, t in enumerate(temp_ticks): 2206 if i == 0: 2207 uniq_d.append(t[0]) 2208 uniq_l.append(t[1]) 2209 logger.debug("Adding label {l} at {d}".format(l=t[0], d=t[1])) 2210 else: 2211 if t[1] == temp_ticks[i - 1][1]: 2212 logger.debug("Skipping label {i}".format(i=t[1])) 2213 else: 2214 logger.debug("Adding label {l} at {d}".format(l=t[0], d=t[1])) 2215 uniq_d.append(t[0]) 2216 uniq_l.append(t[1]) 2217 2218 logger.debug("Unique labels are %s" % list(zip(uniq_d, uniq_l))) 2219 plt.gca().set_xticks(uniq_d) 2220 plt.gca().set_xticklabels(uniq_l) 2221 2222 for i in range(len(n_ticks["label"])): 2223 if n_ticks["label"][i] is not None: 2224 # don't print the same label twice 2225 if i != 0: 2226 if n_ticks["label"][i] == n_ticks["label"][i - 1]: 2227 logger.debug("already print label... " "skipping label {i}".format(i=n_ticks["label"][i])) 2228 else: 2229 logger.debug( 2230 "Adding a line at {d}" 2231 " for label {l}".format(d=n_ticks["distance"][i], l=n_ticks["label"][i]) 2232 ) 2233 plt.axvline(n_ticks["distance"][i], color="k") 2234 else: 2235 logger.debug( 2236 "Adding a line at {d} for label {l}".format(d=n_ticks["distance"][i], l=n_ticks["label"][i]) 2237 ) 2238 plt.axvline(n_ticks["distance"][i], color="k") 2239 2240 shift = [] 2241 br = -1 2242 for branch in branches: 2243 br += 1 2244 shift.append(distance[branch] - rf_distance[br]) 2245 2246 return plt, shift 2247 2248 2249class BSDOSPlotter: 2250 """ 2251 A joint, aligned band structure and density of states plot. Contributions 2252 from Jan Pohls as well as the online example from Germain Salvato-Vallverdu: 2253 http://gvallver.perso.univ-pau.fr/?p=587 2254 """ 2255 2256 def __init__( 2257 self, 2258 bs_projection="elements", 2259 dos_projection="elements", 2260 vb_energy_range=4, 2261 cb_energy_range=4, 2262 fixed_cb_energy=False, 2263 egrid_interval=1, 2264 font="Times New Roman", 2265 axis_fontsize=20, 2266 tick_fontsize=15, 2267 legend_fontsize=14, 2268 bs_legend="best", 2269 dos_legend="best", 2270 rgb_legend=True, 2271 fig_size=(11, 8.5), 2272 ): 2273 """ 2274 Instantiate plotter settings. 2275 2276 Args: 2277 bs_projection (str): "elements" or None 2278 dos_projection (str): "elements", "orbitals", or None 2279 vb_energy_range (float): energy in eV to show of valence bands 2280 cb_energy_range (float): energy in eV to show of conduction bands 2281 fixed_cb_energy (bool): If true, the cb_energy_range will be interpreted 2282 as constant (i.e., no gap correction for cb energy) 2283 egrid_interval (float): interval for grid marks 2284 font (str): font family 2285 axis_fontsize (float): font size for axis 2286 tick_fontsize (float): font size for axis tick labels 2287 legend_fontsize (float): font size for legends 2288 bs_legend (str): matplotlib string location for legend or None 2289 dos_legend (str): matplotlib string location for legend or None 2290 rgb_legend (bool): (T/F) whether to draw RGB triangle/bar for element proj. 2291 fig_size(tuple): dimensions of figure size (width, height) 2292 """ 2293 self.bs_projection = bs_projection 2294 self.dos_projection = dos_projection 2295 self.vb_energy_range = vb_energy_range 2296 self.cb_energy_range = cb_energy_range 2297 self.fixed_cb_energy = fixed_cb_energy 2298 self.egrid_interval = egrid_interval 2299 self.font = font 2300 self.axis_fontsize = axis_fontsize 2301 self.tick_fontsize = tick_fontsize 2302 self.legend_fontsize = legend_fontsize 2303 self.bs_legend = bs_legend 2304 self.dos_legend = dos_legend 2305 self.rgb_legend = rgb_legend 2306 self.fig_size = fig_size 2307 2308 def get_plot(self, bs, dos=None): 2309 """ 2310 Get a matplotlib plot object. 2311 Args: 2312 bs (BandStructureSymmLine): the bandstructure to plot. Projection 2313 data must exist for projected plots. 2314 dos (Dos): the Dos to plot. Projection data must exist (i.e., 2315 CompleteDos) for projected plots. 2316 2317 Returns: 2318 matplotlib.pyplot object on which you can call commands like show() 2319 and savefig() 2320 """ 2321 import matplotlib.lines as mlines 2322 import matplotlib.pyplot as mplt 2323 from matplotlib.gridspec import GridSpec 2324 2325 # make sure the user-specified band structure projection is valid 2326 bs_projection = self.bs_projection 2327 if dos: 2328 elements = [e.symbol for e in dos.structure.composition.elements] 2329 elif bs_projection and bs.structure: 2330 elements = [e.symbol for e in bs.structure.composition.elements] 2331 else: 2332 elements = [] 2333 2334 rgb_legend = ( 2335 self.rgb_legend and bs_projection and bs_projection.lower() == "elements" and len(elements) in [2, 3] 2336 ) 2337 2338 if ( 2339 bs_projection 2340 and bs_projection.lower() == "elements" 2341 and (len(elements) not in [2, 3] or not bs.get_projection_on_elements()) 2342 ): 2343 warnings.warn( 2344 "Cannot get element projected data; either the projection data " 2345 "doesn't exist, or you don't have a compound with exactly 2 " 2346 "or 3 unique elements." 2347 ) 2348 bs_projection = None 2349 2350 # specify energy range of plot 2351 emin = -self.vb_energy_range 2352 emax = self.cb_energy_range if self.fixed_cb_energy else self.cb_energy_range + bs.get_band_gap()["energy"] 2353 2354 # initialize all the k-point labels and k-point x-distances for bs plot 2355 xlabels = [] # all symmetry point labels on x-axis 2356 xlabel_distances = [] # positions of symmetry point x-labels 2357 2358 x_distances_list = [] 2359 prev_right_klabel = None # used to determine which branches require a midline separator 2360 2361 for idx, l in enumerate(bs.branches): 2362 x_distances = [] 2363 2364 # get left and right kpoint labels of this branch 2365 left_k, right_k = l["name"].split("-") 2366 2367 # add $ notation for LaTeX kpoint labels 2368 if left_k[0] == "\\" or "_" in left_k: 2369 left_k = "$" + left_k + "$" 2370 if right_k[0] == "\\" or "_" in right_k: 2371 right_k = "$" + right_k + "$" 2372 2373 # add left k label to list of labels 2374 if prev_right_klabel is None: 2375 xlabels.append(left_k) 2376 xlabel_distances.append(0) 2377 elif prev_right_klabel != left_k: # used for pipe separator 2378 xlabels[-1] = xlabels[-1] + "$\\mid$ " + left_k 2379 2380 # add right k label to list of labels 2381 xlabels.append(right_k) 2382 prev_right_klabel = right_k 2383 2384 # add x-coordinates for labels 2385 left_kpoint = bs.kpoints[l["start_index"]].cart_coords 2386 right_kpoint = bs.kpoints[l["end_index"]].cart_coords 2387 distance = np.linalg.norm(right_kpoint - left_kpoint) 2388 xlabel_distances.append(xlabel_distances[-1] + distance) 2389 2390 # add x-coordinates for kpoint data 2391 npts = l["end_index"] - l["start_index"] 2392 distance_interval = distance / npts 2393 x_distances.append(xlabel_distances[-2]) 2394 for i in range(npts): 2395 x_distances.append(x_distances[-1] + distance_interval) 2396 x_distances_list.append(x_distances) 2397 2398 # set up bs and dos plot 2399 gs = GridSpec(1, 2, width_ratios=[2, 1]) if dos else GridSpec(1, 1) 2400 2401 fig = mplt.figure(figsize=self.fig_size) 2402 fig.patch.set_facecolor("white") 2403 bs_ax = mplt.subplot(gs[0]) 2404 if dos: 2405 dos_ax = mplt.subplot(gs[1]) 2406 2407 # set basic axes limits for the plot 2408 bs_ax.set_xlim(0, x_distances_list[-1][-1]) 2409 bs_ax.set_ylim(emin, emax) 2410 if dos: 2411 dos_ax.set_ylim(emin, emax) 2412 2413 # add BS xticks, labels, etc. 2414 bs_ax.set_xticks(xlabel_distances) 2415 bs_ax.set_xticklabels(xlabels, size=self.tick_fontsize) 2416 bs_ax.set_xlabel("Wavevector $k$", fontsize=self.axis_fontsize, family=self.font) 2417 bs_ax.set_ylabel("$E-E_F$ / eV", fontsize=self.axis_fontsize, family=self.font) 2418 2419 # add BS fermi level line at E=0 and gridlines 2420 bs_ax.hlines(y=0, xmin=0, xmax=x_distances_list[-1][-1], color="k", lw=2) 2421 bs_ax.set_yticks(np.arange(emin, emax + 1e-5, self.egrid_interval)) 2422 bs_ax.set_yticklabels(np.arange(emin, emax + 1e-5, self.egrid_interval), size=self.tick_fontsize) 2423 bs_ax.set_axisbelow(True) 2424 bs_ax.grid(color=[0.5, 0.5, 0.5], linestyle="dotted", linewidth=1) 2425 if dos: 2426 dos_ax.set_yticks(np.arange(emin, emax + 1e-5, self.egrid_interval)) 2427 dos_ax.set_yticklabels([]) 2428 dos_ax.grid(color=[0.5, 0.5, 0.5], linestyle="dotted", linewidth=1) 2429 2430 # renormalize the band energy to the Fermi level 2431 band_energies = {} 2432 for spin in (Spin.up, Spin.down): 2433 if spin in bs.bands: 2434 band_energies[spin] = [] 2435 for band in bs.bands[spin]: 2436 band_energies[spin].append([e - bs.efermi for e in band]) 2437 2438 # renormalize the DOS energies to Fermi level 2439 if dos: 2440 dos_energies = [e - dos.efermi for e in dos.energies] 2441 2442 # get the projection data to set colors for the band structure 2443 colordata = self._get_colordata(bs, elements, bs_projection) 2444 2445 # plot the colored band structure lines 2446 for spin in (Spin.up, Spin.down): 2447 if spin in band_energies: 2448 linestyles = "solid" if spin == Spin.up else "dotted" 2449 for band_idx, band in enumerate(band_energies[spin]): 2450 current_pos = 0 2451 for x_distances in x_distances_list: 2452 sub_band = band[current_pos : current_pos + len(x_distances)] 2453 2454 self._rgbline( 2455 bs_ax, 2456 x_distances, 2457 sub_band, 2458 colordata[spin][band_idx, :, 0][current_pos : current_pos + len(x_distances)], 2459 colordata[spin][band_idx, :, 1][current_pos : current_pos + len(x_distances)], 2460 colordata[spin][band_idx, :, 2][current_pos : current_pos + len(x_distances)], 2461 linestyles=linestyles, 2462 ) 2463 2464 current_pos += len(x_distances) 2465 2466 if dos: 2467 # Plot the DOS and projected DOS 2468 for spin in (Spin.up, Spin.down): 2469 if spin in dos.densities: 2470 # plot the total DOS 2471 dos_densities = dos.densities[spin] * int(spin) 2472 label = "total" if spin == Spin.up else None 2473 dos_ax.plot(dos_densities, dos_energies, color=(0.6, 0.6, 0.6), label=label) 2474 dos_ax.fill_betweenx( 2475 dos_energies, 2476 0, 2477 dos_densities, 2478 color=(0.7, 0.7, 0.7), 2479 facecolor=(0.7, 0.7, 0.7), 2480 ) 2481 2482 if self.dos_projection is None: 2483 pass 2484 2485 elif self.dos_projection.lower() == "elements": 2486 # plot the atom-projected DOS 2487 colors = ["b", "r", "g", "m", "y", "c", "k", "w"] 2488 el_dos = dos.get_element_dos() 2489 for idx, el in enumerate(elements): 2490 dos_densities = el_dos[Element(el)].densities[spin] * int(spin) 2491 label = el if spin == Spin.up else None 2492 dos_ax.plot( 2493 dos_densities, 2494 dos_energies, 2495 color=colors[idx], 2496 label=label, 2497 ) 2498 2499 elif self.dos_projection.lower() == "orbitals": 2500 # plot each of the atomic projected DOS 2501 colors = ["b", "r", "g", "m"] 2502 spd_dos = dos.get_spd_dos() 2503 for idx, orb in enumerate([OrbitalType.s, OrbitalType.p, OrbitalType.d, OrbitalType.f]): 2504 if orb in spd_dos: 2505 dos_densities = spd_dos[orb].densities[spin] * int(spin) 2506 label = orb if spin == Spin.up else None 2507 dos_ax.plot( 2508 dos_densities, 2509 dos_energies, 2510 color=colors[idx], 2511 label=label, 2512 ) 2513 2514 # get index of lowest and highest energy being plotted, used to help auto-scale DOS x-axis 2515 emin_idx = next(x[0] for x in enumerate(dos_energies) if x[1] >= emin) 2516 emax_idx = len(dos_energies) - next(x[0] for x in enumerate(reversed(dos_energies)) if x[1] <= emax) 2517 2518 # determine DOS x-axis range 2519 dos_xmin = ( 2520 0 if Spin.down not in dos.densities else -max(dos.densities[Spin.down][emin_idx : emax_idx + 1] * 1.05) 2521 ) 2522 dos_xmax = max([max(dos.densities[Spin.up][emin_idx:emax_idx]) * 1.05, abs(dos_xmin)]) 2523 2524 # set up the DOS x-axis and add Fermi level line 2525 dos_ax.set_xlim(dos_xmin, dos_xmax) 2526 dos_ax.set_xticklabels([]) 2527 dos_ax.hlines(y=0, xmin=dos_xmin, xmax=dos_xmax, color="k", lw=2) 2528 dos_ax.set_xlabel("DOS", fontsize=self.axis_fontsize, family=self.font) 2529 2530 # add legend for band structure 2531 if self.bs_legend and not rgb_legend: 2532 handles = [] 2533 2534 if bs_projection is None: 2535 handles = [ 2536 mlines.Line2D([], [], linewidth=2, color="k", label="spin up"), 2537 mlines.Line2D( 2538 [], 2539 [], 2540 linewidth=2, 2541 color="b", 2542 linestyle="dotted", 2543 label="spin down", 2544 ), 2545 ] 2546 2547 elif bs_projection.lower() == "elements": 2548 colors = ["b", "r", "g"] 2549 for idx, el in enumerate(elements): 2550 handles.append(mlines.Line2D([], [], linewidth=2, color=colors[idx], label=el)) 2551 2552 bs_ax.legend( 2553 handles=handles, 2554 fancybox=True, 2555 prop={"size": self.legend_fontsize, "family": self.font}, 2556 loc=self.bs_legend, 2557 ) 2558 2559 elif self.bs_legend and rgb_legend: 2560 if len(elements) == 2: 2561 self._rb_line(bs_ax, elements[1], elements[0], loc=self.bs_legend) 2562 elif len(elements) == 3: 2563 self._rgb_triangle(bs_ax, elements[1], elements[2], elements[0], loc=self.bs_legend) 2564 2565 # add legend for DOS 2566 if dos and self.dos_legend: 2567 dos_ax.legend( 2568 fancybox=True, 2569 prop={"size": self.legend_fontsize, "family": self.font}, 2570 loc=self.dos_legend, 2571 ) 2572 2573 mplt.subplots_adjust(wspace=0.1) 2574 return mplt 2575 2576 @staticmethod 2577 def _rgbline(ax, k, e, red, green, blue, alpha=1, linestyles="solid"): 2578 """ 2579 An RGB colored line for plotting. 2580 creation of segments based on: 2581 http://nbviewer.ipython.org/urls/raw.github.com/dpsanders/matplotlib-examples/master/colorline.ipynb 2582 Args: 2583 ax: matplotlib axis 2584 k: x-axis data (k-points) 2585 e: y-axis data (energies) 2586 red: red data 2587 green: green data 2588 blue: blue data 2589 alpha: alpha values data 2590 linestyles: linestyle for plot (e.g., "solid" or "dotted") 2591 """ 2592 from matplotlib.collections import LineCollection 2593 2594 pts = np.array([k, e]).T.reshape(-1, 1, 2) 2595 seg = np.concatenate([pts[:-1], pts[1:]], axis=1) 2596 2597 nseg = len(k) - 1 2598 r = [0.5 * (red[i] + red[i + 1]) for i in range(nseg)] 2599 g = [0.5 * (green[i] + green[i + 1]) for i in range(nseg)] 2600 b = [0.5 * (blue[i] + blue[i + 1]) for i in range(nseg)] 2601 a = np.ones(nseg, np.float_) * alpha 2602 lc = LineCollection(seg, colors=list(zip(r, g, b, a)), linewidth=2, linestyles=linestyles) 2603 ax.add_collection(lc) 2604 2605 @staticmethod 2606 def _get_colordata(bs, elements, bs_projection): 2607 """ 2608 Get color data, including projected band structures 2609 Args: 2610 bs: Bandstructure object 2611 elements: elements (in desired order) for setting to blue, red, green 2612 bs_projection: None for no projection, "elements" for element projection 2613 2614 Returns: 2615 2616 """ 2617 contribs = {} 2618 if bs_projection and bs_projection.lower() == "elements": 2619 projections = bs.get_projection_on_elements() 2620 2621 for spin in (Spin.up, Spin.down): 2622 if spin in bs.bands: 2623 contribs[spin] = [] 2624 for band_idx in range(bs.nb_bands): 2625 colors = [] 2626 for k_idx in range(len(bs.kpoints)): 2627 if bs_projection and bs_projection.lower() == "elements": 2628 c = [0, 0, 0] 2629 projs = projections[spin][band_idx][k_idx] 2630 # note: squared color interpolations are smoother 2631 # see: https://youtu.be/LKnqECcg6Gw 2632 projs = {k: v ** 2 for k, v in projs.items()} 2633 total = sum(projs.values()) 2634 if total > 0: 2635 for idx, e in enumerate(elements): 2636 c[idx] = math.sqrt(projs[e] / total) # min is to handle round errors 2637 2638 c = [c[1], c[2], c[0]] # prefer blue, then red, then green 2639 2640 else: 2641 c = [0, 0, 0] if spin == Spin.up else [0, 0, 1] # black for spin up, blue for spin down 2642 2643 colors.append(c) 2644 2645 contribs[spin].append(colors) 2646 contribs[spin] = np.array(contribs[spin]) 2647 2648 return contribs 2649 2650 @staticmethod 2651 def _rgb_triangle(ax, r_label, g_label, b_label, loc): 2652 """ 2653 Draw an RGB triangle legend on the desired axis 2654 """ 2655 if loc not in range(1, 11): 2656 loc = 2 2657 2658 from mpl_toolkits.axes_grid1.inset_locator import inset_axes 2659 2660 inset_ax = inset_axes(ax, width=1, height=1, loc=loc) 2661 mesh = 35 2662 x = [] 2663 y = [] 2664 color = [] 2665 for r in range(0, mesh): 2666 for g in range(0, mesh): 2667 for b in range(0, mesh): 2668 if not (r == 0 and b == 0 and g == 0): 2669 r1 = r / (r + g + b) 2670 g1 = g / (r + g + b) 2671 b1 = b / (r + g + b) 2672 x.append(0.33 * (2.0 * g1 + r1) / (r1 + b1 + g1)) 2673 y.append(0.33 * np.sqrt(3) * r1 / (r1 + b1 + g1)) 2674 rc = math.sqrt(r ** 2 / (r ** 2 + g ** 2 + b ** 2)) 2675 gc = math.sqrt(g ** 2 / (r ** 2 + g ** 2 + b ** 2)) 2676 bc = math.sqrt(b ** 2 / (r ** 2 + g ** 2 + b ** 2)) 2677 color.append([rc, gc, bc]) 2678 2679 # x = [n + 0.25 for n in x] # nudge x coordinates 2680 # y = [n + (max_y - 1) for n in y] # shift y coordinates to top 2681 # plot the triangle 2682 inset_ax.scatter(x, y, s=7, marker=".", edgecolor=color) # pylint: disable=E1101 2683 inset_ax.set_xlim([-0.35, 1.00]) # pylint: disable=E1101 2684 inset_ax.set_ylim([-0.35, 1.00]) # pylint: disable=E1101 2685 2686 # add the labels 2687 inset_ax.text( # pylint: disable=E1101 2688 0.70, 2689 -0.2, 2690 g_label, 2691 fontsize=13, 2692 family="Times New Roman", 2693 color=(0, 0, 0), 2694 horizontalalignment="left", 2695 ) 2696 inset_ax.text( # pylint: disable=E1101 2697 0.325, 2698 0.70, 2699 r_label, 2700 fontsize=13, 2701 family="Times New Roman", 2702 color=(0, 0, 0), 2703 horizontalalignment="center", 2704 ) 2705 inset_ax.text( # pylint: disable=E1101 2706 -0.05, 2707 -0.2, 2708 b_label, 2709 fontsize=13, 2710 family="Times New Roman", 2711 color=(0, 0, 0), 2712 horizontalalignment="right", 2713 ) 2714 2715 inset_ax.get_xaxis().set_visible(False) # pylint: disable=E1101 2716 inset_ax.get_yaxis().set_visible(False) # pylint: disable=E1101 2717 2718 @staticmethod 2719 def _rb_line(ax, r_label, b_label, loc): 2720 # Draw an rb bar legend on the desired axis 2721 2722 if loc not in range(1, 11): 2723 loc = 2 2724 from mpl_toolkits.axes_grid1.inset_locator import inset_axes 2725 2726 inset_ax = inset_axes(ax, width=1.2, height=0.4, loc=loc) 2727 2728 x = [] 2729 y = [] 2730 color = [] 2731 for i in range(0, 1000): 2732 x.append(i / 1800.0 + 0.55) 2733 y.append(0) 2734 color.append([math.sqrt(c) for c in [1 - (i / 1000) ** 2, 0, (i / 1000) ** 2]]) 2735 2736 # plot the bar 2737 # pylint: disable=E1101 2738 inset_ax.scatter(x, y, s=250.0, marker="s", c=color) 2739 inset_ax.set_xlim([-0.1, 1.7]) 2740 inset_ax.text( 2741 1.35, 2742 0, 2743 b_label, 2744 fontsize=13, 2745 family="Times New Roman", 2746 color=(0, 0, 0), 2747 horizontalalignment="left", 2748 verticalalignment="center", 2749 ) 2750 inset_ax.text( 2751 0.30, 2752 0, 2753 r_label, 2754 fontsize=13, 2755 family="Times New Roman", 2756 color=(0, 0, 0), 2757 horizontalalignment="right", 2758 verticalalignment="center", 2759 ) 2760 2761 inset_ax.get_xaxis().set_visible(False) 2762 inset_ax.get_yaxis().set_visible(False) 2763 2764 2765class BoltztrapPlotter: 2766 # TODO: We need a unittest for this. Come on folks. 2767 """ 2768 class containing methods to plot the data from Boltztrap. 2769 """ 2770 2771 def __init__(self, bz): 2772 """ 2773 Args: 2774 bz: a BoltztrapAnalyzer object 2775 """ 2776 self._bz = bz 2777 2778 def _plot_doping(self, plt, temp): 2779 if len(self._bz.doping) != 0: 2780 limit = 2.21e15 2781 plt.axvline(self._bz.mu_doping["n"][temp][0], linewidth=3.0, linestyle="--") 2782 plt.text( 2783 self._bz.mu_doping["n"][temp][0] + 0.01, 2784 limit, 2785 "$n$=10$^{" + str(math.log10(self._bz.doping["n"][0])) + "}$", 2786 color="b", 2787 ) 2788 plt.axvline(self._bz.mu_doping["n"][temp][-1], linewidth=3.0, linestyle="--") 2789 plt.text( 2790 self._bz.mu_doping["n"][temp][-1] + 0.01, 2791 limit, 2792 "$n$=10$^{" + str(math.log10(self._bz.doping["n"][-1])) + "}$", 2793 color="b", 2794 ) 2795 plt.axvline(self._bz.mu_doping["p"][temp][0], linewidth=3.0, linestyle="--") 2796 plt.text( 2797 self._bz.mu_doping["p"][temp][0] + 0.01, 2798 limit, 2799 "$p$=10$^{" + str(math.log10(self._bz.doping["p"][0])) + "}$", 2800 color="b", 2801 ) 2802 plt.axvline(self._bz.mu_doping["p"][temp][-1], linewidth=3.0, linestyle="--") 2803 plt.text( 2804 self._bz.mu_doping["p"][temp][-1] + 0.01, 2805 limit, 2806 "$p$=10$^{" + str(math.log10(self._bz.doping["p"][-1])) + "}$", 2807 color="b", 2808 ) 2809 2810 def _plot_bg_limits(self, plt): 2811 plt.axvline(0.0, color="k", linewidth=3.0) 2812 plt.axvline(self._bz.gap, color="k", linewidth=3.0) 2813 2814 def plot_seebeck_eff_mass_mu(self, temps=[300], output="average", Lambda=0.5): 2815 """ 2816 Plot respect to the chemical potential of the Seebeck effective mass 2817 calculated as explained in Ref. 2818 Gibbs, Z. M. et al., Effective mass and fermi surface complexity factor 2819 from ab initio band structure calculations. 2820 npj Computational Materials 3, 8 (2017). 2821 2822 Args: 2823 output: 'average' returns the seebeck effective mass calculated 2824 using the average of the three diagonal components of the 2825 seebeck tensor. 'tensor' returns the seebeck effective mass 2826 respect to the three diagonal components of the seebeck tensor. 2827 temps: list of temperatures of calculated seebeck. 2828 Lambda: fitting parameter used to model the scattering (0.5 means 2829 constant relaxation time). 2830 Returns: 2831 a matplotlib object 2832 """ 2833 2834 plt = pretty_plot(9, 7) 2835 for T in temps: 2836 sbk_mass = self._bz.get_seebeck_eff_mass(output=output, temp=T, Lambda=0.5) 2837 # remove noise inside the gap 2838 start = self._bz.mu_doping["p"][T][0] 2839 stop = self._bz.mu_doping["n"][T][0] 2840 mu_steps_1 = [] 2841 mu_steps_2 = [] 2842 sbk_mass_1 = [] 2843 sbk_mass_2 = [] 2844 for i, mu in enumerate(self._bz.mu_steps): 2845 if mu <= start: 2846 mu_steps_1.append(mu) 2847 sbk_mass_1.append(sbk_mass[i]) 2848 elif mu >= stop: 2849 mu_steps_2.append(mu) 2850 sbk_mass_2.append(sbk_mass[i]) 2851 2852 plt.plot(mu_steps_1, sbk_mass_1, label=str(T) + "K", linewidth=3.0) 2853 plt.plot(mu_steps_2, sbk_mass_2, linewidth=3.0) 2854 if output == "average": 2855 plt.gca().get_lines()[1].set_c(plt.gca().get_lines()[0].get_c()) 2856 elif output == "tensor": 2857 plt.gca().get_lines()[3].set_c(plt.gca().get_lines()[0].get_c()) 2858 plt.gca().get_lines()[4].set_c(plt.gca().get_lines()[1].get_c()) 2859 plt.gca().get_lines()[5].set_c(plt.gca().get_lines()[2].get_c()) 2860 2861 plt.xlabel("E-E$_f$ (eV)", fontsize=30) 2862 plt.ylabel("Seebeck effective mass", fontsize=30) 2863 plt.xticks(fontsize=25) 2864 plt.yticks(fontsize=25) 2865 if output == "tensor": 2866 plt.legend( 2867 [str(i) + "_" + str(T) + "K" for T in temps for i in ("x", "y", "z")], 2868 fontsize=20, 2869 ) 2870 elif output == "average": 2871 plt.legend(fontsize=20) 2872 plt.tight_layout() 2873 return plt 2874 2875 def plot_complexity_factor_mu(self, temps=[300], output="average", Lambda=0.5): 2876 """ 2877 Plot respect to the chemical potential of the Fermi surface complexity 2878 factor calculated as explained in Ref. 2879 Gibbs, Z. M. et al., Effective mass and fermi surface complexity factor 2880 from ab initio band structure calculations. 2881 npj Computational Materials 3, 8 (2017). 2882 2883 Args: 2884 output: 'average' returns the complexity factor calculated using the average 2885 of the three diagonal components of the seebeck and conductivity tensors. 2886 'tensor' returns the complexity factor respect to the three 2887 diagonal components of seebeck and conductivity tensors. 2888 temps: list of temperatures of calculated seebeck and conductivity. 2889 Lambda: fitting parameter used to model the scattering (0.5 means constant 2890 relaxation time). 2891 Returns: 2892 a matplotlib object 2893 """ 2894 plt = pretty_plot(9, 7) 2895 for T in temps: 2896 cmplx_fact = self._bz.get_complexity_factor(output=output, temp=T, Lambda=Lambda) 2897 start = self._bz.mu_doping["p"][T][0] 2898 stop = self._bz.mu_doping["n"][T][0] 2899 mu_steps_1 = [] 2900 mu_steps_2 = [] 2901 cmplx_fact_1 = [] 2902 cmplx_fact_2 = [] 2903 for i, mu in enumerate(self._bz.mu_steps): 2904 if mu <= start: 2905 mu_steps_1.append(mu) 2906 cmplx_fact_1.append(cmplx_fact[i]) 2907 elif mu >= stop: 2908 mu_steps_2.append(mu) 2909 cmplx_fact_2.append(cmplx_fact[i]) 2910 2911 plt.plot(mu_steps_1, cmplx_fact_1, label=str(T) + "K", linewidth=3.0) 2912 plt.plot(mu_steps_2, cmplx_fact_2, linewidth=3.0) 2913 if output == "average": 2914 plt.gca().get_lines()[1].set_c(plt.gca().get_lines()[0].get_c()) 2915 elif output == "tensor": 2916 plt.gca().get_lines()[3].set_c(plt.gca().get_lines()[0].get_c()) 2917 plt.gca().get_lines()[4].set_c(plt.gca().get_lines()[1].get_c()) 2918 plt.gca().get_lines()[5].set_c(plt.gca().get_lines()[2].get_c()) 2919 2920 plt.xlabel("E-E$_f$ (eV)", fontsize=30) 2921 plt.ylabel("Complexity Factor", fontsize=30) 2922 plt.xticks(fontsize=25) 2923 plt.yticks(fontsize=25) 2924 if output == "tensor": 2925 plt.legend( 2926 [str(i) + "_" + str(T) + "K" for T in temps for i in ("x", "y", "z")], 2927 fontsize=20, 2928 ) 2929 elif output == "average": 2930 plt.legend(fontsize=20) 2931 plt.tight_layout() 2932 return plt 2933 2934 def plot_seebeck_mu(self, temp=600, output="eig", xlim=None): 2935 """ 2936 Plot the seebeck coefficient in function of Fermi level 2937 2938 Args: 2939 temp: 2940 the temperature 2941 xlim: 2942 a list of min and max fermi energy by default (0, and band gap) 2943 Returns: 2944 a matplotlib object 2945 """ 2946 plt = pretty_plot(9, 7) 2947 seebeck = self._bz.get_seebeck(output=output, doping_levels=False)[temp] 2948 plt.plot(self._bz.mu_steps, seebeck, linewidth=3.0) 2949 2950 self._plot_bg_limits(plt) 2951 self._plot_doping(plt, temp) 2952 if output == "eig": 2953 plt.legend(["S$_1$", "S$_2$", "S$_3$"]) 2954 if xlim is None: 2955 plt.xlim(-0.5, self._bz.gap + 0.5) 2956 else: 2957 plt.xlim(xlim[0], xlim[1]) 2958 plt.ylabel("Seebeck \n coefficient ($\\mu$V/K)", fontsize=30.0) 2959 plt.xlabel("E-E$_f$ (eV)", fontsize=30) 2960 plt.xticks(fontsize=25) 2961 plt.yticks(fontsize=25) 2962 plt.tight_layout() 2963 return plt 2964 2965 def plot_conductivity_mu(self, temp=600, output="eig", relaxation_time=1e-14, xlim=None): 2966 """ 2967 Plot the conductivity in function of Fermi level. Semi-log plot 2968 2969 Args: 2970 temp: the temperature 2971 xlim: a list of min and max fermi energy by default (0, and band 2972 gap) 2973 tau: A relaxation time in s. By default none and the plot is by 2974 units of relaxation time 2975 2976 Returns: 2977 a matplotlib object 2978 """ 2979 cond = self._bz.get_conductivity(relaxation_time=relaxation_time, output=output, doping_levels=False)[temp] 2980 plt = pretty_plot(9, 7) 2981 plt.semilogy(self._bz.mu_steps, cond, linewidth=3.0) 2982 self._plot_bg_limits(plt) 2983 self._plot_doping(plt, temp) 2984 if output == "eig": 2985 plt.legend(["$\\Sigma_1$", "$\\Sigma_2$", "$\\Sigma_3$"]) 2986 if xlim is None: 2987 plt.xlim(-0.5, self._bz.gap + 0.5) 2988 else: 2989 plt.xlim(xlim) 2990 plt.ylim([1e13 * relaxation_time, 1e20 * relaxation_time]) 2991 plt.ylabel("conductivity,\n $\\Sigma$ (1/($\\Omega$ m))", fontsize=30.0) 2992 plt.xlabel("E-E$_f$ (eV)", fontsize=30.0) 2993 plt.xticks(fontsize=25) 2994 plt.yticks(fontsize=25) 2995 plt.tight_layout() 2996 return plt 2997 2998 def plot_power_factor_mu(self, temp=600, output="eig", relaxation_time=1e-14, xlim=None): 2999 """ 3000 Plot the power factor in function of Fermi level. Semi-log plot 3001 3002 Args: 3003 temp: the temperature 3004 xlim: a list of min and max fermi energy by default (0, and band 3005 gap) 3006 tau: A relaxation time in s. By default none and the plot is by 3007 units of relaxation time 3008 3009 Returns: 3010 a matplotlib object 3011 """ 3012 plt = pretty_plot(9, 7) 3013 pf = self._bz.get_power_factor(relaxation_time=relaxation_time, output=output, doping_levels=False)[temp] 3014 plt.semilogy(self._bz.mu_steps, pf, linewidth=3.0) 3015 self._plot_bg_limits(plt) 3016 self._plot_doping(plt, temp) 3017 if output == "eig": 3018 plt.legend(["PF$_1$", "PF$_2$", "PF$_3$"]) 3019 if xlim is None: 3020 plt.xlim(-0.5, self._bz.gap + 0.5) 3021 else: 3022 plt.xlim(xlim) 3023 plt.ylabel("Power factor, ($\\mu$W/(mK$^2$))", fontsize=30.0) 3024 plt.xlabel("E-E$_f$ (eV)", fontsize=30.0) 3025 plt.xticks(fontsize=25) 3026 plt.yticks(fontsize=25) 3027 plt.tight_layout() 3028 return plt 3029 3030 def plot_zt_mu(self, temp=600, output="eig", relaxation_time=1e-14, xlim=None): 3031 """ 3032 Plot the ZT in function of Fermi level. 3033 3034 Args: 3035 temp: the temperature 3036 xlim: a list of min and max fermi energy by default (0, and band 3037 gap) 3038 tau: A relaxation time in s. By default none and the plot is by 3039 units of relaxation time 3040 3041 Returns: 3042 a matplotlib object 3043 """ 3044 plt = pretty_plot(9, 7) 3045 zt = self._bz.get_zt(relaxation_time=relaxation_time, output=output, doping_levels=False)[temp] 3046 plt.plot(self._bz.mu_steps, zt, linewidth=3.0) 3047 self._plot_bg_limits(plt) 3048 self._plot_doping(plt, temp) 3049 if output == "eig": 3050 plt.legend(["ZT$_1$", "ZT$_2$", "ZT$_3$"]) 3051 if xlim is None: 3052 plt.xlim(-0.5, self._bz.gap + 0.5) 3053 else: 3054 plt.xlim(xlim) 3055 plt.ylabel("ZT", fontsize=30.0) 3056 plt.xlabel("E-E$_f$ (eV)", fontsize=30.0) 3057 plt.xticks(fontsize=25) 3058 plt.yticks(fontsize=25) 3059 plt.tight_layout() 3060 return plt 3061 3062 def plot_seebeck_temp(self, doping="all", output="average"): 3063 """ 3064 Plot the Seebeck coefficient in function of temperature for different 3065 doping levels. 3066 3067 Args: 3068 dopings: the default 'all' plots all the doping levels in the analyzer. 3069 Specify a list of doping levels if you want to plot only some. 3070 output: with 'average' you get an average of the three directions 3071 with 'eigs' you get all the three directions. 3072 Returns: 3073 a matplotlib object 3074 """ 3075 3076 if output == "average": 3077 sbk = self._bz.get_seebeck(output="average") 3078 elif output == "eigs": 3079 sbk = self._bz.get_seebeck(output="eigs") 3080 3081 plt = pretty_plot(22, 14) 3082 tlist = sorted(sbk["n"].keys()) 3083 doping = self._bz.doping["n"] if doping == "all" else doping 3084 for i, dt in enumerate(["n", "p"]): 3085 plt.subplot(121 + i) 3086 for dop in doping: 3087 d = self._bz.doping[dt].index(dop) 3088 sbk_temp = [] 3089 for temp in tlist: 3090 sbk_temp.append(sbk[dt][temp][d]) 3091 if output == "average": 3092 plt.plot(tlist, sbk_temp, marker="s", label=str(dop) + " $cm^{-3}$") 3093 elif output == "eigs": 3094 for xyz in range(3): 3095 plt.plot( 3096 tlist, 3097 list(zip(*sbk_temp))[xyz], 3098 marker="s", 3099 label=str(xyz) + " " + str(dop) + " $cm^{-3}$", 3100 ) 3101 plt.title(dt + "-type", fontsize=20) 3102 if i == 0: 3103 plt.ylabel("Seebeck \n coefficient ($\\mu$V/K)", fontsize=30.0) 3104 plt.xlabel("Temperature (K)", fontsize=30.0) 3105 3106 p = "lower right" if i == 0 else "best" 3107 plt.legend(loc=p, fontsize=15) 3108 plt.grid() 3109 plt.xticks(fontsize=25) 3110 plt.yticks(fontsize=25) 3111 3112 plt.tight_layout() 3113 3114 return plt 3115 3116 def plot_conductivity_temp(self, doping="all", output="average", relaxation_time=1e-14): 3117 """ 3118 Plot the conductivity in function of temperature for different doping levels. 3119 3120 Args: 3121 dopings: the default 'all' plots all the doping levels in the analyzer. 3122 Specify a list of doping levels if you want to plot only some. 3123 output: with 'average' you get an average of the three directions 3124 with 'eigs' you get all the three directions. 3125 relaxation_time: specify a constant relaxation time value 3126 3127 Returns: 3128 a matplotlib object 3129 """ 3130 3131 if output == "average": 3132 cond = self._bz.get_conductivity(relaxation_time=relaxation_time, output="average") 3133 elif output == "eigs": 3134 cond = self._bz.get_conductivity(relaxation_time=relaxation_time, output="eigs") 3135 3136 plt = pretty_plot(22, 14) 3137 tlist = sorted(cond["n"].keys()) 3138 doping = self._bz.doping["n"] if doping == "all" else doping 3139 for i, dt in enumerate(["n", "p"]): 3140 plt.subplot(121 + i) 3141 for dop in doping: 3142 d = self._bz.doping[dt].index(dop) 3143 cond_temp = [] 3144 for temp in tlist: 3145 cond_temp.append(cond[dt][temp][d]) 3146 if output == "average": 3147 plt.plot(tlist, cond_temp, marker="s", label=str(dop) + " $cm^{-3}$") 3148 elif output == "eigs": 3149 for xyz in range(3): 3150 plt.plot( 3151 tlist, 3152 list(zip(*cond_temp))[xyz], 3153 marker="s", 3154 label=str(xyz) + " " + str(dop) + " $cm^{-3}$", 3155 ) 3156 plt.title(dt + "-type", fontsize=20) 3157 if i == 0: 3158 plt.ylabel("conductivity $\\sigma$ (1/($\\Omega$ m))", fontsize=30.0) 3159 plt.xlabel("Temperature (K)", fontsize=30.0) 3160 3161 p = "best" # 'lower right' if i == 0 else '' 3162 plt.legend(loc=p, fontsize=15) 3163 plt.grid() 3164 plt.xticks(fontsize=25) 3165 plt.yticks(fontsize=25) 3166 plt.ticklabel_format(style="sci", axis="y", scilimits=(0, 0)) 3167 3168 plt.tight_layout() 3169 3170 return plt 3171 3172 def plot_power_factor_temp(self, doping="all", output="average", relaxation_time=1e-14): 3173 """ 3174 Plot the Power Factor in function of temperature for different doping levels. 3175 3176 Args: 3177 dopings: the default 'all' plots all the doping levels in the analyzer. 3178 Specify a list of doping levels if you want to plot only some. 3179 output: with 'average' you get an average of the three directions 3180 with 'eigs' you get all the three directions. 3181 relaxation_time: specify a constant relaxation time value 3182 3183 Returns: 3184 a matplotlib object 3185 """ 3186 3187 if output == "average": 3188 pf = self._bz.get_power_factor(relaxation_time=relaxation_time, output="average") 3189 elif output == "eigs": 3190 pf = self._bz.get_power_factor(relaxation_time=relaxation_time, output="eigs") 3191 3192 plt = pretty_plot(22, 14) 3193 tlist = sorted(pf["n"].keys()) 3194 doping = self._bz.doping["n"] if doping == "all" else doping 3195 for i, dt in enumerate(["n", "p"]): 3196 plt.subplot(121 + i) 3197 for dop in doping: 3198 d = self._bz.doping[dt].index(dop) 3199 pf_temp = [] 3200 for temp in tlist: 3201 pf_temp.append(pf[dt][temp][d]) 3202 if output == "average": 3203 plt.plot(tlist, pf_temp, marker="s", label=str(dop) + " $cm^{-3}$") 3204 elif output == "eigs": 3205 for xyz in range(3): 3206 plt.plot( 3207 tlist, 3208 list(zip(*pf_temp))[xyz], 3209 marker="s", 3210 label=str(xyz) + " " + str(dop) + " $cm^{-3}$", 3211 ) 3212 plt.title(dt + "-type", fontsize=20) 3213 if i == 0: 3214 plt.ylabel("Power Factor ($\\mu$W/(mK$^2$))", fontsize=30.0) 3215 plt.xlabel("Temperature (K)", fontsize=30.0) 3216 3217 p = "best" # 'lower right' if i == 0 else '' 3218 plt.legend(loc=p, fontsize=15) 3219 plt.grid() 3220 plt.xticks(fontsize=25) 3221 plt.yticks(fontsize=25) 3222 plt.ticklabel_format(style="sci", axis="y", scilimits=(0, 0)) 3223 3224 plt.tight_layout() 3225 return plt 3226 3227 def plot_zt_temp(self, doping="all", output="average", relaxation_time=1e-14): 3228 """ 3229 Plot the figure of merit zT in function of temperature for different doping levels. 3230 3231 Args: 3232 dopings: the default 'all' plots all the doping levels in the analyzer. 3233 Specify a list of doping levels if you want to plot only some. 3234 output: with 'average' you get an average of the three directions 3235 with 'eigs' you get all the three directions. 3236 relaxation_time: specify a constant relaxation time value 3237 3238 Returns: 3239 a matplotlib object 3240 """ 3241 3242 if output == "average": 3243 zt = self._bz.get_zt(relaxation_time=relaxation_time, output="average") 3244 elif output == "eigs": 3245 zt = self._bz.get_zt(relaxation_time=relaxation_time, output="eigs") 3246 3247 plt = pretty_plot(22, 14) 3248 tlist = sorted(zt["n"].keys()) 3249 doping = self._bz.doping["n"] if doping == "all" else doping 3250 for i, dt in enumerate(["n", "p"]): 3251 plt.subplot(121 + i) 3252 for dop in doping: 3253 d = self._bz.doping[dt].index(dop) 3254 zt_temp = [] 3255 for temp in tlist: 3256 zt_temp.append(zt[dt][temp][d]) 3257 if output == "average": 3258 plt.plot(tlist, zt_temp, marker="s", label=str(dop) + " $cm^{-3}$") 3259 elif output == "eigs": 3260 for xyz in range(3): 3261 plt.plot( 3262 tlist, 3263 list(zip(*zt_temp))[xyz], 3264 marker="s", 3265 label=str(xyz) + " " + str(dop) + " $cm^{-3}$", 3266 ) 3267 plt.title(dt + "-type", fontsize=20) 3268 if i == 0: 3269 plt.ylabel("zT", fontsize=30.0) 3270 plt.xlabel("Temperature (K)", fontsize=30.0) 3271 3272 p = "best" # 'lower right' if i == 0 else '' 3273 plt.legend(loc=p, fontsize=15) 3274 plt.grid() 3275 plt.xticks(fontsize=25) 3276 plt.yticks(fontsize=25) 3277 3278 plt.tight_layout() 3279 return plt 3280 3281 def plot_eff_mass_temp(self, doping="all", output="average"): 3282 """ 3283 Plot the average effective mass in function of temperature 3284 for different doping levels. 3285 3286 Args: 3287 dopings: the default 'all' plots all the doping levels in the analyzer. 3288 Specify a list of doping levels if you want to plot only some. 3289 output: with 'average' you get an average of the three directions 3290 with 'eigs' you get all the three directions. 3291 3292 Returns: 3293 a matplotlib object 3294 """ 3295 3296 if output == "average": 3297 em = self._bz.get_average_eff_mass(output="average") 3298 elif output == "eigs": 3299 em = self._bz.get_average_eff_mass(output="eigs") 3300 3301 plt = pretty_plot(22, 14) 3302 tlist = sorted(em["n"].keys()) 3303 doping = self._bz.doping["n"] if doping == "all" else doping 3304 for i, dt in enumerate(["n", "p"]): 3305 plt.subplot(121 + i) 3306 for dop in doping: 3307 d = self._bz.doping[dt].index(dop) 3308 em_temp = [] 3309 for temp in tlist: 3310 em_temp.append(em[dt][temp][d]) 3311 if output == "average": 3312 plt.plot(tlist, em_temp, marker="s", label=str(dop) + " $cm^{-3}$") 3313 elif output == "eigs": 3314 for xyz in range(3): 3315 plt.plot( 3316 tlist, 3317 list(zip(*em_temp))[xyz], 3318 marker="s", 3319 label=str(xyz) + " " + str(dop) + " $cm^{-3}$", 3320 ) 3321 plt.title(dt + "-type", fontsize=20) 3322 if i == 0: 3323 plt.ylabel("Effective mass (m$_e$)", fontsize=30.0) 3324 plt.xlabel("Temperature (K)", fontsize=30.0) 3325 3326 p = "best" # 'lower right' if i == 0 else '' 3327 plt.legend(loc=p, fontsize=15) 3328 plt.grid() 3329 plt.xticks(fontsize=25) 3330 plt.yticks(fontsize=25) 3331 3332 plt.tight_layout() 3333 return plt 3334 3335 def plot_seebeck_dop(self, temps="all", output="average"): 3336 """ 3337 Plot the Seebeck in function of doping levels for different temperatures. 3338 3339 Args: 3340 temps: the default 'all' plots all the temperatures in the analyzer. 3341 Specify a list of temperatures if you want to plot only some. 3342 output: with 'average' you get an average of the three directions 3343 with 'eigs' you get all the three directions. 3344 3345 Returns: 3346 a matplotlib object 3347 """ 3348 3349 if output == "average": 3350 sbk = self._bz.get_seebeck(output="average") 3351 elif output == "eigs": 3352 sbk = self._bz.get_seebeck(output="eigs") 3353 3354 tlist = sorted(sbk["n"].keys()) if temps == "all" else temps 3355 plt = pretty_plot(22, 14) 3356 for i, dt in enumerate(["n", "p"]): 3357 plt.subplot(121 + i) 3358 for temp in tlist: 3359 if output == "eigs": 3360 for xyz in range(3): 3361 plt.semilogx( 3362 self._bz.doping[dt], 3363 list(zip(*sbk[dt][temp]))[xyz], 3364 marker="s", 3365 label=str(xyz) + " " + str(temp) + " K", 3366 ) 3367 elif output == "average": 3368 plt.semilogx( 3369 self._bz.doping[dt], 3370 sbk[dt][temp], 3371 marker="s", 3372 label=str(temp) + " K", 3373 ) 3374 plt.title(dt + "-type", fontsize=20) 3375 if i == 0: 3376 plt.ylabel("Seebeck coefficient ($\\mu$V/K)", fontsize=30.0) 3377 plt.xlabel("Doping concentration (cm$^{-3}$)", fontsize=30.0) 3378 3379 p = "lower right" if i == 0 else "best" 3380 plt.legend(loc=p, fontsize=15) 3381 plt.grid() 3382 plt.xticks(fontsize=25) 3383 plt.yticks(fontsize=25) 3384 3385 plt.tight_layout() 3386 3387 return plt 3388 3389 def plot_conductivity_dop(self, temps="all", output="average", relaxation_time=1e-14): 3390 """ 3391 Plot the conductivity in function of doping levels for different 3392 temperatures. 3393 3394 Args: 3395 temps: the default 'all' plots all the temperatures in the analyzer. 3396 Specify a list of temperatures if you want to plot only some. 3397 output: with 'average' you get an average of the three directions 3398 with 'eigs' you get all the three directions. 3399 relaxation_time: specify a constant relaxation time value 3400 3401 Returns: 3402 a matplotlib object 3403 """ 3404 if output == "average": 3405 cond = self._bz.get_conductivity(relaxation_time=relaxation_time, output="average") 3406 elif output == "eigs": 3407 cond = self._bz.get_conductivity(relaxation_time=relaxation_time, output="eigs") 3408 3409 tlist = sorted(cond["n"].keys()) if temps == "all" else temps 3410 plt = pretty_plot(22, 14) 3411 for i, dt in enumerate(["n", "p"]): 3412 plt.subplot(121 + i) 3413 for temp in tlist: 3414 if output == "eigs": 3415 for xyz in range(3): 3416 plt.semilogx( 3417 self._bz.doping[dt], 3418 list(zip(*cond[dt][temp]))[xyz], 3419 marker="s", 3420 label=str(xyz) + " " + str(temp) + " K", 3421 ) 3422 elif output == "average": 3423 plt.semilogx( 3424 self._bz.doping[dt], 3425 cond[dt][temp], 3426 marker="s", 3427 label=str(temp) + " K", 3428 ) 3429 plt.title(dt + "-type", fontsize=20) 3430 if i == 0: 3431 plt.ylabel("conductivity $\\sigma$ (1/($\\Omega$ m))", fontsize=30.0) 3432 plt.xlabel("Doping concentration ($cm^{-3}$)", fontsize=30.0) 3433 plt.ticklabel_format(style="sci", axis="y", scilimits=(0, 0)) 3434 plt.legend(fontsize=15) 3435 plt.grid() 3436 plt.xticks(fontsize=25) 3437 plt.yticks(fontsize=25) 3438 3439 plt.tight_layout() 3440 3441 return plt 3442 3443 def plot_power_factor_dop(self, temps="all", output="average", relaxation_time=1e-14): 3444 """ 3445 Plot the Power Factor in function of doping levels for different temperatures. 3446 3447 Args: 3448 temps: the default 'all' plots all the temperatures in the analyzer. 3449 Specify a list of temperatures if you want to plot only some. 3450 output: with 'average' you get an average of the three directions 3451 with 'eigs' you get all the three directions. 3452 relaxation_time: specify a constant relaxation time value 3453 3454 Returns: 3455 a matplotlib object 3456 """ 3457 if output == "average": 3458 pf = self._bz.get_power_factor(relaxation_time=relaxation_time, output="average") 3459 elif output == "eigs": 3460 pf = self._bz.get_power_factor(relaxation_time=relaxation_time, output="eigs") 3461 3462 tlist = sorted(pf["n"].keys()) if temps == "all" else temps 3463 plt = pretty_plot(22, 14) 3464 for i, dt in enumerate(["n", "p"]): 3465 plt.subplot(121 + i) 3466 for temp in tlist: 3467 if output == "eigs": 3468 for xyz in range(3): 3469 plt.semilogx( 3470 self._bz.doping[dt], 3471 list(zip(*pf[dt][temp]))[xyz], 3472 marker="s", 3473 label=str(xyz) + " " + str(temp) + " K", 3474 ) 3475 elif output == "average": 3476 plt.semilogx( 3477 self._bz.doping[dt], 3478 pf[dt][temp], 3479 marker="s", 3480 label=str(temp) + " K", 3481 ) 3482 plt.title(dt + "-type", fontsize=20) 3483 if i == 0: 3484 plt.ylabel("Power Factor ($\\mu$W/(mK$^2$))", fontsize=30.0) 3485 plt.xlabel("Doping concentration ($cm^{-3}$)", fontsize=30.0) 3486 plt.ticklabel_format(style="sci", axis="y", scilimits=(0, 0)) 3487 p = "best" # 'lower right' if i == 0 else '' 3488 plt.legend(loc=p, fontsize=15) 3489 plt.grid() 3490 plt.xticks(fontsize=25) 3491 plt.yticks(fontsize=25) 3492 3493 plt.tight_layout() 3494 3495 return plt 3496 3497 def plot_zt_dop(self, temps="all", output="average", relaxation_time=1e-14): 3498 """ 3499 Plot the figure of merit zT in function of doping levels for different 3500 temperatures. 3501 3502 Args: 3503 temps: the default 'all' plots all the temperatures in the analyzer. 3504 Specify a list of temperatures if you want to plot only some. 3505 output: with 'average' you get an average of the three directions 3506 with 'eigs' you get all the three directions. 3507 relaxation_time: specify a constant relaxation time value 3508 3509 Returns: 3510 a matplotlib object 3511 """ 3512 if output == "average": 3513 zt = self._bz.get_zt(relaxation_time=relaxation_time, output="average") 3514 elif output == "eigs": 3515 zt = self._bz.get_zt(relaxation_time=relaxation_time, output="eigs") 3516 3517 tlist = sorted(zt["n"].keys()) if temps == "all" else temps 3518 plt = pretty_plot(22, 14) 3519 for i, dt in enumerate(["n", "p"]): 3520 plt.subplot(121 + i) 3521 for temp in tlist: 3522 if output == "eigs": 3523 for xyz in range(3): 3524 plt.semilogx( 3525 self._bz.doping[dt], 3526 list(zip(*zt[dt][temp]))[xyz], 3527 marker="s", 3528 label=str(xyz) + " " + str(temp) + " K", 3529 ) 3530 elif output == "average": 3531 plt.semilogx( 3532 self._bz.doping[dt], 3533 zt[dt][temp], 3534 marker="s", 3535 label=str(temp) + " K", 3536 ) 3537 plt.title(dt + "-type", fontsize=20) 3538 if i == 0: 3539 plt.ylabel("zT", fontsize=30.0) 3540 plt.xlabel("Doping concentration ($cm^{-3}$)", fontsize=30.0) 3541 3542 p = "lower right" if i == 0 else "best" 3543 plt.legend(loc=p, fontsize=15) 3544 plt.grid() 3545 plt.xticks(fontsize=25) 3546 plt.yticks(fontsize=25) 3547 3548 plt.tight_layout() 3549 3550 return plt 3551 3552 def plot_eff_mass_dop(self, temps="all", output="average"): 3553 """ 3554 Plot the average effective mass in function of doping levels 3555 for different temperatures. 3556 3557 Args: 3558 temps: the default 'all' plots all the temperatures in the analyzer. 3559 Specify a list of temperatures if you want to plot only some. 3560 output: with 'average' you get an average of the three directions 3561 with 'eigs' you get all the three directions. 3562 relaxation_time: specify a constant relaxation time value 3563 3564 Returns: 3565 a matplotlib object 3566 """ 3567 3568 if output == "average": 3569 em = self._bz.get_average_eff_mass(output="average") 3570 elif output == "eigs": 3571 em = self._bz.get_average_eff_mass(output="eigs") 3572 3573 tlist = sorted(em["n"].keys()) if temps == "all" else temps 3574 plt = pretty_plot(22, 14) 3575 for i, dt in enumerate(["n", "p"]): 3576 plt.subplot(121 + i) 3577 for temp in tlist: 3578 if output == "eigs": 3579 for xyz in range(3): 3580 plt.semilogx( 3581 self._bz.doping[dt], 3582 list(zip(*em[dt][temp]))[xyz], 3583 marker="s", 3584 label=str(xyz) + " " + str(temp) + " K", 3585 ) 3586 elif output == "average": 3587 plt.semilogx( 3588 self._bz.doping[dt], 3589 em[dt][temp], 3590 marker="s", 3591 label=str(temp) + " K", 3592 ) 3593 plt.title(dt + "-type", fontsize=20) 3594 if i == 0: 3595 plt.ylabel("Effective mass (m$_e$)", fontsize=30.0) 3596 plt.xlabel("Doping concentration ($cm^{-3}$)", fontsize=30.0) 3597 3598 p = "lower right" if i == 0 else "best" 3599 plt.legend(loc=p, fontsize=15) 3600 plt.grid() 3601 plt.xticks(fontsize=25) 3602 plt.yticks(fontsize=25) 3603 3604 plt.tight_layout() 3605 3606 return plt 3607 3608 def plot_dos(self, sigma=0.05): 3609 """ 3610 plot dos 3611 3612 Args: 3613 sigma: a smearing 3614 3615 Returns: 3616 a matplotlib object 3617 """ 3618 plotter = DosPlotter(sigma=sigma) 3619 plotter.add_dos("t", self._bz.dos) 3620 return plotter.get_plot() 3621 3622 def plot_carriers(self, temp=300): 3623 """ 3624 Plot the carrier concentration in function of Fermi level 3625 3626 Args: 3627 temp: the temperature 3628 3629 Returns: 3630 a matplotlib object 3631 """ 3632 plt = pretty_plot(9, 7) 3633 carriers = [abs(c / (self._bz.vol * 1e-24)) for c in self._bz._carrier_conc[temp]] 3634 plt.semilogy(self._bz.mu_steps, carriers, linewidth=3.0, color="r") 3635 self._plot_bg_limits(plt) 3636 self._plot_doping(plt, temp) 3637 plt.xlim(-0.5, self._bz.gap + 0.5) 3638 plt.ylim(1e14, 1e22) 3639 plt.ylabel("carrier concentration (cm-3)", fontsize=30.0) 3640 plt.xlabel("E-E$_f$ (eV)", fontsize=30) 3641 plt.xticks(fontsize=25) 3642 plt.yticks(fontsize=25) 3643 plt.tight_layout() 3644 return plt 3645 3646 def plot_hall_carriers(self, temp=300): 3647 """ 3648 Plot the Hall carrier concentration in function of Fermi level 3649 3650 Args: 3651 temp: the temperature 3652 3653 Returns: 3654 a matplotlib object 3655 """ 3656 plt = pretty_plot(9, 7) 3657 hall_carriers = [abs(i) for i in self._bz.get_hall_carrier_concentration()[temp]] 3658 plt.semilogy(self._bz.mu_steps, hall_carriers, linewidth=3.0, color="r") 3659 self._plot_bg_limits(plt) 3660 self._plot_doping(plt, temp) 3661 plt.xlim(-0.5, self._bz.gap + 0.5) 3662 plt.ylim(1e14, 1e22) 3663 plt.ylabel("Hall carrier concentration (cm-3)", fontsize=30.0) 3664 plt.xlabel("E-E$_f$ (eV)", fontsize=30) 3665 plt.xticks(fontsize=25) 3666 plt.yticks(fontsize=25) 3667 plt.tight_layout() 3668 return plt 3669 3670 3671class CohpPlotter: 3672 """ 3673 Class for plotting crystal orbital Hamilton populations (COHPs) or 3674 crystal orbital overlap populations (COOPs). It is modeled after the 3675 DosPlotter object. 3676 """ 3677 3678 def __init__(self, zero_at_efermi=True, are_coops=False, are_cobis=False): 3679 """ 3680 Args: 3681 zero_at_efermi: Whether to shift all populations to have zero 3682 energy at the Fermi level. Defaults to True. 3683 are_coops: Switch to indicate that these are COOPs, not COHPs. 3684 Defaults to False for COHPs. 3685 are_cobis: Switch to indicate that these are COBIs, not COHPs/COOPs. 3686 Defaults to False for COHPs 3687 """ 3688 self.zero_at_efermi = zero_at_efermi 3689 self.are_coops = are_coops 3690 self.are_cobis = are_cobis 3691 self._cohps = OrderedDict() 3692 3693 def add_cohp(self, label, cohp): 3694 """ 3695 Adds a COHP for plotting. 3696 3697 Args: 3698 label: Label for the COHP. Must be unique. 3699 3700 cohp: COHP object. 3701 """ 3702 energies = cohp.energies - cohp.efermi if self.zero_at_efermi else cohp.energies 3703 populations = cohp.get_cohp() 3704 int_populations = cohp.get_icohp() 3705 self._cohps[label] = { 3706 "energies": energies, 3707 "COHP": populations, 3708 "ICOHP": int_populations, 3709 "efermi": cohp.efermi, 3710 } 3711 3712 def add_cohp_dict(self, cohp_dict, key_sort_func=None): 3713 """ 3714 Adds a dictionary of COHPs with an optional sorting function 3715 for the keys. 3716 3717 Args: 3718 cohp_dict: dict of the form {label: Cohp} 3719 3720 key_sort_func: function used to sort the cohp_dict keys. 3721 """ 3722 if key_sort_func: 3723 keys = sorted(cohp_dict.keys(), key=key_sort_func) 3724 else: 3725 keys = cohp_dict.keys() 3726 for label in keys: 3727 self.add_cohp(label, cohp_dict[label]) 3728 3729 def get_cohp_dict(self): 3730 """ 3731 Returns the added COHPs as a json-serializable dict. Note that if you 3732 have specified smearing for the COHP plot, the populations returned 3733 will be the smeared and not the original populations. 3734 3735 Returns: 3736 dict: Dict of COHP data of the form {label: {"efermi": efermi, 3737 "energies": ..., "COHP": {Spin.up: ...}, "ICOHP": ...}}. 3738 """ 3739 return jsanitize(self._cohps) 3740 3741 def get_plot( 3742 self, 3743 xlim=None, 3744 ylim=None, 3745 plot_negative=None, 3746 integrated=False, 3747 invert_axes=True, 3748 ): 3749 """ 3750 Get a matplotlib plot showing the COHP. 3751 3752 Args: 3753 xlim: Specifies the x-axis limits. Defaults to None for 3754 automatic determination. 3755 3756 ylim: Specifies the y-axis limits. Defaults to None for 3757 automatic determination. 3758 3759 plot_negative: It is common to plot -COHP(E) so that the 3760 sign means the same for COOPs and COHPs. Defaults to None 3761 for automatic determination: If are_coops is True, this 3762 will be set to False, else it will be set to True. 3763 3764 integrated: Switch to plot ICOHPs. Defaults to False. 3765 3766 invert_axes: Put the energies onto the y-axis, which is 3767 common in chemistry. 3768 3769 Returns: 3770 A matplotlib object. 3771 """ 3772 if self.are_coops: 3773 cohp_label = "COOP" 3774 elif self.are_cobis: 3775 cohp_label = "COBI" 3776 else: 3777 cohp_label = "COHP" 3778 3779 if plot_negative is None: 3780 plot_negative = (not self.are_coops) and (not self.are_cobis) 3781 3782 if integrated: 3783 cohp_label = "I" + cohp_label + " (eV)" 3784 3785 if plot_negative: 3786 cohp_label = "-" + cohp_label 3787 3788 if self.zero_at_efermi: 3789 energy_label = "$E - E_f$ (eV)" 3790 else: 3791 energy_label = "$E$ (eV)" 3792 3793 ncolors = max(3, len(self._cohps)) 3794 ncolors = min(9, ncolors) 3795 3796 import palettable 3797 3798 # pylint: disable=E1101 3799 colors = palettable.colorbrewer.qualitative.Set1_9.mpl_colors 3800 3801 plt = pretty_plot(12, 8) 3802 3803 allpts = [] 3804 keys = self._cohps.keys() 3805 for i, key in enumerate(keys): 3806 energies = self._cohps[key]["energies"] 3807 if not integrated: 3808 populations = self._cohps[key]["COHP"] 3809 else: 3810 populations = self._cohps[key]["ICOHP"] 3811 for spin in [Spin.up, Spin.down]: 3812 if spin in populations: 3813 if invert_axes: 3814 x = -populations[spin] if plot_negative else populations[spin] 3815 y = energies 3816 else: 3817 x = energies 3818 y = -populations[spin] if plot_negative else populations[spin] 3819 allpts.extend(list(zip(x, y))) 3820 if spin == Spin.up: 3821 plt.plot( 3822 x, 3823 y, 3824 color=colors[i % ncolors], 3825 linestyle="-", 3826 label=str(key), 3827 linewidth=3, 3828 ) 3829 else: 3830 plt.plot(x, y, color=colors[i % ncolors], linestyle="--", linewidth=3) 3831 3832 if xlim: 3833 plt.xlim(xlim) 3834 if ylim: 3835 plt.ylim(ylim) 3836 else: 3837 xlim = plt.xlim() 3838 relevanty = [p[1] for p in allpts if xlim[0] < p[0] < xlim[1]] 3839 plt.ylim((min(relevanty), max(relevanty))) 3840 3841 xlim = plt.xlim() 3842 ylim = plt.ylim() 3843 if not invert_axes: 3844 plt.plot(xlim, [0, 0], "k-", linewidth=2) 3845 if self.zero_at_efermi: 3846 plt.plot([0, 0], ylim, "k--", linewidth=2) 3847 else: 3848 plt.plot( 3849 [self._cohps[key]["efermi"], self._cohps[key]["efermi"]], 3850 ylim, 3851 color=colors[i % ncolors], 3852 linestyle="--", 3853 linewidth=2, 3854 ) 3855 else: 3856 plt.plot([0, 0], ylim, "k-", linewidth=2) 3857 if self.zero_at_efermi: 3858 plt.plot(xlim, [0, 0], "k--", linewidth=2) 3859 else: 3860 plt.plot( 3861 xlim, 3862 [self._cohps[key]["efermi"], self._cohps[key]["efermi"]], 3863 color=colors[i % ncolors], 3864 linestyle="--", 3865 linewidth=2, 3866 ) 3867 3868 if invert_axes: 3869 plt.xlabel(cohp_label) 3870 plt.ylabel(energy_label) 3871 else: 3872 plt.xlabel(energy_label) 3873 plt.ylabel(cohp_label) 3874 3875 plt.legend() 3876 leg = plt.gca().get_legend() 3877 ltext = leg.get_texts() 3878 plt.setp(ltext, fontsize=30) 3879 plt.tight_layout() 3880 return plt 3881 3882 def save_plot(self, filename, img_format="eps", xlim=None, ylim=None): 3883 """ 3884 Save matplotlib plot to a file. 3885 3886 Args: 3887 filename: File name to write to. 3888 img_format: Image format to use. Defaults to EPS. 3889 xlim: Specifies the x-axis limits. Defaults to None for 3890 automatic determination. 3891 ylim: Specifies the y-axis limits. Defaults to None for 3892 automatic determination. 3893 """ 3894 plt = self.get_plot(xlim, ylim) 3895 plt.savefig(filename, format=img_format) 3896 3897 def show(self, xlim=None, ylim=None): 3898 """ 3899 Show the plot using matplotlib. 3900 3901 Args: 3902 xlim: Specifies the x-axis limits. Defaults to None for 3903 automatic determination. 3904 ylim: Specifies the y-axis limits. Defaults to None for 3905 automatic determination. 3906 """ 3907 plt = self.get_plot(xlim, ylim) 3908 plt.show() 3909 3910 3911@requires(mlab is not None, "MayAvi mlab not imported! Please install mayavi.") 3912def plot_fermi_surface( 3913 data, 3914 structure, 3915 cbm, 3916 energy_levels=None, 3917 multiple_figure=True, 3918 mlab_figure=None, 3919 kpoints_dict=None, 3920 colors=None, 3921 transparency_factor=None, 3922 labels_scale_factor=0.05, 3923 points_scale_factor=0.02, 3924 interative=True, 3925): 3926 """ 3927 Plot the Fermi surface at specific energy value using Boltztrap 1 FERMI 3928 mode. 3929 3930 The easiest way to use this plotter is: 3931 3932 1. Run boltztrap in 'FERMI' mode using BoltztrapRunner, 3933 2. Load BoltztrapAnalyzer using your method of choice (e.g., from_files) 3934 3. Pass in your BoltztrapAnalyzer's fermi_surface_data as this 3935 function's data argument. 3936 3937 Args: 3938 data: energy values in a 3D grid from a CUBE file via read_cube_file 3939 function, or from a BoltztrapAnalyzer.fermi_surface_data 3940 structure: structure object of the material 3941 energy_levels ([float]): Energy values for plotting the fermi surface(s) 3942 By default 0 eV correspond to the VBM, as in the plot of band 3943 structure along symmetry line. 3944 Default: One surface, with max energy value + 0.01 eV 3945 cbm (bool): Boolean value to specify if the considered band is a 3946 conduction band or not 3947 multiple_figure (bool): If True a figure for each energy level will be 3948 shown. If False all the surfaces will be shown in the same figure. 3949 In this last case, tune the transparency factor. 3950 mlab_figure (mayavi.mlab.figure): A previous figure to plot a new 3951 surface on. 3952 kpoints_dict (dict): dictionary of kpoints to label in the plot. 3953 Example: {"K":[0.5,0.0,0.5]}, coords are fractional 3954 colors ([tuple]): Iterable of 3-tuples (r,g,b) of integers to define 3955 the colors of each surface (one per energy level). 3956 Should be the same length as the number of surfaces being plotted. 3957 Example (3 surfaces): colors=[(1,0,0), (0,1,0), (0,0,1)] 3958 Example (2 surfaces): colors=[(0, 0.5, 0.5)] 3959 transparency_factor [float]: Values in the range [0,1] to tune the 3960 opacity of each surface. Should be one transparency_factor per 3961 surface. 3962 labels_scale_factor (float): factor to tune size of the kpoint labels 3963 points_scale_factor (float): factor to tune size of the kpoint points 3964 interative (bool): if True an interactive figure will be shown. 3965 If False a non interactive figure will be shown, but it is possible 3966 to plot other surfaces on the same figure. To make it interactive, 3967 run mlab.show(). 3968 Returns: 3969 ((mayavi.mlab.figure, mayavi.mlab)): The mlab plotter and an interactive 3970 figure to control the plot. 3971 3972 Note: Experimental. 3973 Please, double check the surface shown by using some 3974 other software and report issues. 3975 """ 3976 bz = structure.lattice.reciprocal_lattice.get_wigner_seitz_cell() 3977 cell = structure.lattice.reciprocal_lattice.matrix 3978 3979 fact = 1 if not cbm else -1 3980 data_1d = data.ravel() 3981 en_min = np.min(fact * data_1d) 3982 en_max = np.max(fact * data_1d) 3983 3984 if energy_levels is None: 3985 energy_levels = [en_min + 0.01] if cbm else [en_max - 0.01] 3986 print("Energy level set to: " + str(energy_levels[0]) + " eV") 3987 3988 else: 3989 for e in energy_levels: 3990 if e > en_max or e < en_min: 3991 raise BoltztrapError( 3992 "energy level " 3993 + str(e) 3994 + " not in the range of possible energies: [" 3995 + str(en_min) 3996 + ", " 3997 + str(en_max) 3998 + "]" 3999 ) 4000 4001 n_surfaces = len(energy_levels) 4002 if colors is None: 4003 colors = [(0, 0, 1)] * n_surfaces 4004 4005 if transparency_factor is None: 4006 transparency_factor = [1] * n_surfaces 4007 4008 if mlab_figure: 4009 fig = mlab_figure 4010 4011 if kpoints_dict is None: 4012 kpoints_dict = {} 4013 4014 if mlab_figure is None and not multiple_figure: 4015 fig = mlab.figure(size=(1024, 768), bgcolor=(1, 1, 1)) 4016 for iface in range(len(bz)): # pylint: disable=C0200 4017 for line in itertools.combinations(bz[iface], 2): 4018 for jface in range(len(bz)): # pylint: disable=C0200 4019 if ( 4020 iface < jface 4021 and any(np.all(line[0] == x) for x in bz[jface]) 4022 and any(np.all(line[1] == x) for x in bz[jface]) 4023 ): 4024 mlab.plot3d( 4025 *zip(line[0], line[1]), 4026 color=(0, 0, 0), 4027 tube_radius=None, 4028 figure=fig, 4029 ) 4030 for label, coords in kpoints_dict.items(): 4031 label_coords = structure.lattice.reciprocal_lattice.get_cartesian_coords(coords) 4032 mlab.points3d( 4033 *label_coords, 4034 scale_factor=points_scale_factor, 4035 color=(0, 0, 0), 4036 figure=fig, 4037 ) 4038 mlab.text3d( 4039 *label_coords, 4040 text=label, 4041 scale=labels_scale_factor, 4042 color=(0, 0, 0), 4043 figure=fig, 4044 ) 4045 4046 for i, isolevel in enumerate(energy_levels): 4047 alpha = transparency_factor[i] 4048 color = colors[i] 4049 if multiple_figure: 4050 fig = mlab.figure(size=(1024, 768), bgcolor=(1, 1, 1)) 4051 4052 for iface in range(len(bz)): # pylint: disable=C0200 4053 for line in itertools.combinations(bz[iface], 2): 4054 for jface in range(len(bz)): 4055 if ( 4056 iface < jface 4057 and any(np.all(line[0] == x) for x in bz[jface]) 4058 and any(np.all(line[1] == x) for x in bz[jface]) 4059 ): 4060 mlab.plot3d( 4061 *zip(line[0], line[1]), 4062 color=(0, 0, 0), 4063 tube_radius=None, 4064 figure=fig, 4065 ) 4066 4067 for label, coords in kpoints_dict.items(): 4068 label_coords = structure.lattice.reciprocal_lattice.get_cartesian_coords(coords) 4069 mlab.points3d( 4070 *label_coords, 4071 scale_factor=points_scale_factor, 4072 color=(0, 0, 0), 4073 figure=fig, 4074 ) 4075 mlab.text3d( 4076 *label_coords, 4077 text=label, 4078 scale=labels_scale_factor, 4079 color=(0, 0, 0), 4080 figure=fig, 4081 ) 4082 4083 cp = mlab.contour3d( 4084 fact * data, 4085 contours=[isolevel], 4086 transparent=True, 4087 colormap="hot", 4088 color=color, 4089 opacity=alpha, 4090 figure=fig, 4091 ) 4092 4093 polydata = cp.actor.actors[0].mapper.input 4094 pts = np.array(polydata.points) # - 1 4095 polydata.points = np.dot(pts, cell / np.array(data.shape)[:, np.newaxis]) 4096 4097 cx, cy, cz = [np.mean(np.array(polydata.points)[:, i]) for i in range(3)] 4098 4099 polydata.points = (np.array(polydata.points) - [cx, cy, cz]) * 2 4100 4101 # mlab.view(distance='auto') 4102 fig.scene.isometric_view() 4103 4104 if interative: 4105 mlab.show() 4106 4107 return fig, mlab 4108 4109 4110def plot_wigner_seitz(lattice, ax=None, **kwargs): 4111 """ 4112 Adds the skeleton of the Wigner-Seitz cell of the lattice to a matplotlib Axes 4113 4114 Args: 4115 lattice: Lattice object 4116 ax: matplotlib :class:`Axes` or None if a new figure should be created. 4117 kwargs: kwargs passed to the matplotlib function 'plot'. Color defaults to black 4118 and linewidth to 1. 4119 4120 Returns: 4121 matplotlib figure and matplotlib ax 4122 """ 4123 ax, fig, plt = get_ax3d_fig_plt(ax) 4124 4125 if "color" not in kwargs: 4126 kwargs["color"] = "k" 4127 if "linewidth" not in kwargs: 4128 kwargs["linewidth"] = 1 4129 4130 bz = lattice.get_wigner_seitz_cell() 4131 ax, fig, plt = get_ax3d_fig_plt(ax) 4132 for iface in range(len(bz)): # pylint: disable=C0200 4133 for line in itertools.combinations(bz[iface], 2): 4134 for jface in range(len(bz)): 4135 if ( 4136 iface < jface 4137 and any(np.all(line[0] == x) for x in bz[jface]) 4138 and any(np.all(line[1] == x) for x in bz[jface]) 4139 ): 4140 ax.plot(*zip(line[0], line[1]), **kwargs) 4141 4142 return fig, ax 4143 4144 4145def plot_lattice_vectors(lattice, ax=None, **kwargs): 4146 """ 4147 Adds the basis vectors of the lattice provided to a matplotlib Axes 4148 4149 Args: 4150 lattice: Lattice object 4151 ax: matplotlib :class:`Axes` or None if a new figure should be created. 4152 kwargs: kwargs passed to the matplotlib function 'plot'. Color defaults to green 4153 and linewidth to 3. 4154 4155 Returns: 4156 matplotlib figure and matplotlib ax 4157 """ 4158 ax, fig, plt = get_ax3d_fig_plt(ax) 4159 4160 if "color" not in kwargs: 4161 kwargs["color"] = "g" 4162 if "linewidth" not in kwargs: 4163 kwargs["linewidth"] = 3 4164 4165 vertex1 = lattice.get_cartesian_coords([0.0, 0.0, 0.0]) 4166 vertex2 = lattice.get_cartesian_coords([1.0, 0.0, 0.0]) 4167 ax.plot(*zip(vertex1, vertex2), **kwargs) 4168 vertex2 = lattice.get_cartesian_coords([0.0, 1.0, 0.0]) 4169 ax.plot(*zip(vertex1, vertex2), **kwargs) 4170 vertex2 = lattice.get_cartesian_coords([0.0, 0.0, 1.0]) 4171 ax.plot(*zip(vertex1, vertex2), **kwargs) 4172 4173 return fig, ax 4174 4175 4176def plot_path(line, lattice=None, coords_are_cartesian=False, ax=None, **kwargs): 4177 """ 4178 Adds a line passing through the coordinates listed in 'line' to a matplotlib Axes 4179 4180 Args: 4181 line: list of coordinates. 4182 lattice: Lattice object used to convert from reciprocal to cartesian coordinates 4183 coords_are_cartesian: Set to True if you are providing 4184 coordinates in cartesian coordinates. Defaults to False. 4185 Requires lattice if False. 4186 ax: matplotlib :class:`Axes` or None if a new figure should be created. 4187 kwargs: kwargs passed to the matplotlib function 'plot'. Color defaults to red 4188 and linewidth to 3. 4189 4190 Returns: 4191 matplotlib figure and matplotlib ax 4192 """ 4193 4194 ax, fig, plt = get_ax3d_fig_plt(ax) 4195 4196 if "color" not in kwargs: 4197 kwargs["color"] = "r" 4198 if "linewidth" not in kwargs: 4199 kwargs["linewidth"] = 3 4200 4201 for k in range(1, len(line)): 4202 vertex1 = line[k - 1] 4203 vertex2 = line[k] 4204 if not coords_are_cartesian: 4205 if lattice is None: 4206 raise ValueError("coords_are_cartesian False requires the lattice") 4207 vertex1 = lattice.get_cartesian_coords(vertex1) 4208 vertex2 = lattice.get_cartesian_coords(vertex2) 4209 ax.plot(*zip(vertex1, vertex2), **kwargs) 4210 4211 return fig, ax 4212 4213 4214def plot_labels(labels, lattice=None, coords_are_cartesian=False, ax=None, **kwargs): 4215 """ 4216 Adds labels to a matplotlib Axes 4217 4218 Args: 4219 labels: dict containing the label as a key and the coordinates as value. 4220 lattice: Lattice object used to convert from reciprocal to cartesian coordinates 4221 coords_are_cartesian: Set to True if you are providing. 4222 coordinates in cartesian coordinates. Defaults to False. 4223 Requires lattice if False. 4224 ax: matplotlib :class:`Axes` or None if a new figure should be created. 4225 kwargs: kwargs passed to the matplotlib function 'text'. Color defaults to blue 4226 and size to 25. 4227 4228 Returns: 4229 matplotlib figure and matplotlib ax 4230 """ 4231 ax, fig, plt = get_ax3d_fig_plt(ax) 4232 4233 if "color" not in kwargs: 4234 kwargs["color"] = "b" 4235 if "size" not in kwargs: 4236 kwargs["size"] = 25 4237 4238 for k, coords in labels.items(): 4239 label = k 4240 if k.startswith("\\") or k.find("_") != -1: 4241 label = "$" + k + "$" 4242 off = 0.01 4243 if coords_are_cartesian: 4244 coords = np.array(coords) 4245 else: 4246 if lattice is None: 4247 raise ValueError("coords_are_cartesian False requires the lattice") 4248 coords = lattice.get_cartesian_coords(coords) 4249 ax.text(*(coords + off), s=label, **kwargs) 4250 4251 return fig, ax 4252 4253 4254def fold_point(p, lattice, coords_are_cartesian=False): 4255 """ 4256 Folds a point with coordinates p inside the first Brillouin zone of the lattice. 4257 4258 Args: 4259 p: coordinates of one point 4260 lattice: Lattice object used to convert from reciprocal to cartesian coordinates 4261 coords_are_cartesian: Set to True if you are providing 4262 coordinates in cartesian coordinates. Defaults to False. 4263 4264 Returns: 4265 The cartesian coordinates folded inside the first Brillouin zone 4266 """ 4267 4268 if coords_are_cartesian: 4269 p = lattice.get_fractional_coords(p) 4270 else: 4271 p = np.array(p) 4272 4273 p = np.mod(p + 0.5 - 1e-10, 1) - 0.5 + 1e-10 4274 p = lattice.get_cartesian_coords(p) 4275 4276 closest_lattice_point = None 4277 smallest_distance = 10000 4278 for i in (-1, 0, 1): 4279 for j in (-1, 0, 1): 4280 for k in (-1, 0, 1): 4281 lattice_point = np.dot((i, j, k), lattice.matrix) 4282 dist = np.linalg.norm(p - lattice_point) 4283 if closest_lattice_point is None or dist < smallest_distance: 4284 closest_lattice_point = lattice_point 4285 smallest_distance = dist 4286 4287 if not np.allclose(closest_lattice_point, (0, 0, 0)): 4288 p = p - closest_lattice_point 4289 4290 return p 4291 4292 4293def plot_points(points, lattice=None, coords_are_cartesian=False, fold=False, ax=None, **kwargs): 4294 """ 4295 Adds Points to a matplotlib Axes 4296 4297 Args: 4298 points: list of coordinates 4299 lattice: Lattice object used to convert from reciprocal to cartesian coordinates 4300 coords_are_cartesian: Set to True if you are providing 4301 coordinates in cartesian coordinates. Defaults to False. 4302 Requires lattice if False. 4303 fold: whether the points should be folded inside the first Brillouin Zone. 4304 Defaults to False. Requires lattice if True. 4305 ax: matplotlib :class:`Axes` or None if a new figure should be created. 4306 kwargs: kwargs passed to the matplotlib function 'scatter'. Color defaults to blue 4307 4308 Returns: 4309 matplotlib figure and matplotlib ax 4310 """ 4311 ax, fig, plt = get_ax3d_fig_plt(ax) 4312 4313 if "color" not in kwargs: 4314 kwargs["color"] = "b" 4315 4316 if (not coords_are_cartesian or fold) and lattice is None: 4317 raise ValueError("coords_are_cartesian False or fold True require the lattice") 4318 4319 for p in points: 4320 4321 if fold: 4322 p = fold_point(p, lattice, coords_are_cartesian=coords_are_cartesian) 4323 4324 elif not coords_are_cartesian: 4325 p = lattice.get_cartesian_coords(p) 4326 4327 ax.scatter(*p, **kwargs) 4328 4329 return fig, ax 4330 4331 4332@add_fig_kwargs 4333def plot_brillouin_zone_from_kpath(kpath, ax=None, **kwargs): 4334 """ 4335 Gives the plot (as a matplotlib object) of the symmetry line path in 4336 the Brillouin Zone. 4337 4338 Args: 4339 kpath (HighSymmKpath): a HighSymmKPath object 4340 ax: matplotlib :class:`Axes` or None if a new figure should be created. 4341 **kwargs: provided by add_fig_kwargs decorator 4342 4343 Returns: 4344 matplotlib figure 4345 4346 """ 4347 lines = [[kpath.kpath["kpoints"][k] for k in p] for p in kpath.kpath["path"]] 4348 return plot_brillouin_zone( 4349 bz_lattice=kpath.prim_rec, 4350 lines=lines, 4351 ax=ax, 4352 labels=kpath.kpath["kpoints"], 4353 **kwargs, 4354 ) 4355 4356 4357@add_fig_kwargs 4358def plot_brillouin_zone( 4359 bz_lattice, 4360 lines=None, 4361 labels=None, 4362 kpoints=None, 4363 fold=False, 4364 coords_are_cartesian=False, 4365 ax=None, 4366 **kwargs, 4367): 4368 """ 4369 Plots a 3D representation of the Brillouin zone of the structure. 4370 Can add to the plot paths, labels and kpoints 4371 4372 Args: 4373 bz_lattice: Lattice object of the Brillouin zone 4374 lines: list of lists of coordinates. Each list represent a different path 4375 labels: dict containing the label as a key and the coordinates as value. 4376 kpoints: list of coordinates 4377 fold: whether the points should be folded inside the first Brillouin Zone. 4378 Defaults to False. Requires lattice if True. 4379 coords_are_cartesian: Set to True if you are providing 4380 coordinates in cartesian coordinates. Defaults to False. 4381 ax: matplotlib :class:`Axes` or None if a new figure should be created. 4382 kwargs: provided by add_fig_kwargs decorator 4383 4384 Returns: 4385 matplotlib figure 4386 """ 4387 4388 fig, ax = plot_lattice_vectors(bz_lattice, ax=ax) 4389 plot_wigner_seitz(bz_lattice, ax=ax) 4390 if lines is not None: 4391 for line in lines: 4392 plot_path(line, bz_lattice, coords_are_cartesian=coords_are_cartesian, ax=ax) 4393 4394 if labels is not None: 4395 plot_labels(labels, bz_lattice, coords_are_cartesian=coords_are_cartesian, ax=ax) 4396 plot_points( 4397 labels.values(), 4398 bz_lattice, 4399 coords_are_cartesian=coords_are_cartesian, 4400 fold=False, 4401 ax=ax, 4402 ) 4403 4404 if kpoints is not None: 4405 plot_points( 4406 kpoints, 4407 bz_lattice, 4408 coords_are_cartesian=coords_are_cartesian, 4409 ax=ax, 4410 fold=fold, 4411 ) 4412 4413 ax.set_xlim3d(-1, 1) 4414 ax.set_ylim3d(-1, 1) 4415 ax.set_zlim3d(-1, 1) 4416 4417 # ax.set_aspect('equal') 4418 ax.axis("off") 4419 4420 return fig 4421 4422 4423def plot_ellipsoid( 4424 hessian, 4425 center, 4426 lattice=None, 4427 rescale=1.0, 4428 ax=None, 4429 coords_are_cartesian=False, 4430 arrows=False, 4431 **kwargs, 4432): 4433 """ 4434 Plots a 3D ellipsoid rappresenting the Hessian matrix in input. 4435 Useful to get a graphical visualization of the effective mass 4436 of a band in a single k-point. 4437 4438 Args: 4439 hessian: the Hessian matrix 4440 center: the center of the ellipsoid in reciprocal coords (Default) 4441 lattice: Lattice object of the Brillouin zone 4442 rescale: factor for size scaling of the ellipsoid 4443 ax: matplotlib :class:`Axes` or None if a new figure should be created. 4444 coords_are_cartesian: Set to True if you are providing a center in 4445 cartesian coordinates. Defaults to False. 4446 kwargs: kwargs passed to the matplotlib function 'plot_wireframe'. 4447 Color defaults to blue, rstride and cstride 4448 default to 4, alpha defaults to 0.2. 4449 Returns: 4450 matplotlib figure and matplotlib ax 4451 Example of use: 4452 fig,ax=plot_wigner_seitz(struct.reciprocal_lattice) 4453 plot_ellipsoid(hessian,[0.0,0.0,0.0], struct.reciprocal_lattice,ax=ax) 4454 """ 4455 4456 if (not coords_are_cartesian) and lattice is None: 4457 raise ValueError("coords_are_cartesian False or fold True require the lattice") 4458 4459 if not coords_are_cartesian: 4460 center = lattice.get_cartesian_coords(center) 4461 4462 if "color" not in kwargs: 4463 kwargs["color"] = "b" 4464 if "rstride" not in kwargs: 4465 kwargs["rstride"] = 4 4466 if "cstride" not in kwargs: 4467 kwargs["cstride"] = 4 4468 if "alpha" not in kwargs: 4469 kwargs["alpha"] = 0.2 4470 4471 # calculate the ellipsoid 4472 # find the rotation matrix and radii of the axes 4473 U, s, rotation = np.linalg.svd(hessian) 4474 radii = 1.0 / np.sqrt(s) 4475 4476 # from polar coordinates 4477 u = np.linspace(0.0, 2.0 * np.pi, 100) 4478 v = np.linspace(0.0, np.pi, 100) 4479 x = radii[0] * np.outer(np.cos(u), np.sin(v)) 4480 y = radii[1] * np.outer(np.sin(u), np.sin(v)) 4481 z = radii[2] * np.outer(np.ones_like(u), np.cos(v)) 4482 for i in range(len(x)): 4483 for j in range(len(x)): 4484 [x[i, j], y[i, j], z[i, j]] = np.dot([x[i, j], y[i, j], z[i, j]], rotation) * rescale + center 4485 4486 # add the ellipsoid to the current axes 4487 ax, fig, plt = get_ax3d_fig_plt(ax) 4488 ax.plot_wireframe(x, y, z, **kwargs) 4489 4490 if arrows: 4491 color = ("b", "g", "r") 4492 em = np.zeros((3, 3)) 4493 for i in range(3): 4494 em[i, :] = rotation[i, :] / np.linalg.norm(rotation[i, :]) 4495 for i in range(3): 4496 ax.quiver3D( 4497 center[0], 4498 center[1], 4499 center[2], 4500 em[i, 0], 4501 em[i, 1], 4502 em[i, 2], 4503 pivot="tail", 4504 arrow_length_ratio=0.2, 4505 length=radii[i] * rescale, 4506 color=color[i], 4507 ) 4508 4509 return fig, ax 4510