1# encoding: utf-8 2import numpy as np 3import scipy.sparse as sparse 4from ase.neighborlist import PrimitiveNeighborList 5# from ase.utils.timing import timer 6from gpaw.utilities.tools import tri2full 7 8# from gpaw import debug 9from gpaw.lcao.overlap import (FourierTransformer, TwoSiteOverlapCalculator, 10 ManySiteOverlapCalculator, 11 AtomicDisplacement, NullPhases, BlochPhases, 12 DerivativeAtomicDisplacement) 13 14 15def get_cutoffs(f_Ij): 16 rcutmax_I = [] 17 for f_j in f_Ij: 18 rcutmax = 0.001 # 'paranoid zero' 19 for f in f_j: 20 rcutmax = max(rcutmax, f.get_cutoff()) 21 rcutmax_I.append(rcutmax) 22 return rcutmax_I 23 24 25def get_lvalues(f_Ij): 26 return [[f.get_angular_momentum_number() for f in f_j] for f_j in f_Ij] 27 28 29class AtomPairRegistry: 30 def __init__(self, cutoff_a, pbc_c, cell_cv, spos_ac): 31 nl = PrimitiveNeighborList(cutoff_a, skin=0, sorted=True, 32 self_interaction=True, 33 use_scaled_positions=True) 34 35 nl.update(pbc=pbc_c, cell=cell_cv, coordinates=spos_ac) 36 r_and_offset_aao = {} 37 38 def add(a1, a2, R_c, offset): 39 r_and_offset_aao.setdefault((a1, a2), []).append((R_c, offset)) 40 41 for a1, spos1_c in enumerate(spos_ac): 42 a2_a, offsets = nl.get_neighbors(a1) 43 for a2, offset in zip(a2_a, offsets): 44 spos2_c = spos_ac[a2] + offset 45 46 R_c = np.dot(spos2_c - spos1_c, cell_cv) 47 add(a1, a2, R_c, offset) 48 if a1 != a2 or offset.any(): 49 add(a2, a1, -R_c, -offset) 50 self.r_and_offset_aao = r_and_offset_aao 51 52 def get(self, a1, a2): 53 R_c_and_offset_a = self.r_and_offset_aao.get((a1, a2)) 54 return R_c_and_offset_a 55 56 def get_atompairs(self): 57 return list(sorted(self.r_and_offset_aao)) 58 59 60class TCIExpansions: 61 def __init__(self, phit_Ij, pt_Ij, I_a): 62 assert len(pt_Ij) == len(phit_Ij) 63 64 # Cutoffs by species: 65 pt_rcmax_I = get_cutoffs(pt_Ij) 66 phit_rcmax_I = get_cutoffs(phit_Ij) 67 rcmax_I = [max(rc1, rc2) for rc1, rc2 68 in zip(pt_rcmax_I, phit_rcmax_I)] 69 70 transformer = FourierTransformer(rcmax=max(rcmax_I + [1e-3]), ng=2**10) 71 tsoc = TwoSiteOverlapCalculator(transformer) 72 msoc = ManySiteOverlapCalculator(tsoc, I_a, I_a) 73 phit_Ijq = msoc.transform(phit_Ij) 74 pt_Ijq = msoc.transform(pt_Ij) 75 pt_l_Ij = get_lvalues(pt_Ij) 76 phit_l_Ij = get_lvalues(phit_Ij) 77 self.O_expansions = msoc.calculate_expansions(phit_l_Ij, phit_Ijq, 78 phit_l_Ij, phit_Ijq) 79 self.T_expansions = msoc.calculate_kinetic_expansions(phit_l_Ij, 80 phit_Ijq) 81 self.P_expansions = msoc.calculate_expansions(pt_l_Ij, pt_Ijq, 82 phit_l_Ij, phit_Ijq) 83 self.I_a = I_a # Actually I_a belongs outside, like spos_ac. 84 self.rcmax_I = rcmax_I 85 self.phit_rcmax_I = phit_rcmax_I 86 self.pt_rcmax_I = pt_rcmax_I 87 88 @classmethod 89 def new_from_setups(cls, setups): 90 I_setup = {} 91 setups_I = list(setups.setups.values()) 92 for I, setup in enumerate(setups_I): 93 I_setup[setup] = I 94 I_a = [I_setup[setup] for setup in setups] 95 96 return TCIExpansions([s.phit_j for s in setups_I], 97 [s.pt_j for s in setups_I], 98 I_a) 99 100 def get_tci_calculator(self, cell_cv, spos_ac, pbc_c, ibzk_qc, dtype): 101 return TCICalculator(self, cell_cv, spos_ac, pbc_c, ibzk_qc, dtype) 102 103 def get_manytci_calculator(self, setups, gd, spos_ac, ibzk_qc, dtype, 104 timer): 105 return ManyTCICalculator(self, setups, gd, spos_ac, ibzk_qc, dtype, 106 timer) 107 108 109class TCICalculator: 110 """High-level two-center integral calculator. 111 112 This object is not aware of parallelization. It works with any 113 pair of atoms a1, a2. 114 115 Create the object and calculate any interatomic overlap matrix as below. 116 117 tci = TCI(...) 118 119 Projector/basis overlap <pt_i^a1|phi_mu> between atoms a1, a2: 120 121 P_qim = tci.P(a1, a2) 122 123 Derivatives of the above with respect to movement of a2: 124 125 dPdR_qvim = tci.dPdR(a1, a2) 126 127 Basis/basis overlap and kinetic matrix elements between atoms a1, a2: 128 129 O_qmm, T_qmm = tci.O_T(a1, a2) 130 131 Derivative of the above wrt. position of a2: 132 133 dOdR_qvmm, dTdR_qvmm = tci.dOdR_dTdR(a1, a2) 134 135 """ 136 def __init__(self, tciexpansions, cell_cv, spos_ac, pbc_c, ibzk_qc, 137 dtype): 138 139 self.tciexpansions = tciexpansions 140 self.dtype = dtype 141 142 # XXX It is somewhat nasty that rcmax depends on how long our 143 # longest orbital happens to be 144 # Cutoffs by atom: 145 I_a = tciexpansions.I_a 146 cutoff_a = [tciexpansions.rcmax_I[I] for I in I_a] 147 self.pt_rcmax_a = np.array([tciexpansions.pt_rcmax_I[I] for I in I_a]) 148 self.phit_rcmax_a = np.array([tciexpansions.phit_rcmax_I[I] 149 for I in I_a]) 150 151 self.a1a2 = AtomPairRegistry(cutoff_a, pbc_c, cell_cv, spos_ac) 152 153 self.ibzk_qc = ibzk_qc 154 if ibzk_qc.any(): 155 self.get_phases = BlochPhases 156 else: 157 self.get_phases = NullPhases 158 159 self.O_T = self._tci_shortcut(False, False) 160 self.P = self._tci_shortcut(True, False) 161 self.dOdR_dTdR = self._tci_shortcut(False, True) 162 self.dPdR = self._tci_shortcut(True, True) 163 164 def _tci_shortcut(self, P, derivative): 165 def calculate(a1, a2): 166 return self._calculate(a1, a2, P, derivative) 167 return calculate 168 169 def _calculate(self, a1, a2, P=False, derivative=False): 170 """Calculate overlap of functions between atoms a1 and a2.""" 171 172 # We want to see quickly if there is no overlap because distance 173 # outside bounding spheres. 174 175 R_c_and_offset_a = self.a1a2.get(a1, a2) 176 if R_c_and_offset_a is None: 177 return None if P else (None, None) 178 179 rcut1 = self.pt_rcmax_a[a1] if P else self.phit_rcmax_a[a1] 180 rcut2 = self.phit_rcmax_a[a2] 181 maxdist = rcut1 + rcut2 182 183 # Filter out displacements larger than maxdist: 184 R_c_and_offset_a = [obj for obj in R_c_and_offset_a 185 if np.linalg.norm(obj[0]) < maxdist] 186 if not R_c_and_offset_a: # There was no overlap after all 187 return None if P else (None, None) 188 189 dtype = self.dtype 190 get_phases = self.get_phases 191 192 displacement = (DerivativeAtomicDisplacement 193 if derivative 194 else AtomicDisplacement) 195 ibzk_qc = self.ibzk_qc 196 nq = len(ibzk_qc) 197 phit_rcmax_a = self.phit_rcmax_a 198 pt_rcmax_a = self.pt_rcmax_a 199 200 shape = (nq, 3) if derivative else (nq,) 201 202 if P: 203 P_expansion = self.tciexpansions.P_expansions.get(a1, a2) 204 obj = P_qim = P_expansion.zeros(shape, dtype=dtype) 205 else: 206 O_expansion = self.tciexpansions.O_expansions.get(a1, a2) 207 T_expansion = self.tciexpansions.T_expansions.get(a1, a2) 208 O_qmm = O_expansion.zeros(shape, dtype=dtype) 209 T_qmm = T_expansion.zeros(shape, dtype=dtype) 210 obj = O_qmm, T_qmm 211 212 for R_c, offset in R_c_and_offset_a: 213 norm = np.linalg.norm(R_c) 214 phases = get_phases(ibzk_qc, offset) 215 216 disp = displacement(None, a1, a2, R_c, offset, phases) 217 218 if P: 219 assert norm < pt_rcmax_a[a1] + phit_rcmax_a[a2] 220 disp.evaluate_overlap(P_expansion, P_qim) 221 else: 222 assert norm < phit_rcmax_a[a1] + phit_rcmax_a[a2] 223 disp.evaluate_overlap(O_expansion, O_qmm) 224 disp.evaluate_overlap(T_expansion, T_qmm) 225 226 return obj 227 228 229class ManyTCICalculator: 230 def __init__(self, tciexpansions, setups, gd, spos_ac, ibzk_qc, dtype, 231 timer): 232 self.tci = tciexpansions.get_tci_calculator(gd.cell_cv, spos_ac, 233 gd.pbc_c, 234 ibzk_qc, dtype) 235 236 self.setups = setups 237 self.dtype = dtype 238 self.Pindices = setups.projector_indices() 239 self.Mindices = setups.basis_indices() 240 self.natoms = len(setups) 241 self.nq = len(ibzk_qc) 242 self.nao = self.Mindices.max 243 self.timer = timer 244 245 # @timer('tci-projectors') 246 def P_aqMi(self, my_atom_indices, derivative=False): 247 P_axMi = {} 248 if derivative: 249 P = self.tci.dPdR 250 251 def empty(nI): 252 return np.empty((self.nq, 3, self.nao, nI), self.dtype) 253 else: 254 P = self.tci.P 255 256 def empty(nI): 257 return np.empty((self.nq, self.nao, nI), self.dtype) 258 259 Mindices = self.Mindices 260 261 for a1 in my_atom_indices: 262 P_xMi = empty(self.setups[a1].ni) 263 264 for a2 in range(self.natoms): 265 N1, N2 = Mindices[a2] 266 P_xmi = P_xMi[..., N1:N2, :] 267 P_xim = P(a1, a2) 268 if P_xim is None: 269 P_xmi[:] = 0.0 270 else: 271 P_xmi[:] = P_xim.swapaxes(-2, -1).conj() 272 P_axMi[a1] = P_xMi 273 274 if derivative: 275 for a in P_axMi: 276 P_axMi[a] *= -1.0 277 return P_axMi 278 279 # @timer('tci-sparseprojectors') 280 def P_qIM(self, my_atom_indices): 281 nq = self.nq 282 P = self.tci.P 283 P_qIM = [sparse.dok_matrix((self.Pindices.max, self.Mindices.max), 284 dtype=self.dtype) 285 for _ in range(nq)] 286 287 for a1 in my_atom_indices: 288 I1, I2 = self.Pindices[a1] 289 290 # We can stride a2 over e.g. bd.comm and then do bd.comm.sum(). 291 # How should we do comm.sum() on a sparse matrix though? 292 for a2 in range(self.natoms): 293 M1, M2 = self.Mindices[a2] 294 P_qim = P(a1, a2) 295 if P_qim is not None: 296 for q in range(nq): 297 P_qIM[q][I1:I2, M1:M2] = P_qim[q] 298 P_qIM = [P_IM.tocsr() for P_IM in P_qIM] 299 return P_qIM 300 301 # @timer('tci-bfs') 302 def O_qMM_T_qMM(self, gdcomm, Mstart, Mstop, ignore_upper=False, 303 derivative=False): 304 mynao = Mstop - Mstart 305 Mindices = self.Mindices 306 307 if derivative: 308 O_T = self.tci.dOdR_dTdR 309 shape = (self.nq, 3, mynao, self.nao) 310 else: 311 O_T = self.tci.O_T 312 shape = (self.nq, mynao, self.nao) 313 314 O_xMM = np.zeros(shape, self.dtype) 315 T_xMM = np.zeros(shape, self.dtype) 316 317 # XXX the a1/a2 loops are not yet well load balanced. 318 for a1 in range(self.natoms): 319 M1, M2 = Mindices[a1] 320 if M2 <= Mstart or M1 >= Mstop: 321 continue 322 323 myM1 = max(M1 - Mstart, 0) 324 myM2 = min(M2 - Mstart, mynao) 325 nM = myM2 - myM1 326 327 assert nM > 0, nM 328 329 a2max = a1 + 1 # if not derivative else self.natoms 330 331 for a2 in range(gdcomm.rank, a2max, gdcomm.size): 332 O_xmm, T_xmm = O_T(a1, a2) 333 if O_xmm is None: 334 continue 335 336 N1, N2 = Mindices[a2] 337 m1 = max(Mstart - M1, 0) 338 m2 = m1 + nM # (Slice may go beyond end of matrix but OK) 339 O_xmm = O_xmm[..., m1:m2, :] 340 T_xmm = T_xmm[..., m1:m2, :] 341 O_xMM[..., myM1:myM2, N1:N2] = O_xmm 342 T_xMM[..., myM1:myM2, N1:N2] = T_xmm 343 344 if not ignore_upper and O_xMM.size: # reshape() fails on size-0 arrays 345 assert mynao == self.nao 346 assert O_xMM.shape[-2:] == (self.nao, self.nao) 347 if derivative: 348 def lumap(arr, out): 349 np.conj(arr, out) 350 out *= -1.0 351 else: 352 lumap = np.conj 353 354 for arr_xMM in [O_xMM, T_xMM]: 355 for tmp_MM in arr_xMM.reshape(-1, self.nao, self.nao): 356 tri2full(tmp_MM, UL='L', map=lumap) 357 358 return O_xMM, T_xMM 359