1from math import pi 2from typing import Dict, Tuple 3from collections import defaultdict 4 5import numpy as np 6 7from gpaw.kpt_descriptor import KPointDescriptor 8from .kpts import RSKPoint, to_real_space 9 10 11def create_symmetry_map(kd: KPointDescriptor): # -> List[List[int]] 12 sym = kd.symmetry 13 U_scc = sym.op_scc 14 nsym = len(U_scc) 15 compconj_s = np.zeros(nsym, bool) 16 if sym.time_reversal and not sym.has_inversion: 17 U_scc = np.concatenate([U_scc, -U_scc]) 18 compconj_s = np.zeros(nsym * 2, bool) 19 compconj_s[nsym:] = True 20 nsym *= 2 21 22 map_ss = np.zeros((nsym, nsym), int) 23 for s1 in range(nsym): 24 for s2 in range(nsym): 25 diff_s = abs(U_scc[s1].dot(U_scc).transpose((1, 0, 2)) - 26 U_scc[s2]).sum(2).sum(1) 27 indices = (diff_s == 0).nonzero()[0] 28 assert len(indices) == 1 29 s = indices[0] 30 assert compconj_s[s1] ^ compconj_s[s2] == compconj_s[s] 31 map_ss[s1, s2] = s 32 33 return map_ss 34 35 36class Symmetry: 37 def __init__(self, kd: KPointDescriptor): 38 self.kd = kd 39 self.symmetry_map_ss = create_symmetry_map(kd) 40 41 U_scc = kd.symmetry.op_scc 42 is_identity_s = (U_scc == np.eye(3, dtype=int)).all(2).all(1) 43 self.s0 = is_identity_s.nonzero()[0][0] 44 self.inverse_s = self.symmetry_map_ss[:, self.s0] 45 46 def symmetry_operation(self, s: int, wfs, inverse=False): 47 if inverse: 48 s = self.inverse_s[s] 49 U_scc = self.kd.symmetry.op_scc 50 nsym = len(U_scc) 51 time_reversal = s >= nsym 52 s %= nsym 53 U_cc = U_scc[s] 54 55 if (U_cc == np.eye(3, dtype=int)).all(): 56 def T0(a_R): 57 return a_R 58 else: 59 N_c = wfs.gd.N_c 60 i_cr = np.dot(U_cc.T, np.indices(N_c).reshape((3, -1))) 61 i = np.ravel_multi_index(i_cr, N_c, 'wrap') 62 63 def T0(a_R): 64 return a_R.ravel()[i].reshape(N_c) 65 66 if time_reversal: 67 def T(a_R): 68 return T0(a_R).conj() 69 else: 70 T = T0 71 72 T_a = [] 73 for a, id in enumerate(wfs.setups.id_a): 74 b = self.kd.symmetry.a_sa[s, a] 75 S_c = np.dot(wfs.spos_ac[a], U_cc) - wfs.spos_ac[b] 76 U_ii = wfs.setups[a].R_sii[s].T 77 T_a.append((b, S_c, U_ii)) 78 79 return T, T_a, time_reversal 80 81 def apply_symmetry(self, s: int, rsk, wfs, spos_ac): 82 U_scc = self.kd.symmetry.op_scc 83 nsym = len(U_scc) 84 time_reversal = s >= nsym 85 s %= nsym 86 sign = 1 - 2 * int(time_reversal) 87 U_cc = U_scc[s] 88 89 if (U_cc == np.eye(3)).all() and not time_reversal: 90 return rsk 91 92 u1_nR = rsk.u_nR 93 proj1 = rsk.proj 94 f_n = rsk.f_n 95 k1_c = rsk.k_c 96 weight = rsk.weight 97 98 u2_nR = np.empty_like(u1_nR) 99 proj2 = proj1.new() 100 101 k2_c = sign * U_cc.dot(k1_c) 102 103 N_c = u1_nR.shape[1:] 104 i_cr = np.dot(U_cc.T, np.indices(N_c).reshape((3, -1))) 105 i = np.ravel_multi_index(i_cr, N_c, 'wrap') 106 for u1_R, u2_R in zip(u1_nR, u2_nR): 107 u2_R[:] = u1_R.ravel()[i].reshape(N_c) 108 109 for a, id in enumerate(wfs.setups.id_a): 110 b = self.kd.symmetry.a_sa[s, a] 111 S_c = np.dot(spos_ac[a], U_cc) - spos_ac[b] 112 x = np.exp(2j * pi * np.dot(k1_c, S_c)) 113 U_ii = wfs.setups[a].R_sii[s].T * x 114 proj2[a][:] = proj1[b].dot(U_ii) 115 116 if time_reversal: 117 np.conj(u2_nR, out=u2_nR) 118 np.conj(proj2.array, out=proj2.array) 119 120 return RSKPoint(u2_nR, proj2, f_n, k2_c, weight) 121 122 def pairs(self, kpts, wfs, spos_ac): 123 kd = self.kd 124 nsym = len(kd.symmetry.op_scc) 125 126 assert len(kpts) == kd.nibzkpts 127 128 symmetries_k = [] 129 for k in range(kd.nibzkpts): 130 indices = np.where(kd.bz2ibz_k == k)[0] 131 sindices = (kd.sym_k[indices] + 132 kd.time_reversal_k[indices] * nsym) 133 symmetries_k.append(sindices) 134 135 pairs: Dict[Tuple[int, int, int], int] 136 137 pairs1 = defaultdict(int) 138 for i1 in range(kd.nibzkpts): 139 for s1 in symmetries_k[i1]: 140 for i2 in range(kd.nibzkpts): 141 for s2 in symmetries_k[i2]: 142 s3 = self.symmetry_map_ss[s1, s2] 143 # s3 = self.inverse_s[s3] 144 if 1: # i1 < i2: 145 pairs1[(i1, i2, s3)] += 1 146 else: 147 s4 = self.inverse_s[s3] 148 if i1 == i2: 149 # pairs1[(i1, i1, min(s3, s4))] += 1 150 pairs1[(i1, i1, s3)] += 1 151 else: 152 pairs1[(i2, i1, s4)] += 1 153 pairs = {} 154 seen = {} 155 for (i1, i2, s), count in pairs1.items(): 156 k2 = kd.bz2bz_ks[kd.ibz2bz_k[i2], s] 157 if (i1, k2) in seen: 158 pairs[seen[(i1, k2)]] += count 159 else: 160 pairs[(i1, i2, s)] = count 161 # seen[(i1, k2)] = (i1, i2, s) 162 163 comm = wfs.world 164 lasti1 = -1 165 lasti2 = -1 166 for (i1, i2, s), count in sorted(pairs.items()): 167 if i1 != lasti1: 168 k1 = kpts[i1] 169 u1_nR = to_real_space(k1.psit) 170 rsk1 = RSKPoint(u1_nR, k1.proj.broadcast(), 171 k1.f_n, k1.k_c, 172 k1.weight, k1.dPdR_aniv) 173 lasti1 = i1 174 if i2 == i1: 175 if s == self.s0: 176 rsk2 = rsk1 177 else: 178 N = len(rsk1.u_nR) 179 S = comm.size 180 B = (N + S - 1) // S 181 na = min(B * comm.rank, N) 182 nb = min(na + B, N) 183 rsk2 = RSKPoint(rsk1.u_nR[na:nb], 184 rsk1.proj.view(na, nb), 185 rsk1.f_n[na:nb], 186 rsk1.k_c, 187 rsk1.weight) 188 lasti2 = i2 189 elif i2 != lasti2: 190 k2 = kpts[i2] 191 N = len(k2.psit.array) 192 S = comm.size 193 B = (N + S - 1) // S 194 na = min(B * comm.rank, N) 195 nb = min(na + B, N) 196 u2_nR = to_real_space(k2.psit, na, nb) 197 rsk2 = RSKPoint(u2_nR, k2.proj.broadcast().view(na, nb), 198 k2.f_n[na:nb], k2.k_c, 199 k2.weight) 200 lasti2 = i2 201 202 yield (i1, i2, s, rsk1, 203 self.apply_symmetry(s, rsk2, wfs, spos_ac), 204 count) 205