1"""Module defining an eigensolver base-class.""" 2from functools import partial 3 4import numpy as np 5from ase.dft.bandgap import _bandgap 6from ase.units import Ha 7from ase.utils.timing import timer 8 9from gpaw.matrix import matrix_matrix_multiply as mmm 10from gpaw.utilities.blas import axpy 11from gpaw.xc.hybrid import HybridXC 12 13 14def reshape(a_x, shape): 15 """Get an ndarray of size shape from a_x buffer.""" 16 return a_x.ravel()[:np.prod(shape)].reshape(shape) 17 18 19class Eigensolver: 20 def __init__(self, keep_htpsit=True, blocksize=1): 21 self.keep_htpsit = keep_htpsit 22 self.initialized = False 23 self.Htpsit_nG = None 24 self.error = np.inf 25 self.blocksize = blocksize 26 self.orthonormalization_required = True 27 28 def initialize(self, wfs): 29 self.timer = wfs.timer 30 self.world = wfs.world 31 self.kpt_comm = wfs.kd.comm 32 self.band_comm = wfs.bd.comm 33 self.dtype = wfs.dtype 34 self.bd = wfs.bd 35 self.nbands = wfs.bd.nbands 36 self.mynbands = wfs.bd.mynbands 37 38 if wfs.bd.comm.size > 1: 39 self.keep_htpsit = False 40 41 if self.keep_htpsit: 42 self.Htpsit_nG = np.empty_like(wfs.work_array) 43 44 # Preconditioner for the electronic gradients: 45 self.preconditioner = wfs.make_preconditioner(self.blocksize) 46 47 for kpt in wfs.kpt_u: 48 if kpt.eps_n is None: 49 kpt.eps_n = np.empty(self.mynbands) 50 51 self.initialized = True 52 53 def reset(self): 54 self.initialized = False 55 56 def weights(self, wfs): 57 """Calculate convergence weights for all eigenstates.""" 58 weight_un = np.zeros((len(wfs.kpt_u), self.bd.mynbands)) 59 60 if isinstance(self.nbands_converge, int): 61 # Converge fixed number of bands: 62 n = self.nbands_converge - self.bd.beg 63 if n > 0: 64 for weight_n, kpt in zip(weight_un, wfs.kpt_u): 65 weight_n[:n] = kpt.weight 66 elif self.nbands_converge == 'occupied': 67 # Conveged occupied bands: 68 for weight_n, kpt in zip(weight_un, wfs.kpt_u): 69 if kpt.f_n is None: # no eigenvalues yet 70 weight_n[:] = np.inf 71 else: 72 # Methfessel-Paxton distribution can give negative 73 # occupation numbers - so we take the absolute value: 74 weight_n[:] = np.abs(kpt.f_n) 75 else: 76 # Converge state with energy up to CBM + delta: 77 assert self.nbands_converge.startswith('CBM+') 78 delta = float(self.nbands_converge[4:]) / Ha 79 80 if wfs.kpt_u[0].f_n is None: 81 weight_un[:] = np.inf # no eigenvalues yet 82 else: 83 # Collect all eigenvalues and calculate band gap: 84 efermi = np.mean(wfs.fermi_levels) 85 eps_skn = np.array( 86 [[wfs.collect_eigenvalues(k, spin) - efermi 87 for k in range(wfs.kd.nibzkpts)] 88 for spin in range(wfs.nspins)]) 89 if wfs.world.rank > 0: 90 eps_skn = np.empty((wfs.nspins, 91 wfs.kd.nibzkpts, 92 wfs.bd.nbands)) 93 wfs.world.broadcast(eps_skn, 0) 94 try: 95 # Find bandgap + positions of CBM: 96 gap, _, (s, k, n) = _bandgap(eps_skn, 97 spin=None, direct=False) 98 except ValueError: 99 gap = 0.0 100 101 if gap == 0.0: 102 cbm = efermi 103 else: 104 cbm = efermi + eps_skn[s, k, n] 105 106 ecut = cbm + delta 107 108 for weight_n, kpt in zip(weight_un, wfs.kpt_u): 109 weight_n[kpt.eps_n < ecut] = kpt.weight 110 111 if (eps_skn[:, :, -1] < ecut - efermi).any(): 112 # We don't have enough bands! 113 weight_un[:] = np.inf 114 115 return weight_un 116 117 def iterate(self, ham, wfs): 118 """Solves eigenvalue problem iteratively 119 120 This method is inherited by the actual eigensolver which should 121 implement *iterate_one_k_point* method for a single iteration of 122 a single kpoint. 123 """ 124 125 if not self.initialized: 126 if isinstance(ham.xc, HybridXC): 127 self.blocksize = wfs.bd.mynbands 128 self.initialize(wfs) 129 130 weight_un = self.weights(wfs) 131 132 error = 0.0 133 for kpt, weights in zip(wfs.kpt_u, weight_un): 134 if not wfs.orthonormalized: 135 wfs.orthonormalize(kpt) 136 e = self.iterate_one_k_point(ham, wfs, kpt, weights) 137 error += e 138 if self.orthonormalization_required: 139 wfs.orthonormalize(kpt) 140 141 wfs.orthonormalized = True 142 self.error = self.band_comm.sum(self.kpt_comm.sum(error)) 143 144 def iterate_one_k_point(self, ham, kpt): 145 """Implemented in subclasses.""" 146 raise NotImplementedError 147 148 def calculate_residuals(self, kpt, wfs, ham, psit, P, eps_n, 149 R, C, n_x=None, calculate_change=False): 150 """Calculate residual. 151 152 From R=Ht*psit calculate R=H*psit-eps*S*psit.""" 153 154 for R_G, eps, psit_G in zip(R.array, eps_n, psit.array): 155 axpy(-eps, psit_G, R_G) 156 157 ham.dH(P, out=C) 158 for a, I1, I2 in P.indices: 159 dS_ii = ham.setups[a].dO_ii 160 C.array[..., I1:I2] -= np.dot((P.array[..., I1:I2].T * eps_n).T, 161 dS_ii) 162 163 ham.xc.add_correction(kpt, psit.array, R.array, 164 {a: P_ni for a, P_ni in P.items()}, 165 {a: C_ni for a, C_ni in C.items()}, 166 n_x, 167 calculate_change) 168 wfs.pt.add(R.array, {a: C_ni for a, C_ni in C.items()}, kpt.q) 169 170 @timer('Subspace diag') 171 def subspace_diagonalize(self, ham, wfs, kpt): 172 """Diagonalize the Hamiltonian in the subspace of kpt.psit_nG 173 174 *Htpsit_nG* is a work array of same size as psit_nG which contains 175 the local part of the Hamiltonian times psit on exit 176 177 First, the Hamiltonian (defined by *kin*, *vt_sG*, and 178 *dH_asp*) is applied to the wave functions, then the *H_nn* 179 matrix is calculated and diagonalized, and finally, the wave 180 functions (and also Htpsit_nG are rotated. Also the 181 projections *P_ani* are rotated. 182 183 It is assumed that the wave functions *psit_nG* are orthonormal 184 and that the integrals of projector functions and wave functions 185 *P_ani* are already calculated. 186 187 Return rotated wave functions and H applied to the rotated 188 wave functions if self.keep_htpsit is True. 189 """ 190 191 if self.band_comm.size > 1 and wfs.bd.strided: 192 raise NotImplementedError 193 194 psit = kpt.psit 195 tmp = psit.new(buf=wfs.work_array) 196 H = wfs.work_matrix_nn 197 P2 = kpt.projections.new() 198 199 Ht = partial(wfs.apply_pseudo_hamiltonian, kpt, ham) 200 201 with self.timer('calc_h_matrix'): 202 # We calculate the complex conjugate of H, because 203 # that is what is most efficient with BLAS given the layout of 204 # our matrices. 205 psit.matrix_elements(operator=Ht, result=tmp, out=H, 206 symmetric=True, cc=True) 207 ham.dH(kpt.projections, out=P2) 208 mmm(1.0, kpt.projections, 'N', P2, 'C', 1.0, H, symmetric=True) 209 ham.xc.correct_hamiltonian_matrix(kpt, H.array) 210 211 with wfs.timer('diagonalize'): 212 slcomm, r, c, b = wfs.scalapack_parameters 213 if r == c == 1: 214 slcomm = None 215 # Complex conjugate before diagonalizing: 216 eps_n = H.eigh(cc=True, scalapack=(slcomm, r, c, b)) 217 # H.array[n, :] now contains the n'th eigenvector and eps_n[n] 218 # the n'th eigenvalue 219 kpt.eps_n = eps_n[wfs.bd.get_slice()] 220 221 with self.timer('rotate_psi'): 222 if self.keep_htpsit: 223 Htpsit = psit.new(buf=self.Htpsit_nG) 224 mmm(1.0, H, 'N', tmp, 'N', 0.0, Htpsit) 225 mmm(1.0, H, 'N', psit, 'N', 0.0, tmp) 226 psit[:] = tmp 227 mmm(1.0, H, 'N', kpt.projections, 'N', 0.0, P2) 228 kpt.projections.matrix = P2.matrix 229 # Rotate orbital dependent XC stuff: 230 ham.xc.rotate(kpt, H.array.T) 231 232 def estimate_memory(self, mem, wfs): 233 gridmem = wfs.bytes_per_wave_function() 234 235 keep_htpsit = self.keep_htpsit and (wfs.bd.mynbands == wfs.bd.nbands) 236 237 if keep_htpsit: 238 mem.subnode('Htpsit', wfs.bd.nbands * gridmem) 239 else: 240 mem.subnode('No Htpsit', 0) 241 242 mem.subnode('eps_n', wfs.bd.mynbands * mem.floatsize) 243 mem.subnode('eps_N', wfs.bd.nbands * mem.floatsize) 244 mem.subnode('Preconditioner', 4 * gridmem) 245 mem.subnode('Work', gridmem) 246