1# coding: utf-8 2"""Scissors operator.""" 3import os 4import numpy as np 5import pickle 6 7from collections import OrderedDict 8from monty.collections import AttrDict 9from abipy.tools.plotting import add_fig_kwargs, get_ax_fig_plt 10 11 12__all__ = [ 13 "Scissors", 14 "ScissorsBuilder", 15] 16 17 18class ScissorsError(Exception): 19 """Base class for the exceptions raised by :class:`Scissors`""" 20 21 22class Scissors(object): 23 """ 24 This object represents an energy-dependent scissors operator. 25 The operator is defined by a list of domains (energy intervals) 26 and a list of functions defined in these domains. 27 The domains should fulfill the constraints documented in the main constructor. 28 29 .. note:: 30 31 eV units are assumed. 32 33 The standard way to create this object is via the methods provided by the factory class :class:`ScissorBuilder`. 34 Once the instance has been created, one can correct the band structure by calling the `apply` method. 35 """ 36 Error = ScissorsError 37 38 def __init__(self, func_list, domains, residues, bounds=None): 39 """ 40 Args: 41 func_list: List of callable objects. Each function takes an eigenvalue and returns 42 the corrected value. 43 domains: Domains of each function. List of tuples [(emin1, emax1), (emin2, emax2), ...] 44 bounds: Specify how to handle energies that do not fall inside one of the domains. 45 At present, only constant boundaries are implemented. 46 residues: A list of the residues of the fitting per domain 47 48 .. note:: 49 50 #. Domains should not overlap, cover e0mesh, and given in increasing order. 51 52 #. Holes are permitted but the interpolation will raise an exception if the 53 eigenvalue falls inside the hole. 54 55 #. Errors contains a list of the fitting errors per domain 56 57 """ 58 # TODO Add consistency check. 59 self.func_list = func_list 60 self.domains = np.atleast_2d(domains) 61 self.residues = residues 62 assert len(self.func_list) == len(self.domains) 63 64 # Treat the out-of-boundary conditions. func_low and func_high are used to handle energies 65 # that are below or above the min/max energy given in domains. 66 blow, bhigh = "c", "c" 67 if bounds is not None: 68 blow, bhigh = bounds[0][0], bounds[0][1] 69 70 if blow.lower() == "c": 71 try: 72 self.func_low = lambda x: float(bounds[0][1]) 73 except Exception: 74 x_low = self.domains[0,0] 75 fx_low = func_list[0](x_low) 76 self.func_low = lambda x: fx_low 77 else: 78 raise NotImplementedError("Only constant boundaries are implemented") 79 80 if bhigh.lower() == "c": 81 try: 82 self.func_high = lambda x: float(bounds[1][1]) 83 except Exception: 84 x_high = self.domains[1, -1] 85 fx_high = func_list[-1](x_high) 86 self.func_high = lambda x: fx_high 87 else: 88 raise NotImplementedError("Only constant boundaries are implemented") 89 90 # This counter stores the number of points that are out of bounds. 91 self.out_bounds = np.zeros(3, dtype=int) 92 93 def apply(self, eig): 94 """Correct the eigenvalue eig (eV units).""" 95 # Get the list of domains. 96 domains = self.domains 97 98 if eig < domains[0,0]: 99 # Eig is below the first point of the first domain. 100 # Call func_low 101 print("left ", eig, " < ", domains[0,0]) 102 self.out_bounds[0] += 1 103 return self.func_low(eig) 104 105 if eig > domains[-1,1]: 106 # Eig is above the last point of the last domain. 107 # Call func_high 108 print("right ", eig, " > ", domains[-1,1]) 109 self.out_bounds[1] += 1 110 return self.func_high(eig) 111 112 # eig is inside the domains: find the domain 113 # and call the corresponding function. 114 for idx, dms in enumerate(domains): 115 if dms[1] >= eig >= dms[0]: 116 return self.func_list[idx](eig) 117 118 self.out_bounds[2] += 1 119 raise self.Error("Cannot find location of eigenvalue %s in domains:\n%s" % (eig, domains)) 120 121 122class ScissorsBuilder(object): 123 """ 124 This object facilitates the creation of :class:`Scissors` instances. 125 126 Usage: 127 128 builder = ScissorsBuilder.from_file("out_SIGRES.nc") 129 130 # To plot the QP results as function of the KS energy: 131 builder.plot_qpe_vs_e0() 132 133 # To select the domains esplicitly (optional but highly recommended) 134 builder.build(domains_spin=[[-10, 6.02], [6.1, 20]]) 135 136 # To compare the fitted results with the ab-initio data: 137 builder.plot_fit() 138 139 # To plot the corrected bands: 140 builder.plot_qpbands(abidata.ref_file("si_nscf_WFK.nc")) 141 """ 142 143 @classmethod 144 def from_file(cls, filepath): 145 """ 146 Generate object from (SIGRES.nc) file. Main entry point for client code. 147 """ 148 from abipy.abilab import abiopen 149 with abiopen(filepath) as ncfile: 150 return cls(qps_spin=ncfile.qplist_spin, sigres_ebands=ncfile.ebands) 151 152 @classmethod 153 def pickle_load(cls, filepath): 154 """Load the object from a pickle file.""" 155 with open(filepath, "rb") as fh: 156 d = AttrDict(pickle.load(fh)) 157 # Costruct the object and compute the scissors. 158 new = cls(d.qps_spin, d.sigres_ebands) 159 new.build(d.domains_spin, d.bounds_spin) 160 return new 161 162 def pickle_dump(self, filepath, protocol=-1): 163 """Save the object in Pickle format""" 164 assert all(s1 == s2 for s1, s2 in zip(self.domains_spin.keys(), self.bounds_spin.keys())) 165 assert all(s1 == s2 for s1, s2 in zip(self.domains_spin.keys(), range(self.nsppol))) 166 167 bounds_spin = None 168 if any(v is not None for v in self.bounds_spin.values()): 169 bounds_spin = [a.tolist() for a in self.bounds_spin.values()] 170 171 # This trick is needed because we cannot pickle bound methods of the scissors operator. 172 d = dict(qps_spin=self._qps_spin, 173 sigres_ebands=self.sigres_ebands, 174 domains_spin=[a for a in self.domains_spin.values()], 175 bounds_spin=bounds_spin) 176 177 with open(filepath, "wb") as fh: 178 pickle.dump(d, fh, protocol=protocol) 179 180 def __init__(self, qps_spin, sigres_ebands): 181 """ 182 Args: 183 qps_spin: List of :class:`QPlist`, for each spin. 184 sigres_ebands: |ElectronBands| obtained from the SIGRES file 185 """ 186 # Sort quasiparticle data by e0. 187 self._qps_spin = tuple([qps.sort_by_e0() for qps in qps_spin]) 188 189 # Compute the boundaries of the E0 mesh. 190 e0min, e0max = np.inf, -np.inf 191 for qps in self._qps_spin: 192 e0mesh = qps.get_e0mesh() 193 e0min = min(e0min, e0mesh[0]) 194 e0max = max(e0max, e0mesh[-1]) 195 196 self._e0min, self._e0max = e0min, e0max 197 198 # The KS bands stored in the sigres file (used to compute automatically the boundaries) 199 self.sigres_ebands = sigres_ebands 200 201 # Start with default values for domains. 202 self.build() 203 204 @property 205 def nsppol(self): 206 """Number of spins.""" 207 return len(self._qps_spin) 208 209 @property 210 def e0min(self): 211 """Minimum KS energy in eV (takes into account spin)""" 212 return self._e0min 213 214 @property 215 def e0max(self): 216 """Maximum KS energy in eV (takes into account spin)""" 217 return self._e0max 218 219 @property 220 def scissors_spin(self): 221 """Returns a tuple of :class:`Scissors` indexed by the spin value.""" 222 try: 223 return self._scissors_spin 224 except AttributeError: 225 raise AttributeError("Call self.build to create the scissors operator") 226 227 def build(self, domains_spin=None, bounds_spin=None, k=3): 228 """ 229 Build the scissors operator. 230 231 Args: 232 domains_spin: list of domains in eV for each spin. If domains is None, 233 domains are computed automatically from the sigres bands 234 (two domains separated by the middle of the gap). 235 bounds_spin: Options specifying the boundary conditions (not used at present) 236 k: Parameter defining the order of the fit. 237 """ 238 nsppol = self.nsppol 239 240 # The parameters defining the scissors operator 241 self.domains_spin = OrderedDict() 242 self.bounds_spin = OrderedDict() 243 244 if domains_spin is None: 245 # Use sigres_ebands and the position of the homo, lumo to compute the domains. 246 domains_spin = nsppol * [None] 247 e_bands = self.sigres_ebands 248 for spin in e_bands.spins: 249 gap_mid = (e_bands.homos[spin].eig + e_bands.lumos[spin].eig) / 2 250 domains_spin[spin] = [[self.e0min - 0.2 * abs(self.e0min), gap_mid], 251 [gap_mid, self.e0max + 0.2 * abs(self.e0max)]] 252 #print("domains", domains_spin[spin]) 253 else: 254 if nsppol == 1: 255 domains_spin = np.reshape(domains_spin, (1, -1, 2)) 256 elif nsppol == 2: 257 assert len(domains_spin) == nsppol 258 if bounds_spin is not None: assert len(bounds_spin) == nsppol 259 else: 260 raise ValueError("Wrong number of spins %d" % nsppol) 261 #if len(domains_spin) != nsppol: 262 # raise ValueError("len(domains_spin) == %s != nsppol %s" % (len(domains_spin), nsppol)) 263 264 # Construct the scissors operator for each spin. 265 scissors_spin = nsppol * [None] 266 for spin, qps in enumerate(self._qps_spin): 267 bounds = None if not bounds_spin else bounds_spin[spin] 268 scissors_spin[spin] = qps.build_scissors(domains_spin[spin], bounds=bounds, k=k, plot=False) 269 270 # Save input so that we can reconstruct Scissors. 271 self.domains_spin[spin] = domains_spin[spin] 272 self.bounds_spin[spin] = bounds 273 274 self._scissors_spin = scissors_spin 275 return domains_spin 276 277 @add_fig_kwargs 278 def plot_qpe_vs_e0(self, with_fields="all", **kwargs): 279 """Plot the quasiparticle corrections as function of the KS energy.""" 280 ax_list = None 281 for spin, qps in enumerate(self._qps_spin): 282 kwargs["title"] = "spin %s" % spin 283 fig = qps.plot_qps_vs_e0(with_fields=with_fields, ax_list=ax_list, show=False, **kwargs) 284 ax_list = fig.axes 285 286 return fig 287 288 @add_fig_kwargs 289 def plot_fit(self, ax=None, fontsize=8, **kwargs): 290 """ 291 Compare fit functions with input quasi-particle corrections. 292 293 Args: 294 ax: |matplotlib-Axes| or None if a new figure should be created. 295 fontsize: fontsize for titles and legend. 296 297 Return: |matplotlib-Figure| 298 """ 299 ax, fig, plt = get_ax_fig_plt(ax=ax) 300 301 for spin in range(self.nsppol): 302 qps = self._qps_spin[spin] 303 e0mesh, qpcorrs = qps.get_e0mesh(), qps.get_qpeme0().real 304 305 ax.scatter(e0mesh, qpcorrs, label="Input QP corrections, spin %s" % spin) 306 scissors = self._scissors_spin[spin] 307 intp_qpc = [scissors.apply(e0) for e0 in e0mesh] 308 ax.plot(e0mesh, intp_qpc, label="Scissors operator, spin %s" % spin) 309 310 ax.grid(True) 311 ax.set_xlabel('KS energy (eV)') 312 ax.set_ylabel('QP-KS (eV)') 313 ax.legend(loc="best", fontsize=fontsize, shadow=True) 314 315 return fig 316 317 def plot_qpbands(self, bands_filepath, bands_label=None, dos_filepath=None, dos_args=None, **kwargs): 318 """ 319 Correct the energies found in the netcdf file bands_filepath and plot the band energies (both the initial 320 and the corrected ones) with matplotlib. The plot contains the KS and the QP DOS if dos_filepath is not None. 321 322 Args: 323 bands_filepath: Path to the netcdf file containing the initial KS energies to be corrected. 324 bands_label String used to label the KS bands in the plot. 325 dos_filepath: Optional path to a netcdf file with the initial KS energies on a homogeneous k-mesh 326 (used to compute the KS and the QP dos) 327 dos_args: Dictionary with the arguments passed to get_dos to compute the DOS 328 Used if dos_filepath is not None. 329 330 kwargs: Options passed to the plotter. 331 332 Return: |matplotlib-Figure| 333 """ 334 from abipy.abilab import abiopen, ElectronBandsPlotter 335 336 # Read the KS band energies from bands_filepath and apply the scissors operator. 337 with abiopen(bands_filepath) as ncfile: 338 ks_bands = ncfile.ebands 339 #structure = ncfile.structure 340 341 qp_bands = ks_bands.apply_scissors(self._scissors_spin) 342 343 # Read the band energies computed on the Monkhorst-Pack (MP) mesh and compute the DOS. 344 ks_dos, qp_dos = None, None 345 if dos_filepath is not None: 346 with abiopen(dos_filepath) as ncfile: 347 ks_mpbands = ncfile.ebands 348 349 dos_args = {} if not dos_args else dos_args 350 ks_dos = ks_mpbands.get_edos(**dos_args) 351 # Compute the DOS with the modified QPState energies. 352 qp_mpbands = ks_mpbands.apply_scissors(self._scissors_spin) 353 qp_dos = qp_mpbands.get_edos(**dos_args) 354 355 # Plot the LDA and the QPState band structure with matplotlib. 356 plotter = ElectronBandsPlotter() 357 358 bands_label = bands_label if bands_label is not None else os.path.basename(bands_filepath) 359 plotter.add_ebands(bands_label, ks_bands, edos=ks_dos) 360 plotter.add_ebands(bands_label + " + scissors", qp_bands, edos=qp_dos) 361 362 #qp_marker: if int > 0, markers for the ab-initio QP energies are displayed. e.g qp_marker=50 363 #qp_marker = 50 364 #if qp_marker is not None: 365 # # Compute correspondence between the k-points in qp_list and the k-path in qp_bands. 366 # # TODO 367 # # WARNING: strictly speaking one should check if qp_kpoint is in the star of k-point. 368 # # but compute_star is too slow if written in pure python. 369 # x, y, s = [], [], [] 370 # for ik_path, kpoint in enumerate(qp_bands.kpoints): 371 # #kstar = kpoint.compute_star(structure.fm_symmops) 372 # for spin in range(self.nsppol): 373 # for ik_qp, qp in enumerate(self._qps_spin[spin]): 374 # #if qp.kpoint in kstar: 375 # if qp.kpoint == kpoint: 376 # x.append(ik_path) 377 # y.append(np.real(qp.qpe)) 378 # s.append(qp_marker) 379 # plotter.set_marker("ab-initio QP", [x, y, s]) 380 381 return plotter.combiplot(**kwargs) 382