1from pathlib import Path 2 3import numpy as np 4from scipy.spatial import cKDTree 5 6from gpaw.utilities import convert_string_to_fd 7from ase.utils.timing import Timer, timer 8 9from gpaw import GPAW, disable_dry_run 10import gpaw.mpi as mpi 11from gpaw.response.math_func import two_phi_planewave_integrals 12 13 14class KohnShamKPoint: 15 """Kohn-Sham orbital information for a given k-point.""" 16 def __init__(self, n_t, s_t, K, 17 eps_t, f_t, ut_tR, projections, shift_c): 18 """K-point data is indexed by a joint spin and band index h, which is 19 directly related to the transition index t.""" 20 self.K = K # BZ k-point index 21 self.n_t = n_t # Band index 22 self.s_t = s_t # Spin index 23 self.eps_t = eps_t # Eigenvalues 24 self.f_t = f_t # Occupation numbers 25 self.ut_tR = ut_tR # Periodic part of wave functions 26 self.projections = projections # PAW projections 27 28 self.shift_c = shift_c # long story - see the 29 # PairDensity.construct_symmetry_operators() method 30 31 32class KohnShamKPointPair: 33 """Object containing all transitions between Kohn-Sham orbitals from a 34 specified k-point to another.""" 35 36 def __init__(self, kpt1, kpt2, mynt, nt, ta, tb, comm=None): 37 self.kpt1 = kpt1 38 self.kpt2 = kpt2 39 40 self.mynt = mynt # Number of transitions in this block 41 self.nt = nt # Total number of transitions between all blocks 42 self.ta = ta # First transition index of this block 43 self.tb = tb # First transition index of this block not included 44 self.comm = comm # MPI communicator between blocks of transitions 45 46 def transition_distribution(self): 47 """Get the distribution of transitions.""" 48 return self.mynt, self.nt, self.ta, self.tb 49 50 def get_transitions(self): 51 return self.n1_t, self.n2_t, self.s1_t, self.s2_t 52 53 def get_all(self, A_mytx): 54 """Get a certain data array with all transitions""" 55 if self.comm is None or A_mytx is None: 56 return A_mytx 57 58 A_tx = np.empty((self.mynt * self.comm.size,) + A_mytx.shape[1:], 59 dtype=A_mytx.dtype) 60 61 self.comm.all_gather(A_mytx, A_tx) 62 63 return A_tx[:self.nt] 64 65 @property 66 def n1_t(self): 67 return self.get_all(self.kpt1.n_t) 68 69 @property 70 def n2_t(self): 71 return self.get_all(self.kpt2.n_t) 72 73 @property 74 def s1_t(self): 75 return self.get_all(self.kpt1.s_t) 76 77 @property 78 def s2_t(self): 79 return self.get_all(self.kpt2.s_t) 80 81 @property 82 def deps_t(self): 83 return self.get_all(self.kpt2.eps_t) - self.get_all(self.kpt1.eps_t) 84 85 @property 86 def df_t(self): 87 return self.get_all(self.kpt2.f_t) - self.get_all(self.kpt1.f_t) 88 89 @classmethod 90 def add_mytransitions_array(cls, _key, key): 91 """Add a A_tx data array class attribute. 92 Handles the fact, that the transitions are distributed in blocks. 93 94 Parameters 95 ---------- 96 _key : str 97 attribute name for the A_mytx data array 98 key : str 99 attribute name for the A_tx data array 100 """ 101 # In general, the data array has not been specified to instances of 102 # the class. As a result, set the _key to None 103 setattr(cls, _key, None) 104 # self.key should return full data array 105 setattr(cls, key, 106 property(lambda self: self.get_all(self.__dict__[_key]))) 107 108 def attach(self, _key, key, A_mytx): 109 """Attach a data array to the k-point pair. 110 Used by PairMatrixElement to attach matrix elements calculated 111 between the k-points for the different transitions.""" 112 self.add_mytransitions_array(_key, key) 113 setattr(self, _key, A_mytx) 114 115 116class KohnShamPair: 117 """Class for extracting pairs of Kohn-Sham orbitals from a ground 118 state calculation.""" 119 120 def __init__(self, gs, world=mpi.world, transitionblockscomm=None, 121 kptblockcomm=None, txt='-', timer=None): 122 """ 123 Parameters 124 ---------- 125 transitionblockscomm : gpaw.mpi.Communicator 126 Communicator for distributing the transitions among processes 127 kptblockcomm : gpaw.mpi.Communicator 128 Communicator for distributing k-points among processes 129 """ 130 self.world = world 131 self.fd = convert_string_to_fd(txt, world) 132 self.timer = timer or Timer() 133 self.calc = get_calc(gs, fd=self.fd, timer=self.timer) 134 self.calc_parallel = self.check_calc_parallelisation() 135 136 self.transitionblockscomm = transitionblockscomm 137 self.kptblockcomm = kptblockcomm 138 139 # Prepare to distribute transitions 140 self.mynt = None 141 self.nt = None 142 self.ta = None 143 self.tb = None 144 145 # Prepare to find k-point data from vector 146 kd = self.calc.wfs.kd 147 self.kdtree = cKDTree(np.mod(np.mod(kd.bzk_kc, 1).round(6), 1)) 148 149 # Prepare to use other processes' k-points 150 self._pd0 = None 151 152 # Prepare to redistribute kptdata 153 self.rrequests = [] 154 self.srequests = [] 155 156 # Count bands so it is possible to remove null transitions 157 self.nocc1 = None # number of completely filled bands 158 self.nocc2 = None # number of non-empty bands 159 self.count_occupied_bands() 160 161 def check_calc_parallelisation(self): 162 """Check how ground state calculation is distributed in memory""" 163 if self.calc.world.size == 1: 164 return False 165 else: 166 assert self.world.rank == self.calc.wfs.world.rank 167 assert self.calc.wfs.gd.comm.size == 1 168 return True 169 170 def count_occupied_bands(self): 171 """Count number of occupied and unoccupied bands in ground state 172 calculation. Can be used to omit null-transitions between two occupied 173 bands or between two unoccupied bands.""" 174 ftol = 1.e-9 # Could be given as input 175 nocc1 = 9999999 176 nocc2 = 0 177 for kpt in self.calc.wfs.kpt_u: 178 f_n = kpt.f_n / kpt.weight 179 nocc1 = min((f_n > 1 - ftol).sum(), nocc1) 180 nocc2 = max((f_n > ftol).sum(), nocc2) 181 nocc1 = int(nocc1) 182 nocc2 = int(nocc2) 183 184 # Collect nocc for all k-points 185 nocc1 = self.calc.wfs.kd.comm.min(nocc1) 186 nocc2 = self.calc.wfs.kd.comm.max(nocc2) 187 188 # Sum over band distribution 189 nocc1 = self.calc.wfs.bd.comm.sum(nocc1) 190 nocc2 = self.calc.wfs.bd.comm.sum(nocc2) 191 192 self.nocc1 = int(nocc1) 193 self.nocc2 = int(nocc2) 194 print('Number of completely filled bands:', self.nocc1, file=self.fd) 195 print('Number of partially filled bands:', self.nocc2, file=self.fd) 196 print('Total number of bands:', self.calc.wfs.bd.nbands, 197 file=self.fd) 198 199 @property 200 def pd0(self): 201 """Get a PWDescriptor that includes all k-points""" 202 if self._pd0 is None: 203 from gpaw.wavefunctions.pw import PWDescriptor 204 wfs = self.calc.wfs 205 assert wfs.gd.comm.size == 1 206 207 kd0 = wfs.kd.copy() 208 pd, gd = wfs.pd, wfs.gd 209 210 # Extract stuff from self.calc.wfs.pd 211 ecut, dtype = pd.ecut, pd.dtype 212 fftwflags, gammacentered = pd.fftwflags, pd.gammacentered 213 214 # Initiate _pd0 with kd0 215 self._pd0 = PWDescriptor(ecut, gd, dtype=dtype, 216 kd=kd0, fftwflags=fftwflags, 217 gammacentered=gammacentered) 218 return self._pd0 219 220 @timer('Get Kohn-Sham pairs') 221 def get_kpoint_pairs(self, n1_t, n2_t, k1_pc, k2_pc, s1_t, s2_t): 222 """Get all pairs of Kohn-Sham orbitals for transitions: 223 (n1_t, k1_p, s1_t) -> (n2_t, k2_p, s2_t) 224 Here, t is a composite band and spin transition index 225 and p is indexing the different k-points to be distributed.""" 226 227 # Distribute transitions and extract data for transitions in 228 # this process' block 229 nt = len(n1_t) 230 assert nt == len(n2_t) 231 self.distribute_transitions(nt) 232 233 kpt1 = self.get_kpoints(k1_pc, n1_t, s1_t) 234 kpt2 = self.get_kpoints(k2_pc, n2_t, s2_t) 235 236 # The process might not have any k-point pairs to evaluate, as 237 # their are distributed in the kptblockcomm 238 if kpt1 is None: 239 assert kpt2 is None 240 return None 241 assert kpt2 is not None 242 243 return KohnShamKPointPair(kpt1, kpt2, 244 self.mynt, nt, self.ta, self.tb, 245 comm=self.transitionblockscomm) 246 247 def distribute_transitions(self, nt): 248 """Distribute transitions between processes in block communicator""" 249 if self.transitionblockscomm is None: 250 mynt = nt 251 ta = 0 252 tb = nt 253 else: 254 nblocks = self.transitionblockscomm.size 255 rank = self.transitionblockscomm.rank 256 257 mynt = (nt + nblocks - 1) // nblocks 258 ta = min(rank * mynt, nt) 259 tb = min(ta + mynt, nt) 260 261 self.mynt = mynt 262 self.nt = nt 263 self.ta = ta 264 self.tb = tb 265 266 def get_kpoints(self, k_pc, n_t, s_t): 267 """Get KohnShamKPoint and help other processes extract theirs""" 268 assert len(n_t) == len(s_t) 269 assert len(k_pc) <= self.kptblockcomm.size 270 kpt = None 271 272 # Use the data extraction factory to extract the kptdata 273 _extract_kptdata = self.create_extract_kptdata() 274 kptdata = _extract_kptdata(k_pc, n_t, s_t) 275 276 # Make local n and s arrays for the KohnShamKPoint object 277 n_myt = np.empty(self.mynt, dtype=n_t.dtype) 278 n_myt[:self.tb - self.ta] = n_t[self.ta:self.tb] 279 s_myt = np.empty(self.mynt, dtype=s_t.dtype) 280 s_myt[:self.tb - self.ta] = s_t[self.ta:self.tb] 281 282 # Initiate k-point object. 283 if self.kptblockcomm.rank in range(len(k_pc)): 284 assert kptdata is not None 285 kpt = KohnShamKPoint(n_myt, s_myt, *kptdata) 286 287 return kpt 288 289 def create_extract_kptdata(self): 290 """Creator component of the data extraction factory.""" 291 if self.calc_parallel: 292 return self.parallel_extract_kptdata 293 else: 294 return self.serial_extract_kptdata 295 # Useful for debugging: 296 # return self.parallel_extract_kptdata 297 298 def parallel_extract_kptdata(self, k_pc, n_t, s_t): 299 """Returns the input to KohnShamKPoint: 300 K, n_myt, s_myt, eps_myt, f_myt, ut_mytR, projections, shift_c 301 if a k-point in the given list, k_pc, belongs to the process. 302 """ 303 # Extract the data from the ground state calculator object 304 data, h_myt, myt_myt = self._parallel_extract_kptdata(k_pc, n_t, s_t) 305 306 # If the process has a k-point to return, symmetrize and unfold 307 if self.kptblockcomm.rank in range(len(k_pc)): 308 assert data is not None 309 # Unpack data, apply FT and symmetrization 310 K, k_c, eps_h, f_h, Ph, psit_hG = data 311 Ph, ut_hR, shift_c = self.transform_and_symmetrize(K, k_c, Ph, 312 psit_hG) 313 314 (eps_myt, f_myt, 315 P, ut_mytR) = self.unfold_arrays(eps_h, f_h, Ph, ut_hR, 316 h_myt, myt_myt) 317 318 data = (K, eps_myt, f_myt, ut_mytR, P, shift_c) 319 320 # Wait for communication to finish 321 with self.timer('Waiting to complete mpi.send'): 322 while self.srequests: 323 self.world.wait(self.srequests.pop(0)) 324 325 return data 326 327 @timer('Extracting data from the ground state calculator object') 328 def _parallel_extract_kptdata(self, k_pc, n_t, s_t): 329 """In-place kptdata extraction.""" 330 (data, myu_eu, 331 myn_eueh, ik_r2, 332 nrh_r2, eh_eur2reh, 333 rh_eur2reh, h_r1rh, 334 h_myt, myt_myt) = self.get_extraction_protocol(k_pc, n_t, s_t) 335 336 (eps_r1rh, f_r1rh, 337 P_r1rhI, psit_r1rhG, 338 eps_r2rh, f_r2rh, 339 P_r2rhI, psit_r2rhG) = self.allocate_transfer_arrays(data, nrh_r2, 340 ik_r2, h_r1rh) 341 342 # Do actual extraction 343 for myu, myn_eh, eh_r2reh, rh_r2reh in zip(myu_eu, myn_eueh, 344 eh_eur2reh, rh_eur2reh): 345 346 eps_eh, f_eh, P_ehI = self.extract_wfs_data(myu, myn_eh) 347 348 for r2, (eh_reh, rh_reh) in enumerate(zip(eh_r2reh, rh_r2reh)): 349 if eh_reh: 350 eps_r2rh[r2][rh_reh] = eps_eh[eh_reh] 351 f_r2rh[r2][rh_reh] = f_eh[eh_reh] 352 P_r2rhI[r2][rh_reh] = P_ehI[eh_reh] 353 354 # Wavefunctions are heavy objects which can only be extracted 355 # for one band index at a time, handle them seperately 356 self.add_wave_function(myu, myn_eh, eh_r2reh, 357 rh_r2reh, psit_r2rhG) 358 359 self.distribute_extracted_data(eps_r1rh, f_r1rh, P_r1rhI, psit_r1rhG, 360 eps_r2rh, f_r2rh, P_r2rhI, psit_r2rhG) 361 362 data = self.collect_kptdata(data, h_r1rh, eps_r1rh, 363 f_r1rh, P_r1rhI, psit_r1rhG) 364 365 return data, h_myt, myt_myt 366 367 @timer('Create data extraction protocol') 368 def get_extraction_protocol(self, k_pc, n_t, s_t): 369 """Figure out how to extract data efficiently. 370 For the serial communicator, all processes can access all data, 371 and resultantly, there is no need to send any data. 372 """ 373 wfs = self.calc.wfs 374 get_extraction_info = self.create_get_extraction_info() 375 376 # Kpoint data 377 data = (None, None, None) 378 379 # Extraction protocol 380 myu_eu = [] 381 myn_eueh = [] 382 383 # Data distribution protocol 384 nrh_r2 = np.zeros(self.world.size, dtype=int) 385 ik_r2 = [None for _ in range(self.world.size)] 386 eh_eur2reh = [] 387 rh_eur2reh = [] 388 h_r1rh = [list([]) for _ in range(self.world.size)] 389 390 # h to t index mapping 391 myt_myt = np.arange(self.tb - self.ta) 392 t_myt = range(self.ta, self.tb) 393 n_myt, s_myt = n_t[t_myt], s_t[t_myt] 394 h_myt = np.empty(self.tb - self.ta, dtype=int) 395 396 nt = len(n_t) 397 assert nt == len(s_t) 398 t_t = np.arange(nt) 399 nh = 0 400 for p, k_c in enumerate(k_pc): # p indicates the receiving process 401 K = self.find_kpoint(k_c) 402 ik = wfs.kd.bz2ibz_k[K] 403 for r2 in range(p * self.transitionblockscomm.size, 404 min((p + 1) * self.transitionblockscomm.size, 405 self.world.size)): 406 ik_r2[r2] = ik 407 408 if p == self.kptblockcomm.rank: 409 data = (K, k_c, ik) 410 411 # Find out who should store the data in KSKPpoint 412 r2_t, myt_t = self.map_who_has(p, t_t) 413 414 # Find out how to extract data 415 # In the ground state, kpts are indexed by u=(s, k) 416 for s in set(s_t): 417 thiss_myt = s_myt == s 418 thiss_t = s_t == s 419 t_ct = t_t[thiss_t] 420 n_ct = n_t[thiss_t] 421 r2_ct = r2_t[t_ct] 422 423 # Find out where data is in wfs 424 u = ik * wfs.nspins + s 425 myu, r1_ct, myn_ct = get_extraction_info(u, n_ct, r2_ct) 426 427 # If the process is extracting or receiving data, 428 # figure out how to do so 429 if self.world.rank in np.append(r1_ct, r2_ct): 430 # Does this process have anything to send? 431 thisr1_ct = r1_ct == self.world.rank 432 if np.any(thisr1_ct): 433 eh_r2reh = [list([]) for _ in range(self.world.size)] 434 rh_r2reh = [list([]) for _ in range(self.world.size)] 435 # Find composite indeces h = (n, s) 436 n_et = n_ct[thisr1_ct] 437 n_eh = np.unique(n_et) 438 # Find composite local band indeces 439 myn_eh = np.unique(myn_ct[thisr1_ct]) 440 441 # Where to send the data 442 r2_et = r2_ct[thisr1_ct] 443 for r2 in np.unique(r2_et): 444 thisr2_et = r2_et == r2 445 # What ns are the process sending? 446 n_reh = np.unique(n_et[thisr2_et]) 447 eh_reh = [] 448 for n in n_reh: 449 eh_reh.append(np.where(n_eh == n)[0][0]) 450 # How to send it 451 eh_r2reh[r2] = eh_reh 452 nreh = len(eh_reh) 453 rh_r2reh[r2] = np.arange(nreh) + nrh_r2[r2] 454 nrh_r2[r2] += nreh 455 456 myu_eu.append(myu) 457 myn_eueh.append(myn_eh) 458 eh_eur2reh.append(eh_r2reh) 459 rh_eur2reh.append(rh_r2reh) 460 461 # Does this process have anything to receive? 462 thisr2_ct = r2_ct == self.world.rank 463 if np.any(thisr2_ct): 464 # Find unique composite indeces h = (n, s) 465 n_rt = n_ct[thisr2_ct] 466 n_rn = np.unique(n_rt) 467 nrn = len(n_rn) 468 h_rn = np.arange(nrn) + nh 469 nh += nrn 470 471 # Where to get the data from 472 r1_rt = r1_ct[thisr2_ct] 473 for r1 in np.unique(r1_rt): 474 thisr1_rt = r1_rt == r1 475 # What ns are the process getting? 476 n_reh = np.unique(n_rt[thisr1_rt]) 477 # Where to put them 478 for n in n_reh: 479 h = h_rn[np.where(n_rn == n)[0][0]] 480 h_r1rh[r1].append(h) 481 482 # h to t mapping 483 thisn_myt = n_myt == n 484 thish_myt = np.logical_and(thisn_myt, 485 thiss_myt) 486 h_myt[thish_myt] = h 487 488 return (data, myu_eu, myn_eueh, ik_r2, nrh_r2, 489 eh_eur2reh, rh_eur2reh, h_r1rh, h_myt, myt_myt) 490 491 def create_get_extraction_info(self): 492 """Creator component of the extraction information factory.""" 493 if self.calc_parallel: 494 return self.get_parallel_extraction_info 495 else: 496 return self.get_serial_extraction_info 497 498 @staticmethod 499 def get_serial_extraction_info(u, n_ct, r2_ct): 500 """Figure out where to extract the data from in the gs calc""" 501 # Let the process extract its own data 502 myu = u # The process has access to all data 503 r1_ct = r2_ct 504 myn_ct = n_ct 505 506 return myu, r1_ct, myn_ct 507 508 def get_parallel_extraction_info(self, u, n_ct, *unused): 509 """Figure out where to extract the data from in the gs calc""" 510 wfs = self.calc.wfs 511 # Find out where data is in wfs 512 k, s = divmod(u, wfs.nspins) 513 kptrank, q = wfs.kd.who_has(k) 514 myu = q * wfs.nspins + s 515 r1_ct, myn_ct = [], [] 516 for n in n_ct: 517 bandrank, myn = wfs.bd.who_has(n) 518 # XXX this will fail when using non-standard nesting 519 # of communicators. 520 r1 = (kptrank * wfs.gd.comm.size * wfs.bd.comm.size 521 + bandrank * wfs.gd.comm.size) 522 r1_ct.append(r1) 523 myn_ct.append(myn) 524 525 return myu, np.array(r1_ct), np.array(myn_ct) 526 527 @timer('Allocate transfer arrays') 528 def allocate_transfer_arrays(self, data, nrh_r2, ik_r2, h_r1rh): 529 """Allocate arrays for intermediate storage of data.""" 530 wfs = self.calc.wfs 531 kptex = wfs.kpt_u[0] 532 Pshape = kptex.projections.array.shape 533 Pdtype = kptex.projections.matrix.dtype 534 psitdtype = kptex.psit.array.dtype 535 536 # Number of h-indeces to receive 537 nrh_r1 = [len(h_rh) for h_rh in h_r1rh] 538 539 # if self.kptblockcomm.rank in range(len(ik_p)): 540 if data[2] is not None: 541 ik = data[2] 542 ng = self.pd0.ng_q[ik] 543 eps_r1rh, f_r1rh, P_r1rhI, psit_r1rhG = [], [], [], [] 544 for nrh in nrh_r1: 545 if nrh >= 1: 546 eps_r1rh.append(np.empty(nrh)) 547 f_r1rh.append(np.empty(nrh)) 548 P_r1rhI.append(np.empty((nrh,) + Pshape[1:], dtype=Pdtype)) 549 psit_r1rhG.append(np.empty((nrh, ng), dtype=psitdtype)) 550 else: 551 eps_r1rh.append(None) 552 f_r1rh.append(None) 553 P_r1rhI.append(None) 554 psit_r1rhG.append(None) 555 else: 556 eps_r1rh, f_r1rh, P_r1rhI, psit_r1rhG = None, None, None, None 557 558 eps_r2rh, f_r2rh, P_r2rhI, psit_r2rhG = [], [], [], [] 559 for nrh, ik in zip(nrh_r2, ik_r2): 560 if nrh: 561 eps_r2rh.append(np.empty(nrh)) 562 f_r2rh.append(np.empty(nrh)) 563 P_r2rhI.append(np.empty((nrh,) + Pshape[1:], dtype=Pdtype)) 564 ng = self.pd0.ng_q[ik] 565 psit_r2rhG.append(np.empty((nrh, ng), dtype=psitdtype)) 566 else: 567 eps_r2rh.append(None) 568 f_r2rh.append(None) 569 P_r2rhI.append(None) 570 psit_r2rhG.append(None) 571 572 return (eps_r1rh, f_r1rh, P_r1rhI, psit_r1rhG, 573 eps_r2rh, f_r2rh, P_r2rhI, psit_r2rhG) 574 575 def map_who_has(self, p, t_t): 576 """Convert k-point and transition index to global world rank 577 and local transition index""" 578 trank_t, myt_t = np.divmod(t_t, self.mynt) 579 return p * self.transitionblockscomm.size + trank_t, myt_t 580 581 @timer('Extracting eps, f and P_I from wfs') 582 def extract_wfs_data(self, myu, myn_eh): 583 wfs = self.calc.wfs 584 kpt = wfs.kpt_u[myu] 585 # Get eig and occ 586 eps_eh, f_eh = kpt.eps_n[myn_eh], kpt.f_n[myn_eh] / kpt.weight 587 588 # Get projections 589 assert kpt.projections.atom_partition.comm.size == 1 590 P_ehI = kpt.projections.array[myn_eh] 591 592 return eps_eh, f_eh, P_ehI 593 594 @timer('Extracting wave function from wfs') 595 def add_wave_function(self, myu, myn_eh, 596 eh_r2reh, rh_r2reh, psit_r2rhG): 597 """Add the plane wave coefficients of the smooth part of 598 the wave function to the psit_r2rtG arrays.""" 599 wfs = self.calc.wfs 600 kpt = wfs.kpt_u[myu] 601 602 for eh_reh, rh_reh, psit_rhG in zip(eh_r2reh, rh_r2reh, psit_r2rhG): 603 if eh_reh: 604 for eh, rh in zip(eh_reh, rh_reh): 605 psit_rhG[rh] = kpt.psit_nG[myn_eh[eh]] 606 607 @timer('Distributing kptdata') 608 def distribute_extracted_data(self, eps_r1rh, f_r1rh, P_r1rhI, psit_r1rhG, 609 eps_r2rh, f_r2rh, P_r2rhI, psit_r2rhG): 610 """Send the extracted data to appropriate destinations""" 611 # Store the data extracted by the process itself 612 rank = self.world.rank 613 # Check if there is actually some data to store 614 if eps_r2rh[rank] is not None: 615 eps_r1rh[rank] = eps_r2rh[rank] 616 f_r1rh[rank] = f_r2rh[rank] 617 P_r1rhI[rank] = P_r2rhI[rank] 618 psit_r1rhG[rank] = psit_r2rhG[rank] 619 620 # Receive data 621 if eps_r1rh is not None: # The process may not be receiving anything 622 for r1, (eps_rh, f_rh, 623 P_rhI, psit_rhG) in enumerate(zip(eps_r1rh, f_r1rh, 624 P_r1rhI, psit_r1rhG)): 625 # Check if there is any data to receive 626 if r1 != rank and eps_rh is not None: 627 rreq1 = self.world.receive(eps_rh, r1, 628 tag=201, block=False) 629 rreq2 = self.world.receive(f_rh, r1, 630 tag=202, block=False) 631 rreq3 = self.world.receive(P_rhI, r1, 632 tag=203, block=False) 633 rreq4 = self.world.receive(psit_rhG, r1, 634 tag=204, block=False) 635 self.rrequests += [rreq1, rreq2, rreq3, rreq4] 636 637 # Send data 638 for r2, (eps_rh, f_rh, 639 P_rhI, psit_rhG) in enumerate(zip(eps_r2rh, f_r2rh, 640 P_r2rhI, psit_r2rhG)): 641 # Check if there is any data to send 642 if r2 != rank and eps_rh is not None: 643 sreq1 = self.world.send(eps_rh, r2, tag=201, block=False) 644 sreq2 = self.world.send(f_rh, r2, tag=202, block=False) 645 sreq3 = self.world.send(P_rhI, r2, tag=203, block=False) 646 sreq4 = self.world.send(psit_rhG, r2, tag=204, block=False) 647 self.srequests += [sreq1, sreq2, sreq3, sreq4] 648 649 with self.timer('Waiting to complete mpi.receive'): 650 while self.rrequests: 651 self.world.wait(self.rrequests.pop(0)) 652 653 @timer('Collecting kptdata') 654 def collect_kptdata(self, data, h_r1rh, 655 eps_r1rh, f_r1rh, P_r1rhI, psit_r1rhG): 656 """From the extracted data, collect the KohnShamKPoint data arrays""" 657 658 # Some processes may not have to return a k-point 659 if data[0] is None: 660 return None 661 K, k_c, ik = data 662 663 # Allocate data arrays 664 wfs = self.calc.wfs 665 maxh_r1 = [max(h_rh) for h_rh in h_r1rh if h_rh] 666 if maxh_r1: 667 nh = max(maxh_r1) + 1 668 else: # Carry around empty array 669 assert self.ta == self.tb 670 nh = 1 671 eps_h = np.empty(nh) 672 f_h = np.empty(nh) 673 Ph = wfs.kpt_u[0].projections.new(nbands=nh, bcomm=None) 674 psit_hG = np.empty((nh, self.pd0.ng_q[ik]), 675 dtype=wfs.kpt_u[0].psit.array.dtype) 676 677 # Store extracted data in the arrays 678 for (h_rh, eps_rh, 679 f_rh, P_rhI, psit_rhG) in zip(h_r1rh, eps_r1rh, 680 f_r1rh, P_r1rhI, psit_r1rhG): 681 if h_rh: 682 eps_h[h_rh] = eps_rh 683 f_h[h_rh] = f_rh 684 Ph.array[h_rh] = P_rhI 685 psit_hG[h_rh] = psit_rhG 686 687 return (K, k_c, eps_h, f_h, Ph, psit_hG) 688 689 @timer('Unfolding arrays') 690 def unfold_arrays(self, eps_h, f_h, Ph, ut_hR, h_myt, myt_myt): 691 """Create transition data arrays from the composite h = (n, s) index""" 692 693 wfs = self.calc.wfs 694 # Allocate data arrays for the k-point 695 mynt = self.mynt 696 eps_myt = np.empty(mynt) 697 f_myt = np.empty(mynt) 698 P = wfs.kpt_u[0].projections.new(nbands=mynt, bcomm=None) 699 ut_mytR = wfs.gd.empty(self.mynt, wfs.dtype) 700 701 # Unfold k-point data 702 eps_myt[myt_myt] = eps_h[h_myt] 703 f_myt[myt_myt] = f_h[h_myt] 704 P.array[myt_myt] = Ph.array[h_myt] 705 ut_mytR[myt_myt] = ut_hR[h_myt] 706 707 return eps_myt, f_myt, P, ut_mytR 708 709 @timer('Extracting data from the ground state calculator object') 710 def serial_extract_kptdata(self, k_pc, n_t, s_t): 711 # All processes can access all data. Each process extracts it own data. 712 wfs = self.calc.wfs 713 714 # Do data extraction for the processes, which have data to extract 715 if self.kptblockcomm.rank in range(len(k_pc)): 716 # Find k-point indeces 717 k_c = k_pc[self.kptblockcomm.rank] 718 K = self.find_kpoint(k_c) 719 ik = wfs.kd.bz2ibz_k[K] 720 # Construct symmetry operators 721 (_, T, a_a, U_aii, shift_c, 722 time_reversal) = self.construct_symmetry_operators(K, k_c=k_c) 723 724 (myu_eu, myn_eurn, nh, h_eurn, h_myt, 725 myt_myt) = self.get_serial_extraction_protocol(ik, n_t, s_t) 726 727 # Allocate transfer arrays 728 eps_h = np.empty(nh) 729 f_h = np.empty(nh) 730 Ph = wfs.kpt_u[0].projections.new(nbands=nh, bcomm=None) 731 ut_hR = wfs.gd.empty(nh, wfs.dtype) 732 733 # Extract data from the ground state 734 for myu, myn_rn, h_rn in zip(myu_eu, myn_eurn, h_eurn): 735 kpt = wfs.kpt_u[myu] 736 with self.timer('Extracting eps, f and P_I from wfs'): 737 eps_h[h_rn] = kpt.eps_n[myn_rn] 738 f_h[h_rn] = kpt.f_n[myn_rn] / kpt.weight 739 Ph.array[h_rn] = kpt.projections.array[myn_rn] 740 741 with self.timer('Extracting, fourier transforming and ' 742 'symmetrizing wave function'): 743 for myn, h in zip(myn_rn, h_rn): 744 ut_hR[h] = T(wfs.pd.ifft(kpt.psit_nG[myn], kpt.q)) 745 746 # Symmetrize projections 747 with self.timer('Apply symmetry operations'): 748 P_ahi = [] 749 for a1, U_ii in zip(a_a, U_aii): 750 P_hi = np.ascontiguousarray(Ph[a1]) 751 # Apply symmetry operations. This will map a1 onto a2 752 np.dot(P_hi, U_ii, out=P_hi) 753 if time_reversal: 754 np.conj(P_hi, out=P_hi) 755 P_ahi.append(P_hi) 756 757 # Store symmetrized projectors 758 for a2, P_hi in enumerate(P_ahi): 759 I1, I2 = Ph.map[a2] 760 Ph.array[..., I1:I2] = P_hi 761 762 (eps_myt, f_myt, 763 P, ut_mytR) = self.unfold_arrays(eps_h, f_h, Ph, ut_hR, 764 h_myt, myt_myt) 765 766 return (K, eps_myt, f_myt, ut_mytR, P, shift_c) 767 768 @timer('Create data extraction protocol') 769 def get_serial_extraction_protocol(self, ik, n_t, s_t): 770 """Figure out how to extract data efficiently. 771 For the serial communicator, all processes can access all data, 772 and resultantly, there is no need to send any data. 773 """ 774 wfs = self.calc.wfs 775 776 # Only extract the transitions handled by the process itself 777 myt_myt = np.arange(self.tb - self.ta) 778 t_myt = range(self.ta, self.tb) 779 n_myt = n_t[t_myt] 780 s_myt = s_t[t_myt] 781 782 # In the ground state, kpts are indexed by u=(s, k) 783 myu_eu = [] 784 myn_eurn = [] 785 nh = 0 786 h_eurn = [] 787 h_myt = np.empty(self.tb - self.ta, dtype=int) 788 for s in set(s_myt): 789 thiss_myt = s_myt == s 790 n_ct = n_myt[thiss_myt] 791 792 # Find unique composite h = (n, u) indeces 793 n_rn = np.unique(n_ct) 794 nrn = len(n_rn) 795 h_eurn.append(np.arange(nrn) + nh) 796 nh += nrn 797 798 # Find mapping between h and the transition index 799 for n, h in zip(n_rn, h_eurn[-1]): 800 thisn_myt = n_myt == n 801 thish_myt = np.logical_and(thisn_myt, thiss_myt) 802 h_myt[thish_myt] = h 803 804 # Find out where data is in wfs 805 u = ik * wfs.nspins + s 806 # The process has access to all data 807 myu = u 808 myn_rn = n_rn 809 810 myu_eu.append(myu) 811 myn_eurn.append(myn_rn) 812 813 return myu_eu, myn_eurn, nh, h_eurn, h_myt, myt_myt 814 815 @timer('Identifying k-points') 816 def find_kpoint(self, k_c): 817 return self.kdtree.query(np.mod(np.mod(k_c, 1).round(6), 1))[1] 818 819 @timer('Apply symmetry operations') 820 def transform_and_symmetrize(self, K, k_c, Ph, psit_hG): 821 """Get wave function on a real space grid and symmetrize it 822 along with the corresponding PAW projections.""" 823 (_, T, a_a, U_aii, shift_c, 824 time_reversal) = self.construct_symmetry_operators(K, k_c=k_c) 825 826 # Symmetrize wave functions 827 wfs = self.calc.wfs 828 ik = wfs.kd.bz2ibz_k[K] 829 ut_hR = wfs.gd.empty(len(psit_hG), wfs.dtype) 830 with self.timer('Fourier transform and symmetrize wave functions'): 831 for h, psit_G in enumerate(psit_hG): 832 ut_hR[h] = T(self.pd0.ifft(psit_G, ik)) 833 834 # Symmetrize projections 835 P_ahi = [] 836 for a1, U_ii in zip(a_a, U_aii): 837 P_hi = np.ascontiguousarray(Ph[a1]) 838 # Apply symmetry operations. This will map a1 onto a2 839 np.dot(P_hi, U_ii, out=P_hi) 840 if time_reversal: 841 np.conj(P_hi, out=P_hi) 842 P_ahi.append(P_hi) 843 844 # Store symmetrized projectors 845 for a2, P_hi in enumerate(P_ahi): 846 I1, I2 = Ph.map[a2] 847 Ph.array[..., I1:I2] = P_hi 848 849 return Ph, ut_hR, shift_c 850 851 @timer('Construct symmetry operators') 852 def construct_symmetry_operators(self, K, k_c=None): 853 """Construct symmetry operators for wave function and PAW projections. 854 855 We want to transform a k-point in the irreducible part of the BZ to 856 the corresponding k-point with index K. 857 858 Returns U_cc, T, a_a, U_aii, shift_c and time_reversal, where: 859 860 * U_cc is a rotation matrix. 861 * T() is a function that transforms the periodic part of the wave 862 function. 863 * a_a is a list of symmetry related atom indices 864 * U_aii is a list of rotation matrices for the PAW projections 865 * shift_c is three integers: see code below. 866 * time_reversal is a flag - if True, projections should be complex 867 conjugated. 868 869 See the extract_orbitals() method for how to use these tuples. 870 """ 871 872 wfs = self.calc.wfs 873 kd = wfs.kd 874 875 s = kd.sym_k[K] 876 U_cc = kd.symmetry.op_scc[s] 877 time_reversal = kd.time_reversal_k[K] 878 ik = kd.bz2ibz_k[K] 879 if k_c is None: 880 k_c = kd.bzk_kc[K] 881 ik_c = kd.ibzk_kc[ik] 882 883 sign = 1 - 2 * time_reversal 884 shift_c = np.dot(U_cc, ik_c) - k_c * sign 885 886 try: 887 assert np.allclose(shift_c.round(), shift_c) 888 except AssertionError: 889 print('shift_c ' + str(shift_c), file=self.fd) 890 print('k_c ' + str(k_c), file=self.fd) 891 print('kd.bzk_kc[K] ' + str(kd.bzk_kc[K]), file=self.fd) 892 print('ik_c ' + str(ik_c), file=self.fd) 893 print('U_cc ' + str(U_cc), file=self.fd) 894 print('sign ' + str(sign), file=self.fd) 895 raise AssertionError 896 897 shift_c = shift_c.round().astype(int) 898 899 if (U_cc == np.eye(3)).all(): 900 def T(f_R): 901 return f_R 902 else: 903 N_c = self.calc.wfs.gd.N_c 904 i_cr = np.dot(U_cc.T, np.indices(N_c).reshape((3, -1))) 905 i = np.ravel_multi_index(i_cr, N_c, 'wrap') 906 907 def T(f_R): 908 return f_R.ravel()[i].reshape(N_c) 909 910 if time_reversal: 911 T0 = T 912 913 def T(f_R): 914 return T0(f_R).conj() 915 916 shift_c *= -1 917 918 a_a = [] 919 U_aii = [] 920 for a, id in enumerate(self.calc.wfs.setups.id_a): 921 b = kd.symmetry.a_sa[s, a] 922 S_c = np.dot(self.calc.spos_ac[a], U_cc) - self.calc.spos_ac[b] 923 x = np.exp(2j * np.pi * np.dot(ik_c, S_c)) 924 U_ii = wfs.setups[a].R_sii[s].T * x 925 a_a.append(b) 926 U_aii.append(U_ii) 927 928 shift0_c = (kd.bzk_kc[K] - k_c).round().astype(int) 929 shift_c += -shift0_c 930 931 return U_cc, T, a_a, U_aii, shift_c, time_reversal 932 933 934def get_calc(gs, fd=None, timer=None): 935 """Get ground state calculation object.""" 936 if isinstance(gs, GPAW): 937 return gs 938 else: 939 if timer is None: 940 def timer(*unused): 941 def __enter__(self): 942 pass 943 944 def __exit__(self): 945 pass 946 947 with timer('Read ground state'): 948 assert Path(gs).is_file() 949 if fd is not None: 950 print('Reading ground state calculation:\n %s' % gs, 951 file=fd) 952 with disable_dry_run(): 953 return GPAW(gs, txt=None, communicator=mpi.serial_comm) 954 955 956class PairMatrixElement: 957 """Class for calculating matrix elements for transitions in Kohn-Sham 958 linear response functions.""" 959 def __init__(self, kspair): 960 """ 961 Parameters 962 ---------- 963 kslrf : KohnShamLinearResponseFunction instance 964 """ 965 self.calc = kspair.calc 966 self.fd = kspair.fd 967 self.timer = kspair.timer 968 self.transitionblockscomm = kspair.transitionblockscomm 969 970 def initialize(self, *args, **kwargs): 971 """Initialize e.g. PAW corrections or other operations 972 ahead in time of integration.""" 973 pass 974 975 def __call__(self, kskptpair, *args, **kwargs): 976 """Calculate the matrix element for all transitions in kskptpairs.""" 977 raise NotImplementedError('Define specific matrix element') 978 979 980class PlaneWavePairDensity(PairMatrixElement): 981 """Class for calculating pair densities: 982 983 n_T(q+G) = <s'n'k'| e^(i (q + G) r) |snk> 984 985 in the plane wave mode""" 986 def __init__(self, kspair): 987 PairMatrixElement.__init__(self, kspair) 988 989 # Save PAW correction for all calls with same q_c 990 self.Q_aGii = None 991 self.currentq_c = None 992 993 def initialize(self, pd): 994 """Initialize PAW corrections ahead in time of integration.""" 995 self.initialize_paw_corrections(pd) 996 997 @timer('Initialize PAW corrections') 998 def initialize_paw_corrections(self, pd): 999 """Initialize PAW corrections, if not done already, for the given q""" 1000 q_c = pd.kd.bzk_kc[0] 1001 if self.Q_aGii is None or not np.allclose(q_c - self.currentq_c, 0.): 1002 self.Q_aGii = self._initialize_paw_corrections(pd) 1003 self.currentq_c = q_c 1004 1005 def _initialize_paw_corrections(self, pd): 1006 wfs = self.calc.wfs 1007 spos_ac = self.calc.spos_ac 1008 G_Gv = pd.get_reciprocal_vectors() 1009 1010 pos_av = np.dot(spos_ac, pd.gd.cell_cv) 1011 1012 # Collect integrals for all species: 1013 Q_xGii = {} 1014 for id, atomdata in wfs.setups.setups.items(): 1015 Q_Gii = two_phi_planewave_integrals(G_Gv, atomdata) 1016 ni = atomdata.ni 1017 Q_Gii.shape = (-1, ni, ni) 1018 1019 Q_xGii[id] = Q_Gii 1020 1021 Q_aGii = [] 1022 for a, atomdata in enumerate(wfs.setups): 1023 id = wfs.setups.id_a[a] 1024 Q_Gii = Q_xGii[id] 1025 x_G = np.exp(-1j * np.dot(G_Gv, pos_av[a])) 1026 Q_aGii.append(x_G[:, np.newaxis, np.newaxis] * Q_Gii) 1027 1028 return Q_aGii 1029 1030 @timer('Calculate pair density') 1031 def __call__(self, kskptpair, pd): 1032 """Calculate the pair densities for all transitions: 1033 n_t(q+G) = <s'n'k+q| e^(i (q + G) r) |snk> 1034 = <snk| e^(-i (q + G) r) |s'n'k+q> 1035 """ 1036 Q_aGii = self.get_paw_projectors(pd) 1037 Q_G = self.get_fft_indices(kskptpair, pd) 1038 mynt, nt, ta, tb = kskptpair.transition_distribution() 1039 1040 n_mytG = pd.empty(mynt) 1041 1042 # Calculate smooth part of the pair densities: 1043 with self.timer('Calculate smooth part'): 1044 ut1cc_mytR = kskptpair.kpt1.ut_tR.conj() 1045 n_mytR = ut1cc_mytR * kskptpair.kpt2.ut_tR 1046 # Unvectorized 1047 for myt in range(tb - ta): 1048 n_mytG[myt] = pd.fft(n_mytR[myt], 0, Q_G) * pd.gd.dv 1049 1050 # Calculate PAW corrections with numpy 1051 with self.timer('PAW corrections'): 1052 P1 = kskptpair.kpt1.projections 1053 P2 = kskptpair.kpt2.projections 1054 for (Q_Gii, (a1, P1_myti), 1055 (a2, P2_myti)) in zip(Q_aGii, P1.items(), P2.items()): 1056 P1cc_myti = P1_myti[:tb - ta].conj() 1057 C1_Gimyt = np.tensordot(Q_Gii, P1cc_myti, axes=([1, 1])) 1058 P2_imyt = P2_myti.T[:, :tb - ta] 1059 n_mytG[:tb - ta] += np.sum(C1_Gimyt * P2_imyt[np.newaxis, 1060 :, :], axis=1).T 1061 1062 # Attach the calculated pair density to the KohnShamKPointPair object 1063 kskptpair.attach('n_mytG', 'n_tG', n_mytG) 1064 1065 def get_paw_projectors(self, pd): 1066 """Make sure PAW correction has been initialized properly 1067 and return projectors""" 1068 self.initialize_paw_corrections(pd) 1069 return self.Q_aGii 1070 1071 @timer('Get G-vector indices') 1072 def get_fft_indices(self, kskptpair, pd): 1073 """Get indices for G-vectors inside cutoff sphere.""" 1074 kpt1 = kskptpair.kpt1 1075 kpt2 = kskptpair.kpt2 1076 kd = self.calc.wfs.kd 1077 q_c = pd.kd.bzk_kc[0] 1078 1079 N_G = pd.Q_qG[0] 1080 1081 shift_c = kpt1.shift_c - kpt2.shift_c 1082 shift_c += (q_c - kd.bzk_kc[kpt2.K] 1083 + kd.bzk_kc[kpt1.K]).round().astype(int) 1084 if shift_c.any(): 1085 n_cG = np.unravel_index(N_G, pd.gd.N_c) 1086 n_cG = [n_G + shift for n_G, shift in zip(n_cG, shift_c)] 1087 N_G = np.ravel_multi_index(n_cG, pd.gd.N_c, 'wrap') 1088 return N_G 1089