1 2import functools 3import numbers 4import sys 5from math import pi 6 7import numpy as np 8from scipy.spatial import Delaunay, cKDTree 9 10from ase.units import Ha 11from gpaw.utilities import convert_string_to_fd 12from ase.utils.timing import timer, Timer 13 14import gpaw.mpi as mpi 15from gpaw import GPAW, disable_dry_run 16from gpaw.fd_operators import Gradient 17from gpaw.kpt_descriptor import KPointDescriptor 18from gpaw.response.math_func import (two_phi_planewave_integrals, 19 two_phi_nabla_planewave_integrals) 20from gpaw.utilities.blas import gemm 21from gpaw.utilities.progressbar import ProgressBar 22from gpaw.wavefunctions.pw import PWLFC 23from gpaw.bztools import get_reduced_bz, unique_rows 24 25 26class KPoint: 27 def __init__(self, s, K, n1, n2, blocksize, na, nb, 28 ut_nR, eps_n, f_n, P_ani, shift_c): 29 self.s = s # spin index 30 self.K = K # BZ k-point index 31 self.n1 = n1 # first band 32 self.n2 = n2 # first band not included 33 self.blocksize = blocksize 34 self.na = na # first band of block 35 self.nb = nb # first band of block not included 36 self.ut_nR = ut_nR # periodic part of wave functions in real-space 37 self.eps_n = eps_n # eigenvalues 38 self.f_n = f_n # occupation numbers 39 self.P_ani = P_ani # PAW projections 40 self.shift_c = shift_c # long story - see the 41 # PairDensity.construct_symmetry_operators() method 42 43 44class KPointPair: 45 """This class defines the kpoint-pair container object. 46 47 Used for calculating pair quantities it contains two kpoints, 48 and an associated set of Fourier components.""" 49 def __init__(self, kpt1, kpt2, Q_G): 50 self.kpt1 = kpt1 51 self.kpt2 = kpt2 52 self.Q_G = Q_G 53 54 def get_k1(self): 55 """ Return KPoint object 1.""" 56 return self.kpt1 57 58 def get_k2(self): 59 """ Return KPoint object 2.""" 60 return self.kpt2 61 62 def get_planewave_indices(self): 63 """ Return the planewave indices associated with this pair.""" 64 return self.Q_G 65 66 def get_transition_energies(self, n_n, m_m): 67 """Return the energy difference for specified bands.""" 68 n_n = np.array(n_n) 69 m_m = np.array(m_m) 70 kpt1 = self.kpt1 71 kpt2 = self.kpt2 72 deps_nm = (kpt1.eps_n[n_n - self.kpt1.n1][:, np.newaxis] - 73 kpt2.eps_n[m_m - self.kpt2.n1]) 74 return deps_nm 75 76 def get_occupation_differences(self, n_n, m_m): 77 """Get difference in occupation factor between specified bands.""" 78 n_n = np.array(n_n) 79 m_m = np.array(m_m) 80 kpt1 = self.kpt1 81 kpt2 = self.kpt2 82 df_nm = (kpt1.f_n[n_n - self.kpt1.n1][:, np.newaxis] - 83 kpt2.f_n[m_m - self.kpt2.n1]) 84 return df_nm 85 86 87class PWSymmetryAnalyzer: 88 """Class for handling planewave symmetries.""" 89 def __init__(self, kd, pd, txt=sys.stdout, 90 disable_point_group=False, 91 disable_non_symmorphic=True, 92 disable_time_reversal=False, 93 timer=None): 94 """Creates a PWSymmetryAnalyzer object. 95 96 Determines which of the symmetries of the atomic structure 97 that is compatible with the reciprocal lattice. Contains the 98 necessary functions for mapping quantities between kpoints, 99 and or symmetrizing arrays. 100 101 kd: KPointDescriptor 102 The kpoint descriptor containing the 103 information about symmetries and kpoints. 104 pd: PWDescriptor 105 Plane wave descriptor that contains the reciprocal 106 lattice . 107 txt: str 108 Output file. 109 disable_point_group: bool 110 Switch for disabling point group symmetries. 111 disable_non_symmorphic: 112 Switch for disabling non symmorphic symmetries. 113 disable_time_reversal: 114 Switch for disabling time reversal. 115 """ 116 self.pd = pd 117 self.kd = kd 118 self.fd = txt 119 120 # Caveats 121 assert disable_non_symmorphic, \ 122 print('You are not allowed to use non symmorphic syms, sorry. ', 123 file=self.fd) 124 125 # Settings 126 self.disable_point_group = disable_point_group 127 self.disable_time_reversal = disable_time_reversal 128 self.disable_non_symmorphic = disable_non_symmorphic 129 if (kd.symmetry.has_inversion or not kd.symmetry.time_reversal) and \ 130 not self.disable_time_reversal: 131 print('\nThe ground calculation does not support time-reversal ' + 132 'symmetry possibly because it has an inversion center ' + 133 'or that it has been manually deactivated. \n', file=self.fd) 134 self.disable_time_reversal = True 135 136 self.disable_symmetries = (self.disable_point_group and 137 self.disable_time_reversal and 138 self.disable_non_symmorphic) 139 140 # Number of symmetries 141 U_scc = kd.symmetry.op_scc 142 self.nU = len(U_scc) 143 144 self.nsym = 2 * self.nU 145 self.use_time_reversal = not self.disable_time_reversal 146 147 # Which timer to use 148 self.timer = timer or Timer() 149 150 self.KDTree = cKDTree(np.mod(np.mod(kd.bzk_kc, 1).round(6), 1)) 151 152 # Initialize 153 self.initialize() 154 155 @timer('Initialize') 156 def initialize(self): 157 """Initialize relevant quantities.""" 158 self.infostring = '' 159 if self.disable_point_group: 160 self.infostring += 'Point group not included. ' 161 else: 162 self.infostring += 'Point group included. ' 163 164 if self.disable_time_reversal: 165 self.infostring += 'Time reversal not included. ' 166 else: 167 self.infostring += 'Time reversal included. ' 168 169 if self.disable_non_symmorphic: 170 self.infostring += 'Disabled non symmorphic symmetries. ' 171 else: 172 self.infostring += 'Time reversal included. ' 173 174 if self.disable_symmetries: 175 self.infostring += 'All symmetries have been disabled. ' 176 177 # Do the work 178 self.analyze_symmetries() 179 self.analyze_kpoints() 180 self.initialize_G_maps() 181 182 # Print info 183 print(self.infostring, file=self.fd) 184 self.print_symmetries() 185 186 def find_kpoint(self, k_c): 187 return self.KDTree.query(np.mod(np.mod(k_c, 1).round(6), 1))[1] 188 189 def print_symmetries(self): 190 """Handsome print function for symmetry operations.""" 191 192 p = functools.partial(print, file=self.fd) 193 194 p() 195 nx = 6 if self.disable_non_symmorphic else 3 196 ns = len(self.s_s) 197 y = 0 198 for y in range((ns + nx - 1) // nx): 199 for c in range(3): 200 for x in range(nx): 201 s = x + y * nx 202 if s == ns: 203 break 204 tmp = self.get_symmetry_operator(self.s_s[s]) 205 op_cc, sign, TR, shift_c, ft_c = tmp 206 op_c = sign * op_cc[c] 207 p(' (%2d %2d %2d)' % tuple(op_c), end='') 208 p() 209 p() 210 211 @timer('Analyze') 212 def analyze_kpoints(self): 213 """Calculate the reduction in the number of kpoints.""" 214 K_gK = self.group_kpoints() 215 ng = len(K_gK) 216 self.infostring += '{0} groups of equivalent kpoints. '.format(ng) 217 percent = (1. - (ng + 0.) / self.kd.nbzkpts) * 100 218 self.infostring += '{0}% reduction. '.format(percent) 219 220 @timer('Analyze symmetries.') 221 def analyze_symmetries(self): 222 r"""Determine allowed symmetries. 223 224 An direct symmetry U must fulfill:: 225 226 U \mathbf{q} = q + \Delta 227 228 Under time-reversal (indirect) it must fulfill:: 229 230 -U \mathbf{q} = q + \Delta 231 232 where :math:`\Delta` is a reciprocal lattice vector. 233 """ 234 pd = self.pd 235 236 # Shortcuts 237 q_c = pd.kd.bzk_kc[0] 238 kd = self.kd 239 240 U_scc = kd.symmetry.op_scc 241 nU = self.nU 242 nsym = self.nsym 243 244 shift_sc = np.zeros((nsym, 3), int) 245 conserveq_s = np.zeros(nsym, bool) 246 247 newq_sc = np.dot(U_scc, q_c) 248 249 # Direct symmetries 250 dshift_sc = (newq_sc - q_c[np.newaxis]).round().astype(int) 251 inds_s = np.argwhere((newq_sc == q_c[np.newaxis] + dshift_sc).all(1)) 252 conserveq_s[inds_s] = True 253 254 shift_sc[:nU] = dshift_sc 255 256 # Time reversal 257 trshift_sc = (-newq_sc - q_c[np.newaxis]).round().astype(int) 258 trinds_s = np.argwhere((-newq_sc == q_c[np.newaxis] + 259 trshift_sc).all(1)) + nU 260 conserveq_s[trinds_s] = True 261 shift_sc[nU:nsym] = trshift_sc 262 263 # The indices of the allowed symmetries 264 s_s = conserveq_s.nonzero()[0] 265 266 # Filter out disabled symmetries 267 if self.disable_point_group: 268 s_s = list(filter(self.is_not_point_group, s_s)) 269 270 if self.disable_time_reversal: 271 s_s = list(filter(self.is_not_time_reversal, s_s)) 272 273 if self.disable_non_symmorphic: 274 s_s = list(filter(self.is_not_non_symmorphic, s_s)) 275 276# stmp_s = [] 277# for s in s_s: 278# if self.kd.bz2bz_ks[0, s] == -1: 279# assert (self.kd.bz2bz_ks[:, s] == -1).all() 280# else: 281# stmp_s.append(s) 282 283# s_s = stmp_s 284 285 self.infostring += 'Found {} allowed symmetries. '.format(len(s_s)) 286 self.s_s = s_s 287 self.shift_sc = shift_sc 288 289 def is_not_point_group(self, s): 290 U_scc = self.kd.symmetry.op_scc 291 nU = self.nU 292 return (U_scc[s % nU] == np.eye(3)).all() 293 294 def is_not_time_reversal(self, s): 295 nU = self.nU 296 return not bool(s // nU) 297 298 def is_not_non_symmorphic(self, s): 299 ft_sc = self.kd.symmetry.ft_sc 300 nU = self.nU 301 return not bool(ft_sc[s % nU].any()) 302 303 def how_many_symmetries(self): 304 """Return number of symmetries.""" 305 return len(self.s_s) 306 307 @timer('Group kpoints') 308 def group_kpoints(self, K_k=None): 309 """Group kpoints according to the reduced symmetries""" 310 if K_k is None: 311 K_k = np.arange(self.kd.nbzkpts) 312 s_s = self.s_s 313 bz2bz_ks = self.kd.bz2bz_ks 314 nk = len(bz2bz_ks) 315 sbz2sbz_ks = bz2bz_ks[K_k][:, s_s] # Reduced number of symmetries 316 # Avoid -1 (see documentation in gpaw.symmetry) 317 sbz2sbz_ks[sbz2sbz_ks == -1] = nk 318 319 smallestk_k = np.sort(sbz2sbz_ks)[:, 0] 320 k2g_g = np.unique(smallestk_k, return_index=True)[1] 321 322 K_gs = sbz2sbz_ks[k2g_g] 323 K_gk = [np.unique(K_s[K_s != nk]) for K_s in K_gs] 324 325 return K_gk 326 327 def get_BZ(self): 328 # Get the little group of q 329 U_scc = [] 330 for s in self.s_s: 331 U_cc, sign, _, _, _ = self.get_symmetry_operator(s) 332 U_scc.append(sign * U_cc) 333 U_scc = np.array(U_scc) 334 335 # Determine the irreducible BZ 336 bzk_kc, ibzk_kc = get_reduced_bz(self.pd.gd.cell_cv, 337 U_scc, 338 False) 339 340 return bzk_kc 341 342 def get_reduced_kd(self, pbc_c=np.ones(3, bool)): 343 # Get the little group of q 344 U_scc = [] 345 for s in self.s_s: 346 U_cc, sign, _, _, _ = self.get_symmetry_operator(s) 347 U_scc.append(sign * U_cc) 348 U_scc = np.array(U_scc) 349 350 # Determine the irreducible BZ 351 bzk_kc, ibzk_kc = get_reduced_bz(self.pd.gd.cell_cv, 352 U_scc, 353 False, 354 pbc_c=pbc_c) 355 356 n = 3 357 N_xc = np.indices((n, n, n)).reshape((3, n**3)).T - n // 2 358 359 # Find the irreducible kpoints 360 tess = Delaunay(ibzk_kc) 361 ik_kc = [] 362 for N_c in N_xc: 363 k_kc = self.kd.bzk_kc + N_c 364 k_kc = k_kc[tess.find_simplex(k_kc) >= 0] 365 if not len(ik_kc) and len(k_kc): 366 ik_kc = unique_rows(k_kc) 367 elif len(k_kc): 368 ik_kc = unique_rows(np.append(k_kc, ik_kc, axis=0)) 369 370 return KPointDescriptor(ik_kc) 371 372 def unfold_kpoints(self, points_pv, tol=1e-8, mod=None): 373 points_pc = np.dot(points_pv, self.pd.gd.cell_cv.T) / (2 * np.pi) 374 375 # Get the little group of q 376 U_scc = [] 377 for s in self.s_s: 378 U_cc, sign, _, _, _ = self.get_symmetry_operator(s) 379 U_scc.append(sign * U_cc) 380 U_scc = np.array(U_scc) 381 382 points = np.concatenate(np.dot(points_pc, U_scc.transpose(0, 2, 1))) 383 points = unique_rows(points, tol=tol, mod=mod) 384 points = np.dot(points, self.pd.gd.icell_cv) * (2 * np.pi) 385 return points 386 387 def get_kpoint_weight(self, k_c): 388 K = self.find_kpoint(k_c) 389 iK = self.kd.bz2ibz_k[K] 390 K_k = self.unfold_ibz_kpoint(iK) 391 K_gK = self.group_kpoints(K_k) 392 393 for K_k in K_gK: 394 if K in K_k: 395 if self.kd.refine_info is not None: 396 weight = sum(self.kd.refine_info.weight_k[K_k]) 397 return weight 398 else: 399 return len(K_k) 400 401 def get_kpoint_mapping(self, K1, K2): 402 """Get index of symmetry for mapping between K1 and K2""" 403 s_s = self.s_s 404 bz2bz_ks = self.kd.bz2bz_ks 405 bzk2rbz_s = bz2bz_ks[K1][s_s] 406 try: 407 s = np.argwhere(bzk2rbz_s == K2)[0][0] 408 except IndexError: 409 print('K = {0} cannot be mapped into K = {1}'.format(K1, K2), 410 file=self.fd) 411 raise 412 return s_s[s] 413 414 def get_shift(self, K1, K2, U_cc, sign): 415 """Get shift for mapping between K1 and K2.""" 416 kd = self.kd 417 k1_c = kd.bzk_kc[K1] 418 k2_c = kd.bzk_kc[K2] 419 420 shift_c = np.dot(U_cc, k1_c) - k2_c * sign 421 assert np.allclose(shift_c.round(), shift_c) 422 shift_c = shift_c.round().astype(int) 423 424 return shift_c 425 426 @timer('map_G') 427 def map_G(self, K1, K2, a_MG): 428 """Map a function of G from K1 to K2. """ 429 if len(a_MG) == 0: 430 return [] 431 432 if K1 == K2: 433 return a_MG 434 435 G_G, sign = self.map_G_vectors(K1, K2) 436 437 s = self.get_kpoint_mapping(K1, K2) 438 U_cc, _, TR, shift_c, ft_c = self.get_symmetry_operator(s) 439 440 return TR(a_MG[..., G_G]) 441 442 def symmetrize_wGG(self, A_wGG): 443 """Symmetrize an array in GG'.""" 444 445 for A_GG in A_wGG: 446 tmp_GG = np.zeros_like(A_GG) 447 448 for s in self.s_s: 449 G_G, sign, _ = self.G_sG[s] 450 if sign == 1: 451 tmp_GG += A_GG[G_G, :][:, G_G] 452 if sign == -1: 453 tmp_GG += A_GG[G_G, :][:, G_G].T 454 455 A_GG[:] = tmp_GG / self.how_many_symmetries() 456 457 def symmetrize_wxx(self, A_wxx, optical_limit=False): 458 """Symmetrize an array in xx'.""" 459 tmp_wxx = np.zeros_like(A_wxx) 460 461 A_cv = self.pd.gd.cell_cv 462 iA_cv = self.pd.gd.icell_cv 463 464 if self.use_time_reversal: 465 AT_wxx = np.transpose(A_wxx, (0, 2, 1)) 466 467 for s in self.s_s: 468 G_G, sign, shift_c = self.G_sG[s] 469 if optical_limit: 470 G_G = np.array(G_G) + 2 471 G_G = np.insert(G_G, 0, [0, 1]) 472 U_cc, _, TR, shift_c, ft_c = self.get_symmetry_operator(s) 473 M_vv = np.dot(np.dot(A_cv.T, U_cc.T), iA_cv) 474 475 if sign == 1: 476 tmp = A_wxx[:, G_G, :][:, :, G_G] 477 if optical_limit: 478 tmp[:, 0:3, :] = np.transpose(np.dot(M_vv.T, 479 tmp[:, 0:3, :]), 480 (1, 0, 2)) 481 tmp[:, :, 0:3] = np.dot(tmp[..., 0:3], M_vv) 482 tmp_wxx += tmp 483 elif sign == -1: 484 tmp = AT_wxx[:, G_G, :][:, :, G_G] 485 if optical_limit: 486 tmp[:, 0:3, :] = np.transpose(np.dot(M_vv.T, 487 tmp[:, 0:3, :]), 488 (1, 0, 2)) * sign 489 tmp[:, :, 0:3] = np.dot(tmp[:, :, 0:3], M_vv) * sign 490 tmp_wxx += tmp 491 492 # Inplace overwriting 493 A_wxx[:] = tmp_wxx / self.how_many_symmetries() 494 495 def symmetrize_wxvG(self, A_wxvG): 496 """Symmetrize chi0_wxvG""" 497 A_cv = self.pd.gd.cell_cv 498 iA_cv = self.pd.gd.icell_cv 499 500 if self.use_time_reversal: 501 # ::-1 corresponds to transpose in wing indices 502 AT_wxvG = A_wxvG[:, ::-1] 503 504 tmp_wxvG = np.zeros_like(A_wxvG) 505 for s in self.s_s: 506 G_G, sign, shift_c = self.G_sG[s] 507 U_cc, _, TR, shift_c, ft_c = self.get_symmetry_operator(s) 508 M_vv = np.dot(np.dot(A_cv.T, U_cc.T), iA_cv) 509 if sign == 1: 510 tmp = sign * np.dot(M_vv.T, A_wxvG[..., G_G]) 511 elif sign == -1: 512 tmp = sign * np.dot(M_vv.T, AT_wxvG[..., G_G]) 513 tmp_wxvG += np.transpose(tmp, (1, 2, 0, 3)) 514 515 # Overwrite the input 516 A_wxvG[:] = tmp_wxvG / self.how_many_symmetries() 517 518 def symmetrize_wvv(self, A_wvv): 519 """Symmetrize chi_wvv.""" 520 A_cv = self.pd.gd.cell_cv 521 iA_cv = self.pd.gd.icell_cv 522 tmp_wvv = np.zeros_like(A_wvv) 523 if self.use_time_reversal: 524 AT_wvv = np.transpose(A_wvv, (0, 2, 1)) 525 526 for s in self.s_s: 527 G_G, sign, shift_c = self.G_sG[s] 528 U_cc, _, TR, shift_c, ft_c = self.get_symmetry_operator(s) 529 M_vv = np.dot(np.dot(A_cv.T, U_cc.T), iA_cv) 530 if sign == 1: 531 tmp = np.dot(np.dot(M_vv.T, A_wvv), M_vv) 532 elif sign == -1: 533 tmp = np.dot(np.dot(M_vv.T, AT_wvv), M_vv) 534 tmp_wvv += np.transpose(tmp, (1, 0, 2)) 535 536 # Overwrite the input 537 A_wvv[:] = tmp_wvv / self.how_many_symmetries() 538 539 @timer('map_v') 540 def map_v(self, K1, K2, a_Mv): 541 """Map a function of v (cartesian component) from K1 to K2.""" 542 543 if len(a_Mv) == 0: 544 return [] 545 546 if K1 == K2: 547 return a_Mv 548 549 A_cv = self.pd.gd.cell_cv 550 iA_cv = self.pd.gd.icell_cv 551 552 # Get symmetry 553 s = self.get_kpoint_mapping(K1, K2) 554 U_cc, sign, TR, _, ft_c = self.get_symmetry_operator(s) 555 556 # Create cartesian operator 557 M_vv = np.dot(np.dot(A_cv.T, U_cc.T), iA_cv) 558 return sign * np.dot(TR(a_Mv), M_vv) 559 560 def timereversal(self, s): 561 """Is this a time-reversal symmetry?""" 562 tr = bool(s // self.nU) 563 return tr 564 565 def get_symmetry_operator(self, s): 566 """Return symmetry operator s.""" 567 U_scc = self.kd.symmetry.op_scc 568 ft_sc = self.kd.symmetry.op_scc 569 570 reds = s % self.nU 571 if self.timereversal(s): 572 TR = np.conj 573 sign = -1 574 else: 575 sign = 1 576 577 def TR(x): 578 return x 579 580 return U_scc[reds], sign, TR, self.shift_sc[s], ft_sc[reds] 581 582 @timer('map_G_vectors') 583 def map_G_vectors(self, K1, K2): 584 """Return G vector mapping.""" 585 s = self.get_kpoint_mapping(K1, K2) 586 G_G, sign, shift_c = self.G_sG[s] 587 588 return G_G, sign 589 590 def initialize_G_maps(self): 591 """Calculate the Gvector mappings.""" 592 pd = self.pd 593 B_cv = 2.0 * np.pi * pd.gd.icell_cv 594 G_Gv = pd.get_reciprocal_vectors(add_q=False) 595 G_Gc = np.dot(G_Gv, np.linalg.inv(B_cv)) 596 Q_G = pd.Q_qG[0] 597 598 G_sG = [None] * self.nsym 599 UG_sGc = [None] * self.nsym 600 Q_sG = [None] * self.nsym 601 for s in self.s_s: 602 U_cc, sign, TR, shift_c, ft_c = self.get_symmetry_operator(s) 603 iU_cc = np.linalg.inv(U_cc).T 604 UG_Gc = np.dot(G_Gc - shift_c, sign * iU_cc) 605 606 assert np.allclose(UG_Gc.round(), UG_Gc) 607 UQ_G = np.ravel_multi_index(UG_Gc.round().astype(int).T, 608 pd.gd.N_c, 'wrap') 609 610 G_G = len(Q_G) * [None] 611 for G, UQ in enumerate(UQ_G): 612 try: 613 G_G[G] = np.argwhere(Q_G == UQ)[0][0] 614 except IndexError: 615 print('This should not be possible but' + 616 'a G-vector was mapped outside the sphere') 617 raise IndexError 618 UG_sGc[s] = UG_Gc 619 Q_sG[s] = UQ_G 620 G_sG[s] = [G_G, sign, shift_c] 621 self.G_Gc = G_Gc 622 self.UG_sGc = UG_sGc 623 self.Q_sG = Q_sG 624 self.G_sG = G_sG 625 626 def unfold_ibz_kpoint(self, ik): 627 """Return kpoints related to irreducible kpoint.""" 628 kd = self.kd 629 K_k = np.unique(kd.bz2bz_ks[kd.ibz2bz_k[ik]]) 630 K_k = K_k[K_k != -1] 631 return K_k 632 633 634class PairDensity: 635 def __init__(self, gs, ecut=50, response='density', 636 ftol=1e-6, threshold=1, 637 real_space_derivatives=False, 638 world=mpi.world, txt='-', timer=None, 639 nblocks=1, gate_voltage=None, 640 paw_correction='brute-force', **unused): 641 """Density matrix elements 642 643 Parameters 644 ---------- 645 ftol : float 646 Threshold determining whether a band is completely filled 647 (f > 1 - ftol) or completely empty (f < ftol). 648 threshold : float 649 Numerical threshold for the optical limit k dot p perturbation 650 theory expansion. 651 real_space_derivatives : bool 652 Calculate nabla matrix elements (in the optical limit) 653 using a real space finite difference approximation. 654 gate_voltage : float 655 Shift the fermi level by gate_voltage [Hartree]. 656 """ 657 self.world = world 658 self.fd = convert_string_to_fd(txt, world) 659 self.timer = timer or Timer() 660 661 with self.timer('Read ground state'): 662 if not isinstance(gs, GPAW): 663 print('Reading ground state calculation:\n %s' % gs, 664 file=self.fd) 665 with disable_dry_run(): 666 calc = GPAW(gs, communicator=mpi.serial_comm) 667 else: 668 calc = gs 669 assert calc.wfs.world.size == 1 670 671 assert calc.wfs.kd.symmetry.symmorphic 672 self.calc = calc 673 674 if ecut is not None: 675 ecut /= Ha 676 677 if gate_voltage is not None: 678 gate_voltage = gate_voltage / Ha 679 680 self.response = response 681 self.ecut = ecut 682 self.ftol = ftol 683 self.threshold = threshold 684 self.real_space_derivatives = real_space_derivatives 685 self.gate_voltage = gate_voltage 686 687 if nblocks == 1: 688 self.blockcomm = world.new_communicator([world.rank]) 689 self.kncomm = world 690 else: 691 assert world.size % nblocks == 0, world.size 692 rank1 = world.rank // nblocks * nblocks 693 rank2 = rank1 + nblocks 694 self.blockcomm = self.world.new_communicator(range(rank1, rank2)) 695 ranks = range(world.rank % nblocks, world.size, nblocks) 696 self.kncomm = self.world.new_communicator(ranks) 697 698 self.fermi_level = self.calc.wfs.fermi_level 699 700 if gate_voltage is not None: 701 self.add_gate_voltage(gate_voltage) 702 703 self.spos_ac = calc.spos_ac 704 705 self.nocc1 = None # number of completely filled bands 706 self.nocc2 = None # number of non-empty bands 707 self.count_occupied_bands() 708 709 self.ut_sKnvR = None # gradient of wave functions for optical limit 710 711 self.vol = abs(np.linalg.det(calc.wfs.gd.cell_cv)) 712 713 kd = self.calc.wfs.kd 714 self.KDTree = cKDTree(np.mod(np.mod(kd.bzk_kc, 1).round(6), 1)) 715 print('Number of blocks:', nblocks, file=self.fd) 716 717 self.paw_correction = paw_correction 718 719 def find_kpoint(self, k_c): 720 return self.KDTree.query(np.mod(np.mod(k_c, 1).round(6), 1))[1] 721 722 def add_gate_voltage(self, gate_voltage=0): 723 """Shifts the Fermi-level by e * Vg. By definition e = 1.""" 724 assert self.calc.wfs.occupations.name in {'fermi-dirac', 'zero-width'} 725 print('Shifting Fermi-level by %.2f eV' % (gate_voltage * Ha), 726 file=self.fd) 727 self.fermi_level += gate_voltage 728 for kpt in self.calc.wfs.kpt_u: 729 kpt.f_n = (self.shift_occupations(kpt.eps_n, gate_voltage) * 730 kpt.weight) 731 732 def shift_occupations(self, eps_n, gate_voltage): 733 """Shift fermilevel.""" 734 fermi = self.fermi_level 735 width = getattr(self.calc.wfs.occupations, '_width', 0.0) / Ha 736 if width < 1e-9: 737 return (eps_n < fermi).astype(float) 738 else: 739 tmp = (eps_n - fermi) / width 740 f_n = np.zeros_like(eps_n) 741 f_n[tmp <= 100] = 1 / (1 + np.exp(tmp[tmp <= 100])) 742 f_n[tmp > 100] = 0.0 743 return f_n 744 745 def count_occupied_bands(self): 746 self.nocc1 = 9999999 747 self.nocc2 = 0 748 for kpt in self.calc.wfs.kpt_u: 749 f_n = kpt.f_n / kpt.weight 750 self.nocc1 = min((f_n > 1 - self.ftol).sum(), self.nocc1) 751 self.nocc2 = max((f_n > self.ftol).sum(), self.nocc2) 752 print('Number of completely filled bands:', self.nocc1, file=self.fd) 753 print('Number of partially filled bands:', self.nocc2, file=self.fd) 754 print('Total number of bands:', self.calc.wfs.bd.nbands, 755 file=self.fd) 756 757 def distribute_k_points_and_bands(self, band1, band2, kpts=None): 758 """Distribute spins, k-points and bands. 759 760 nbands: int 761 Number of bands for each spin/k-point combination. 762 763 The attribute self.mysKn1n2 will be set to a list of (s, K, n1, n2) 764 tuples that this process handles. 765 """ 766 767 wfs = self.calc.wfs 768 769 if kpts is None: 770 kpts = np.arange(wfs.kd.nbzkpts) 771 772 nbands = band2 - band1 773 size = self.kncomm.size 774 rank = self.kncomm.rank 775 ns = wfs.nspins 776 nk = len(kpts) 777 n = (ns * nk * nbands + size - 1) // size 778 i1 = rank * n 779 i2 = min(i1 + n, ns * nk * nbands) 780 781 self.mysKn1n2 = [] 782 i = 0 783 for s in range(ns): 784 for K in kpts: 785 n1 = min(max(0, i1 - i), nbands) 786 n2 = min(max(0, i2 - i), nbands) 787 if n1 != n2: 788 self.mysKn1n2.append((s, K, n1 + band1, n2 + band1)) 789 i += nbands 790 791 print('BZ k-points:', self.calc.wfs.kd, file=self.fd) 792 print('Distributing spins, k-points and bands (%d x %d x %d)' % 793 (ns, nk, nbands), 794 'over %d process%s' % 795 (self.kncomm.size, ['es', ''][self.kncomm.size == 1]), 796 file=self.fd) 797 print('Number of blocks:', self.blockcomm.size, file=self.fd) 798 799 @timer('Get a k-point') 800 def get_k_point(self, s, k_c, n1, n2, load_wfs=True, block=False): 801 """Return wave functions for a specific k-point and spin. 802 803 s: int 804 Spin index (0 or 1). 805 K: int 806 BZ k-point index. 807 n1, n2: int 808 Range of bands to include. 809 """ 810 811 wfs = self.calc.wfs 812 kd = wfs.kd 813 814 # Parse kpoint: is k_c an index or a vector 815 if not isinstance(k_c, numbers.Integral): 816 K = self.find_kpoint(k_c) 817 shift0_c = (kd.bzk_kc[K] - k_c).round().astype(int) 818 else: 819 # Fall back to index 820 K = k_c 821 shift0_c = np.array([0, 0, 0]) 822 k_c = None 823 824 if block: 825 nblocks = self.blockcomm.size 826 rank = self.blockcomm.rank 827 else: 828 nblocks = 1 829 rank = 0 830 831 blocksize = (n2 - n1 + nblocks - 1) // nblocks 832 na = min(n1 + rank * blocksize, n2) 833 nb = min(na + blocksize, n2) 834 835 U_cc, T, a_a, U_aii, shift_c, time_reversal = \ 836 self.construct_symmetry_operators(K, k_c=k_c) 837 838 shift_c += -shift0_c 839 ik = wfs.kd.bz2ibz_k[K] 840 assert wfs.kd.comm.size == 1 841 kpt = wfs.kpt_qs[ik][s] 842 843 assert n2 <= len(kpt.eps_n), \ 844 'Increase GS-nbands or decrease chi0-nbands!' 845 eps_n = kpt.eps_n[n1:n2] 846 f_n = kpt.f_n[n1:n2] / kpt.weight 847 848 if not load_wfs: 849 return KPoint(s, K, n1, n2, blocksize, na, nb, 850 None, eps_n, f_n, None, shift_c) 851 852 with self.timer('load wfs'): 853 psit_nG = kpt.psit_nG 854 ut_nR = wfs.gd.empty(nb - na, wfs.dtype) 855 for n in range(na, nb): 856 ut_nR[n - na] = T(wfs.pd.ifft(psit_nG[n], ik)) 857 858 with self.timer('Load projections'): 859 P_ani = [] 860 for b, U_ii in zip(a_a, U_aii): 861 P_ni = np.dot(kpt.P_ani[b][na:nb], U_ii) 862 if time_reversal: 863 P_ni = P_ni.conj() 864 P_ani.append(P_ni) 865 866 return KPoint(s, K, n1, n2, blocksize, na, nb, 867 ut_nR, eps_n, f_n, P_ani, shift_c) 868 869 def generate_pair_densities(self, pd, m1, m2, spins, intraband=True, 870 PWSA=None, disable_optical_limit=False, 871 unsymmetrized=False, use_more_memory=1): 872 """Generator for returning pair densities. 873 874 Returns the pair densities between the occupied and 875 the states in range(m1, m2). 876 877 pd: PWDescriptor 878 Plane-wave descriptor for a single q-point. 879 m1: int 880 Index of first unoccupied band. 881 m2: int 882 Index of last unoccupied band. 883 spins: list 884 List of spin indices included. 885 intraband: bool 886 Include intraband transitions in optical limit. 887 PWSA: PlanewaveSymmetryAnalyzer 888 If supplied uses this object to determine the symmetries 889 of the pair-densities. 890 disable_optical_limit: bool 891 Disable optical limit. 892 unsymmetrized: bool 893 Only return pair-densities from one kpoint in each 894 group of equivalent kpoints. 895 use_more_memory: float 896 Group more pair densities for several occupied bands 897 together before returning. Here 0 <= use_more_memory <= 1, 898 where zero is the minimal amount of memory, and 1 is the maximal. 899 """ 900 assert 0 <= use_more_memory <= 1 901 902 q_c = pd.kd.bzk_kc[0] 903 optical_limit = np.allclose(q_c, 0.0) and self.response == 'density' 904 optical_limit = not disable_optical_limit and optical_limit 905 906 Q_aGii = self.initialize_paw_corrections(pd) 907 self.Q_aGii = Q_aGii # This is used in g0w0 908 909 if PWSA is None: 910 with self.timer('Symmetry analyzer'): 911 PWSA = PWSymmetryAnalyzer # Line too long otherwise 912 PWSA = PWSA(self.calc.wfs.kd, pd, 913 timer=self.timer, txt=self.fd) 914 915 pb = ProgressBar(self.fd) 916 for kn, (s, ik, n1, n2) in pb.enumerate(self.mysKn1n2): 917 Kstar_k = PWSA.unfold_ibz_kpoint(ik) 918 for K_k in PWSA.group_kpoints(Kstar_k): 919 # Let the first kpoint of the group represent 920 # the rest of the kpoints 921 K1 = K_k[0] 922 # In this way wavefunctions are only loaded into 923 # memory for this particular set of kpoints 924 kptpair = self.get_kpoint_pair(pd, s, K1, n1, n2, m1, m2) 925 kpt1 = kptpair.get_k1() # kpt1 = k 926 927 if kpt1.s not in spins: 928 continue 929 kpt2 = kptpair.get_k2() # kpt2 = k + q 930 931 if unsymmetrized: 932 # Number of times kpoints are mapped into themselves 933 weight = np.sqrt(PWSA.how_many_symmetries() / len(K_k)) 934 935 # Use kpt2 to compute intraband transitions 936 # These conditions are sufficient to make sure 937 # that it still works in parallel 938 if kpt1.n1 == 0 and self.blockcomm.rank == 0 and \ 939 optical_limit and intraband: 940 assert self.nocc2 <= kpt2.nb, \ 941 print('Error: Too few unoccupied bands') 942 vel0_mv = self.intraband_pair_density(kpt2) 943 f_m = kpt2.f_n[kpt2.na - kpt2.n1:kpt2.nb - kpt2.n1] 944 with self.timer('intraband'): 945 if vel0_mv is not None: 946 if unsymmetrized: 947 yield (f_m, None, None, 948 None, None, vel0_mv / weight) 949 else: 950 for K2 in K_k: 951 vel_mv = PWSA.map_v(K1, K2, vel0_mv) 952 yield (f_m, None, None, 953 None, None, vel_mv) 954 955 # Divide the occupied bands into chunks 956 n_n = np.arange(n2 - n1) 957 if use_more_memory == 0: 958 chunksize = 1 959 else: 960 chunksize = np.ceil(len(n_n) * 961 use_more_memory).astype(int) 962 963 no_n = [] 964 for i in range(len(n_n) // chunksize): 965 i1 = i * chunksize 966 i2 = min((i + 1) * chunksize, len(n_n)) 967 no_n.append(n_n[i1:i2]) 968 969 # n runs over occupied bands 970 for n_n in no_n: # n_n is a list of occupied band indices 971 # m over unoccupied bands 972 m_m = np.arange(0, kpt2.n2 - kpt2.n1) 973 deps_nm = kptpair.get_transition_energies(n_n, m_m) 974 df_nm = kptpair.get_occupation_differences(n_n, m_m) 975 976 # This is not quite right for 977 # degenerate partially occupied 978 # bands, but good enough for now: 979 df_nm[df_nm <= 1e-20] = 0.0 980 981 # Get pair density for representative kpoint 982 ol = optical_limit 983 n0_nmG, n0_nmv, _ = self.get_pair_density(pd, kptpair, 984 n_n, m_m, 985 optical_limit=ol, 986 intraband=False, 987 Q_aGii=Q_aGii) 988 989 n0_nmG[deps_nm >= 0.0] = 0.0 990 if optical_limit: 991 n0_nmv[deps_nm >= 0.0] = 0.0 992 993 # Reshape nm -> m 994 nG = pd.ngmax 995 deps_m = deps_nm.reshape(-1) 996 df_m = df_nm.reshape(-1) 997 n0_mG = n0_nmG.reshape((-1, nG)) 998 if optical_limit: 999 n0_mv = n0_nmv.reshape((-1, 3)) 1000 1001 if unsymmetrized: 1002 if optical_limit: 1003 yield (None, df_m, deps_m, 1004 n0_mG / weight, n0_mv / weight, None) 1005 else: 1006 yield (None, df_m, deps_m, 1007 n0_mG / weight, None, None) 1008 continue 1009 1010 # Collect pair densities in a single array 1011 # and return them 1012 nm = n0_mG.shape[0] 1013 nG = n0_mG.shape[1] 1014 nk = len(K_k) 1015 1016 n_MG = np.empty((nm * nk, nG), complex) 1017 if optical_limit: 1018 n_Mv = np.empty((nm * nk, 3), complex) 1019 deps_M = np.tile(deps_m, nk) 1020 df_M = np.tile(df_m, nk) 1021 1022 for i, K2 in enumerate(K_k): 1023 i1 = i * nm 1024 i2 = (i + 1) * nm 1025 n_mG = PWSA.map_G(K1, K2, n0_mG) 1026 1027 if optical_limit: 1028 n_mv = PWSA.map_v(K1, K2, n0_mv) 1029 n_mG[:, 0] = n_mv[:, 0] 1030 n_Mv[i1:i2, :] = n_mv 1031 1032 n_MG[i1:i2, :] = n_mG 1033 1034 if optical_limit: 1035 yield (None, df_M, deps_M, n_MG, n_Mv, None) 1036 else: 1037 yield (None, df_M, deps_M, n_MG, None, None) 1038 1039 pb.finish() 1040 1041 @timer('Get kpoint pair') 1042 def get_kpoint_pair(self, pd, s, Kork_c, n1, n2, m1, m2, 1043 load_wfs=True, block=False): 1044 # wfs = self.calc.wfs 1045 # bzk_kc = wfs.kd.bzk_kc 1046 1047 if isinstance(Kork_c, int): 1048 # If k_c is an integer then it refers to 1049 # the index of the kpoint in the BZ 1050 k_c = self.calc.wfs.kd.bzk_kc[Kork_c] 1051 else: 1052 k_c = Kork_c 1053 1054 q_c = pd.kd.bzk_kc[0] 1055 with self.timer('get k-points'): 1056 kpt1 = self.get_k_point(s, k_c, n1, n2, load_wfs=load_wfs) 1057 # K2 = wfs.kd.find_k_plus_q(q_c, [kpt1.K])[0] 1058 if self.response in ['+-', '-+']: 1059 s2 = 1 - s 1060 else: 1061 s2 = s 1062 kpt2 = self.get_k_point(s2, k_c + q_c, m1, m2, 1063 load_wfs=load_wfs, block=block) 1064 1065 with self.timer('fft indices'): 1066 Q_G = self.get_fft_indices(kpt1.K, kpt2.K, q_c, pd, 1067 kpt1.shift_c - kpt2.shift_c) 1068 1069 return KPointPair(kpt1, kpt2, Q_G) 1070 1071 @timer('get_pair_density') 1072 def get_pair_density(self, pd, kptpair, n_n, m_m, 1073 optical_limit=False, intraband=False, 1074 Q_aGii=None, block=False, direction=2, 1075 extend_head=True): 1076 """Get pair density for a kpoint pair.""" 1077 ol = optical_limit = np.allclose(pd.kd.bzk_kc[0], 0.0) and \ 1078 self.response == 'density' 1079 eh = extend_head 1080 cpd = self.calculate_pair_densities # General pair densities 1081 opd = self.optical_pair_density # Interband pair densities / q 1082 1083 if Q_aGii is None: 1084 Q_aGii = self.initialize_paw_corrections(pd) 1085 1086 kpt1 = kptpair.kpt1 1087 kpt2 = kptpair.kpt2 1088 Q_G = kptpair.Q_G # Fourier components of kpoint pair 1089 nG = len(Q_G) 1090 1091 if extend_head: 1092 n_nmG = np.zeros((len(n_n), len(m_m), nG + 2 * ol), pd.dtype) 1093 else: 1094 n_nmG = np.zeros((len(n_n), len(m_m), nG), pd.dtype) 1095 1096 for j, n in enumerate(n_n): 1097 Q_G = kptpair.Q_G 1098 with self.timer('conj'): 1099 ut1cc_R = kpt1.ut_nR[n - kpt1.na].conj() 1100 with self.timer('paw'): 1101 C1_aGi = [np.dot(Q_Gii, P1_ni[n - kpt1.na].conj()) 1102 for Q_Gii, P1_ni in zip(Q_aGii, kpt1.P_ani)] 1103 n_nmG[j, :, 2 * ol * eh:] = cpd(ut1cc_R, C1_aGi, kpt2, pd, Q_G, 1104 block=block) 1105 if optical_limit: 1106 if extend_head: 1107 n_nmG[j, :, 0:3] = opd(n, m_m, kpt1, kpt2, 1108 block=block) 1109 else: 1110 n_nmG[j, :, 0] = opd(n, m_m, kpt1, kpt2, 1111 block=block)[:, direction] 1112 return n_nmG 1113 1114 @timer('get_pair_momentum') 1115 def get_pair_momentum(self, pd, kptpair, n_n, m_m, Q_avGii=None): 1116 r"""Calculate matrix elements of the momentum operator. 1117 1118 Calculates:: 1119 1120 n_{nm\mathrm{k}}\int_{\Omega_{\mathrm{cell}}}\mathrm{d}\mathbf{r} 1121 \psi_{n\mathrm{k}}^*(\mathbf{r}) 1122 e^{-i\,(\mathrm{q} + \mathrm{G})\cdot\mathbf{r}} 1123 \nabla\psi_{m\mathrm{k} + \mathrm{q}}(\mathbf{r}) 1124 1125 pd: PlaneWaveDescriptor 1126 Plane wave descriptor of a single q_c. 1127 kptpair: KPointPair 1128 KpointPair object containing the two kpoints. 1129 n_n: list 1130 List of left-band indices (n). 1131 m_m: 1132 List of right-band indices (m). 1133 """ 1134 wfs = self.calc.wfs 1135 1136 kpt1 = kptpair.kpt1 1137 kpt2 = kptpair.kpt2 1138 Q_G = kptpair.Q_G # Fourier components of kpoint pair 1139 1140 # For the same band we 1141 kd = wfs.kd 1142 gd = wfs.gd 1143 k_c = kd.bzk_kc[kpt1.K] + kpt1.shift_c 1144 k_v = 2 * np.pi * np.dot(k_c, np.linalg.inv(gd.cell_cv).T) 1145 1146 # Calculate k + G 1147 G_Gv = pd.get_reciprocal_vectors(add_q=True) 1148 kqG_Gv = k_v[np.newaxis] + G_Gv 1149 1150 # Pair velocities 1151 n_nmvG = pd.zeros((len(n_n), len(m_m), 3)) 1152 1153 # Calculate derivatives of left-wavefunction 1154 # (there will typically be fewer of these) 1155 ut_nvR = self.make_derivative(kpt1.s, kpt1.K, kpt1.n1, kpt1.n2) 1156 1157 # PAW-corrections 1158 if Q_avGii is None: 1159 Q_avGii = self.initialize_paw_nabla_corrections(pd) 1160 1161 # Iterate over occupied bands 1162 for j, n in enumerate(n_n): 1163 ut1cc_R = kpt1.ut_nR[n].conj() 1164 1165 n_mG = self.calculate_pair_densities(ut1cc_R, 1166 [], kpt2, 1167 pd, Q_G) 1168 1169 n_nmvG[j] = 1j * kqG_Gv.T[np.newaxis] * n_mG[:, np.newaxis] 1170 1171 # Treat each cartesian component at a time 1172 for v in range(3): 1173 # Minus from integration by parts 1174 utvcc_R = -ut_nvR[n, v].conj() 1175 Cv1_aGi = [np.dot(P1_ni[n].conj(), Q_vGii[v]) 1176 for Q_vGii, P1_ni in zip(Q_avGii, kpt1.P_ani)] 1177 1178 nv_mG = self.calculate_pair_densities(utvcc_R, 1179 Cv1_aGi, kpt2, 1180 pd, Q_G) 1181 1182 n_nmvG[j, :, v] += nv_mG 1183 1184 # We want the momentum operator 1185 n_nmvG *= -1j 1186 1187 return n_nmvG 1188 1189 @timer('Calculate pair-densities') 1190 def calculate_pair_densities(self, ut1cc_R, C1_aGi, kpt2, pd, Q_G, 1191 block=True): 1192 """Calculate FFT of pair-densities and add PAW corrections. 1193 1194 ut1cc_R: 3-d complex ndarray 1195 Complex conjugate of the periodic part of the left hand side 1196 wave function. 1197 C1_aGi: list of ndarrays 1198 PAW corrections for all atoms. 1199 kpt2: KPoint object 1200 Right hand side k-point object. 1201 pd: PWDescriptor 1202 Plane-wave descriptor for for q=k2-k1. 1203 Q_G: 1-d int ndarray 1204 Mapping from flattened 3-d FFT grid to 0.5(G+q)^2<ecut sphere. 1205 """ 1206 1207 dv = pd.gd.dv 1208 n_mG = pd.empty(kpt2.blocksize) 1209 myblocksize = kpt2.nb - kpt2.na 1210 1211 for ut_R, n_G in zip(kpt2.ut_nR, n_mG): 1212 n_R = ut1cc_R * ut_R 1213 with self.timer('fft'): 1214 n_G[:] = pd.fft(n_R, 0, Q_G) * dv 1215 # PAW corrections: 1216 with self.timer('gemm'): 1217 for C1_Gi, P2_mi in zip(C1_aGi, kpt2.P_ani): 1218 gemm(1.0, C1_Gi, P2_mi, 1.0, n_mG[:myblocksize], 't') 1219 1220 if not block or self.blockcomm.size == 1: 1221 return n_mG 1222 else: 1223 n_MG = pd.empty(kpt2.blocksize * self.blockcomm.size) 1224 self.blockcomm.all_gather(n_mG, n_MG) 1225 return n_MG[:kpt2.n2 - kpt2.n1] 1226 1227 @timer('Optical limit') 1228 def optical_pair_velocity(self, n, m_m, kpt1, kpt2, block=False): 1229 if self.ut_sKnvR is None or kpt1.K not in self.ut_sKnvR[kpt1.s]: 1230 self.ut_sKnvR = self.calculate_derivatives(kpt1) 1231 1232 kd = self.calc.wfs.kd 1233 gd = self.calc.wfs.gd 1234 k_c = kd.bzk_kc[kpt1.K] + kpt1.shift_c 1235 k_v = 2 * np.pi * np.dot(k_c, np.linalg.inv(gd.cell_cv).T) 1236 1237 ut_vR = self.ut_sKnvR[kpt1.s][kpt1.K][n - kpt1.n1] 1238 atomdata_a = self.calc.wfs.setups 1239 if self.paw_correction == 'brute-force': 1240 C_avi = [np.dot(atomdata.nabla_iiv.T, P_ni[n - kpt1.na]) 1241 for atomdata, P_ni in zip(atomdata_a, kpt1.P_ani)] 1242 elif self.paw_correction == 'skip': 1243 C_avi = [np.zeros((3, P_ni.shape[1]), complex) 1244 for atomdata, P_ni in zip(atomdata_a, kpt1.P_ani)] 1245 else: 1246 1 / 0 1247 1248 blockbands = kpt2.nb - kpt2.na 1249 n0_mv = np.empty((kpt2.blocksize, 3), dtype=complex) 1250 nt_m = np.empty(kpt2.blocksize, dtype=complex) 1251 n0_mv[:blockbands] = -self.calc.wfs.gd.integrate(ut_vR, 1252 kpt2.ut_nR).T 1253 nt_m[:blockbands] = self.calc.wfs.gd.integrate(kpt1.ut_nR[n - kpt1.na], 1254 kpt2.ut_nR) 1255 1256 n0_mv[:blockbands] += (1j * nt_m[:blockbands, np.newaxis] * 1257 k_v[np.newaxis, :]) 1258 1259 for C_vi, P_mi in zip(C_avi, kpt2.P_ani): 1260 gemm(1.0, C_vi, P_mi, 1.0, n0_mv[:blockbands], 'c') 1261 1262 if block and self.blockcomm.size > 1: 1263 n0_Mv = np.empty((kpt2.blocksize * self.blockcomm.size, 3), 1264 dtype=complex) 1265 self.blockcomm.all_gather(n0_mv, n0_Mv) 1266 n0_mv = n0_Mv[:kpt2.n2 - kpt2.n1] 1267 1268 return -1j * n0_mv 1269 1270 def optical_pair_density(self, n, m_m, kpt1, kpt2, 1271 block=False): 1272 # Relative threshold for perturbation theory 1273 threshold = self.threshold 1274 1275 eps1 = kpt1.eps_n[n - kpt1.n1] 1276 deps_m = (eps1 - kpt2.eps_n)[m_m - kpt2.n1] 1277 n0_mv = self.optical_pair_velocity(n, m_m, kpt1, kpt2, 1278 block=block) 1279 1280 deps_m = deps_m.copy() 1281 deps_m[deps_m == 0.0] = np.inf 1282 1283 smallness_mv = np.abs(-1e-3 * n0_mv / deps_m[:, np.newaxis]) 1284 inds_mv = (np.logical_and(np.inf > smallness_mv, 1285 smallness_mv > threshold)) 1286 n0_mv *= - 1 / deps_m[:, np.newaxis] 1287 n0_mv[inds_mv] = 0 1288 1289 return n0_mv 1290 1291 @timer('Intraband') 1292 def intraband_pair_density(self, kpt, n_n=None, 1293 only_partially_occupied=False): 1294 """Calculate intraband matrix elements of nabla""" 1295 # Bands and check for block parallelization 1296 na, nb, n1 = kpt.na, kpt.nb, kpt.n1 1297 vel_nv = np.zeros((nb - na, 3), dtype=complex) 1298 if n_n is None: 1299 n_n = np.arange(na, nb) 1300 assert np.max(n_n) < nb, 'This is too many bands' 1301 assert np.min(n_n) >= na, 'This is too few bands' 1302 1303 # Load kpoints 1304 kd = self.calc.wfs.kd 1305 gd = self.calc.wfs.gd 1306 k_c = kd.bzk_kc[kpt.K] + kpt.shift_c 1307 k_v = 2 * np.pi * np.dot(k_c, np.linalg.inv(gd.cell_cv).T) 1308 atomdata_a = self.calc.wfs.setups 1309 f_n = kpt.f_n 1310 1311 # Only works with Fermi-Dirac distribution 1312 assert self.calc.wfs.occupations.name in {'fermi-dirac', 'zero-width'} 1313 1314 # No carriers when T=0 1315 width = getattr(self.calc.wfs.occupations, '_width', 0.0) / Ha 1316 1317 if width > 1e-15: 1318 dfde_n = -1 / width * (f_n - f_n**2.0) # Analytical derivative 1319 partocc_n = np.abs(dfde_n) > 1e-5 # Is part. occupied? 1320 else: 1321 # Just include all bands to be sure 1322 partocc_n = np.ones(len(f_n), dtype=bool) 1323 1324 if only_partially_occupied and not partocc_n.any(): 1325 return None 1326 1327 if only_partially_occupied: 1328 # Check for block par. consistency 1329 assert (partocc_n < nb).all(), \ 1330 print('Include more unoccupied bands ', + 1331 'or less block parr.', file=self.fd) 1332 1333 # Break bands into degenerate chunks 1334 degchunks_cn = [] # indexing c as chunk number 1335 for n in n_n: 1336 inds_n = np.nonzero(np.abs(kpt.eps_n[n - n1] - 1337 kpt.eps_n) < 1e-5)[0] + n1 1338 1339 # Has this chunk already been computed? 1340 oldchunk = any([n in chunk for chunk in degchunks_cn]) 1341 if not oldchunk and \ 1342 (partocc_n[n - n1] or not only_partially_occupied): 1343 assert all([ind in n_n for ind in inds_n]), \ 1344 print('\nYou are cutting over a degenerate band ' + 1345 'using block parallelization.', 1346 inds_n, n_n, file=self.fd) 1347 degchunks_cn.append((inds_n)) 1348 1349 # Calculate matrix elements by diagonalizing each block 1350 for ind_n in degchunks_cn: 1351 deg = len(ind_n) 1352 ut_nvR = self.calc.wfs.gd.zeros((deg, 3), complex) 1353 vel_nnv = np.zeros((deg, deg, 3), dtype=complex) 1354 # States are included starting from kpt.na 1355 ut_nR = kpt.ut_nR[ind_n - na] 1356 1357 # Get derivatives 1358 for ind, ut_vR in zip(ind_n, ut_nvR): 1359 ut_vR[:] = self.make_derivative(kpt.s, kpt.K, 1360 ind, ind + 1)[0] 1361 1362 # Treat the whole degenerate chunk 1363 for n in range(deg): 1364 ut_vR = ut_nvR[n] 1365 C_avi = [np.dot(atomdata.nabla_iiv.T, P_ni[ind_n[n] - na]) 1366 for atomdata, P_ni in zip(atomdata_a, kpt.P_ani)] 1367 1368 nabla0_nv = -self.calc.wfs.gd.integrate(ut_vR, ut_nR).T 1369 nt_n = self.calc.wfs.gd.integrate(ut_nR[n], ut_nR) 1370 nabla0_nv += 1j * nt_n[:, np.newaxis] * k_v[np.newaxis, :] 1371 1372 for C_vi, P_ni in zip(C_avi, kpt.P_ani): 1373 gemm(1.0, C_vi, P_ni[ind_n - na], 1.0, nabla0_nv, 'c') 1374 1375 vel_nnv[n] = -1j * nabla0_nv 1376 1377 for iv in range(3): 1378 vel, _ = np.linalg.eig(vel_nnv[..., iv]) 1379 vel_nv[ind_n - na, iv] = vel # Use eigenvalues 1380 1381 return vel_nv[n_n - na] 1382 1383 def get_fft_indices(self, K1, K2, q_c, pd, shift0_c): 1384 """Get indices for G-vectors inside cutoff sphere.""" 1385 kd = self.calc.wfs.kd 1386 N_G = pd.Q_qG[0] 1387 shift_c = (shift0_c + 1388 (q_c - kd.bzk_kc[K2] + kd.bzk_kc[K1]).round().astype(int)) 1389 if shift_c.any(): 1390 n_cG = np.unravel_index(N_G, pd.gd.N_c) 1391 n_cG = [n_G + shift for n_G, shift in zip(n_cG, shift_c)] 1392 N_G = np.ravel_multi_index(n_cG, pd.gd.N_c, 'wrap') 1393 return N_G 1394 1395 def construct_symmetry_operators(self, K, k_c=None): 1396 """Construct symmetry operators for wave function and PAW projections. 1397 1398 We want to transform a k-point in the irreducible part of the BZ to 1399 the corresponding k-point with index K. 1400 1401 Returns U_cc, T, a_a, U_aii, shift_c and time_reversal, where: 1402 1403 * U_cc is a rotation matrix. 1404 * T() is a function that transforms the periodic part of the wave 1405 function. 1406 * a_a is a list of symmetry related atom indices 1407 * U_aii is a list of rotation matrices for the PAW projections 1408 * shift_c is three integers: see code below. 1409 * time_reversal is a flag - if True, projections should be complex 1410 conjugated. 1411 1412 See the get_k_point() method for how to use these tuples. 1413 """ 1414 1415 wfs = self.calc.wfs 1416 kd = wfs.kd 1417 1418 s = kd.sym_k[K] 1419 U_cc = kd.symmetry.op_scc[s] 1420 time_reversal = kd.time_reversal_k[K] 1421 ik = kd.bz2ibz_k[K] 1422 if k_c is None: 1423 k_c = kd.bzk_kc[K] 1424 ik_c = kd.ibzk_kc[ik] 1425 1426 sign = 1 - 2 * time_reversal 1427 shift_c = np.dot(U_cc, ik_c) - k_c * sign 1428 1429 try: 1430 assert np.allclose(shift_c.round(), shift_c) 1431 except AssertionError: 1432 print('shift_c ' + str(shift_c), file=self.fd) 1433 print('k_c ' + str(k_c), file=self.fd) 1434 print('kd.bzk_kc[K] ' + str(kd.bzk_kc[K]), file=self.fd) 1435 print('ik_c ' + str(ik_c), file=self.fd) 1436 print('U_cc ' + str(U_cc), file=self.fd) 1437 print('sign ' + str(sign), file=self.fd) 1438 raise AssertionError 1439 1440 shift_c = shift_c.round().astype(int) 1441 1442 if (U_cc == np.eye(3)).all(): 1443 def T(f_R): 1444 return f_R 1445 else: 1446 N_c = self.calc.wfs.gd.N_c 1447 i_cr = np.dot(U_cc.T, np.indices(N_c).reshape((3, -1))) 1448 i = np.ravel_multi_index(i_cr, N_c, 'wrap') 1449 1450 def T(f_R): 1451 return f_R.ravel()[i].reshape(N_c) 1452 1453 if time_reversal: 1454 T0 = T 1455 1456 def T(f_R): 1457 return T0(f_R).conj() 1458 shift_c *= -1 1459 1460 a_a = [] 1461 U_aii = [] 1462 for a, id in enumerate(self.calc.wfs.setups.id_a): 1463 b = kd.symmetry.a_sa[s, a] 1464 S_c = np.dot(self.spos_ac[a], U_cc) - self.spos_ac[b] 1465 x = np.exp(2j * pi * np.dot(ik_c, S_c)) 1466 U_ii = wfs.setups[a].R_sii[s].T * x 1467 a_a.append(b) 1468 U_aii.append(U_ii) 1469 1470 return U_cc, T, a_a, U_aii, shift_c, time_reversal 1471 1472 @timer('Initialize PAW corrections') 1473 def initialize_paw_corrections(self, pd, soft=False): 1474 wfs = self.calc.wfs 1475 q_v = pd.K_qv[0] 1476 optical_limit = np.allclose(q_v, 0) and self.response == 'density' 1477 1478 G_Gv = pd.get_reciprocal_vectors() 1479 if optical_limit: 1480 G_Gv[0] = 1 1481 1482 pos_av = np.dot(self.spos_ac, pd.gd.cell_cv) 1483 1484 # Collect integrals for all species: 1485 Q_xGii = {} 1486 for id, atomdata in wfs.setups.setups.items(): 1487 if soft: 1488 ghat = PWLFC([atomdata.ghat_l], pd) 1489 ghat.set_positions(np.zeros((1, 3))) 1490 Q_LG = ghat.expand().T 1491 if atomdata.Delta_iiL is None: 1492 ni = atomdata.ni 1493 Q_Gii = np.zeros((Q_LG.shape[1], ni, ni)) 1494 else: 1495 Q_Gii = np.dot(atomdata.Delta_iiL, Q_LG).T 1496 else: 1497 ni = atomdata.ni 1498 if self.paw_correction == 'brute-force': 1499 Q_Gii = two_phi_planewave_integrals(G_Gv, atomdata) 1500 Q_Gii.shape = (-1, ni, ni) 1501 elif self.paw_correction == 'skip': 1502 Q_Gii = np.zeros((len(G_Gv), ni, ni), complex) 1503 else: 1504 1 / 0 1505 1506 Q_xGii[id] = Q_Gii 1507 1508 Q_aGii = [] 1509 for a, atomdata in enumerate(wfs.setups): 1510 id = wfs.setups.id_a[a] 1511 Q_Gii = Q_xGii[id] 1512 x_G = np.exp(-1j * np.dot(G_Gv, pos_av[a])) 1513 Q_aGii.append(x_G[:, np.newaxis, np.newaxis] * Q_Gii) 1514 if optical_limit: 1515 Q_aGii[a][0] = atomdata.dO_ii 1516 1517 return Q_aGii 1518 1519 @timer('Initialize PAW corrections') 1520 def initialize_paw_nabla_corrections(self, pd, soft=False): 1521 print('Initializing nabla PAW Corrections', file=self.fd) 1522 wfs = self.calc.wfs 1523 G_Gv = pd.get_reciprocal_vectors() 1524 pos_av = np.dot(self.spos_ac, pd.gd.cell_cv) 1525 1526 # Collect integrals for all species: 1527 Q_xvGii = {} 1528 for id, atomdata in wfs.setups.setups.items(): 1529 if soft: 1530 raise NotImplementedError 1531 else: 1532 Q_vGii = two_phi_nabla_planewave_integrals(G_Gv, atomdata) 1533 ni = atomdata.ni 1534 Q_vGii.shape = (3, -1, ni, ni) 1535 1536 Q_xvGii[id] = Q_vGii 1537 1538 Q_avGii = [] 1539 for a, atomdata in enumerate(wfs.setups): 1540 id = wfs.setups.id_a[a] 1541 Q_vGii = Q_xvGii[id] 1542 x_G = np.exp(-1j * np.dot(G_Gv, pos_av[a])) 1543 Q_avGii.append(x_G[np.newaxis, :, np.newaxis, np.newaxis] * Q_vGii) 1544 1545 return Q_avGii 1546 1547 def calculate_derivatives(self, kpt): 1548 ut_sKnvR = [{}, {}] 1549 ut_nvR = self.make_derivative(kpt.s, kpt.K, kpt.n1, kpt.n2) 1550 ut_sKnvR[kpt.s][kpt.K] = ut_nvR 1551 1552 return ut_sKnvR 1553 1554 @timer('Derivatives') 1555 def make_derivative(self, s, K, n1, n2): 1556 wfs = self.calc.wfs 1557 if self.real_space_derivatives: 1558 grad_v = [Gradient(wfs.gd, v, 1.0, 4, complex).apply 1559 for v in range(3)] 1560 1561 U_cc, T, a_a, U_aii, shift_c, time_reversal = \ 1562 self.construct_symmetry_operators(K) 1563 A_cv = wfs.gd.cell_cv 1564 M_vv = np.dot(np.dot(A_cv.T, U_cc.T), np.linalg.inv(A_cv).T) 1565 ik = wfs.kd.bz2ibz_k[K] 1566 assert wfs.kd.comm.size == 1 1567 kpt = wfs.kpt_qs[ik][s] 1568 psit_nG = kpt.psit_nG 1569 iG_Gv = 1j * wfs.pd.get_reciprocal_vectors(q=ik, add_q=False) 1570 ut_nvR = wfs.gd.zeros((n2 - n1, 3), complex) 1571 for n in range(n1, n2): 1572 for v in range(3): 1573 if self.real_space_derivatives: 1574 ut_R = T(wfs.pd.ifft(psit_nG[n], ik)) 1575 grad_v[v](ut_R, ut_nvR[n - n1, v], 1576 np.ones((3, 2), complex)) 1577 else: 1578 ut_R = T(wfs.pd.ifft(iG_Gv[:, v] * psit_nG[n], ik)) 1579 for v2 in range(3): 1580 ut_nvR[n - n1, v2] += ut_R * M_vv[v, v2] 1581 1582 return ut_nvR 1583