1from math import pi
2from typing import Dict, Tuple
3from collections import defaultdict
4
5import numpy as np
6
7from gpaw.kpt_descriptor import KPointDescriptor
8from .kpts import RSKPoint, to_real_space
9
10
11def create_symmetry_map(kd: KPointDescriptor):  # -> List[List[int]]
12    sym = kd.symmetry
13    U_scc = sym.op_scc
14    nsym = len(U_scc)
15    compconj_s = np.zeros(nsym, bool)
16    if sym.time_reversal and not sym.has_inversion:
17        U_scc = np.concatenate([U_scc, -U_scc])
18        compconj_s = np.zeros(nsym * 2, bool)
19        compconj_s[nsym:] = True
20        nsym *= 2
21
22    map_ss = np.zeros((nsym, nsym), int)
23    for s1 in range(nsym):
24        for s2 in range(nsym):
25            diff_s = abs(U_scc[s1].dot(U_scc).transpose((1, 0, 2)) -
26                         U_scc[s2]).sum(2).sum(1)
27            indices = (diff_s == 0).nonzero()[0]
28            assert len(indices) == 1
29            s = indices[0]
30            assert compconj_s[s1] ^ compconj_s[s2] == compconj_s[s]
31            map_ss[s1, s2] = s
32
33    return map_ss
34
35
36class Symmetry:
37    def __init__(self, kd: KPointDescriptor):
38        self.kd = kd
39        self.symmetry_map_ss = create_symmetry_map(kd)
40
41        U_scc = kd.symmetry.op_scc
42        is_identity_s = (U_scc == np.eye(3, dtype=int)).all(2).all(1)
43        self.s0 = is_identity_s.nonzero()[0][0]
44        self.inverse_s = self.symmetry_map_ss[:, self.s0]
45
46    def symmetry_operation(self, s: int, wfs, inverse=False):
47        if inverse:
48            s = self.inverse_s[s]
49        U_scc = self.kd.symmetry.op_scc
50        nsym = len(U_scc)
51        time_reversal = s >= nsym
52        s %= nsym
53        U_cc = U_scc[s]
54
55        if (U_cc == np.eye(3, dtype=int)).all():
56            def T0(a_R):
57                return a_R
58        else:
59            N_c = wfs.gd.N_c
60            i_cr = np.dot(U_cc.T, np.indices(N_c).reshape((3, -1)))
61            i = np.ravel_multi_index(i_cr, N_c, 'wrap')
62
63            def T0(a_R):
64                return a_R.ravel()[i].reshape(N_c)
65
66        if time_reversal:
67            def T(a_R):
68                return T0(a_R).conj()
69        else:
70            T = T0
71
72        T_a = []
73        for a, id in enumerate(wfs.setups.id_a):
74            b = self.kd.symmetry.a_sa[s, a]
75            S_c = np.dot(wfs.spos_ac[a], U_cc) - wfs.spos_ac[b]
76            U_ii = wfs.setups[a].R_sii[s].T
77            T_a.append((b, S_c, U_ii))
78
79        return T, T_a, time_reversal
80
81    def apply_symmetry(self, s: int, rsk, wfs, spos_ac):
82        U_scc = self.kd.symmetry.op_scc
83        nsym = len(U_scc)
84        time_reversal = s >= nsym
85        s %= nsym
86        sign = 1 - 2 * int(time_reversal)
87        U_cc = U_scc[s]
88
89        if (U_cc == np.eye(3)).all() and not time_reversal:
90            return rsk
91
92        u1_nR = rsk.u_nR
93        proj1 = rsk.proj
94        f_n = rsk.f_n
95        k1_c = rsk.k_c
96        weight = rsk.weight
97
98        u2_nR = np.empty_like(u1_nR)
99        proj2 = proj1.new()
100
101        k2_c = sign * U_cc.dot(k1_c)
102
103        N_c = u1_nR.shape[1:]
104        i_cr = np.dot(U_cc.T, np.indices(N_c).reshape((3, -1)))
105        i = np.ravel_multi_index(i_cr, N_c, 'wrap')
106        for u1_R, u2_R in zip(u1_nR, u2_nR):
107            u2_R[:] = u1_R.ravel()[i].reshape(N_c)
108
109        for a, id in enumerate(wfs.setups.id_a):
110            b = self.kd.symmetry.a_sa[s, a]
111            S_c = np.dot(spos_ac[a], U_cc) - spos_ac[b]
112            x = np.exp(2j * pi * np.dot(k1_c, S_c))
113            U_ii = wfs.setups[a].R_sii[s].T * x
114            proj2[a][:] = proj1[b].dot(U_ii)
115
116        if time_reversal:
117            np.conj(u2_nR, out=u2_nR)
118            np.conj(proj2.array, out=proj2.array)
119
120        return RSKPoint(u2_nR, proj2, f_n, k2_c, weight)
121
122    def pairs(self, kpts, wfs, spos_ac):
123        kd = self.kd
124        nsym = len(kd.symmetry.op_scc)
125
126        assert len(kpts) == kd.nibzkpts
127
128        symmetries_k = []
129        for k in range(kd.nibzkpts):
130            indices = np.where(kd.bz2ibz_k == k)[0]
131            sindices = (kd.sym_k[indices] +
132                        kd.time_reversal_k[indices] * nsym)
133            symmetries_k.append(sindices)
134
135        pairs: Dict[Tuple[int, int, int], int]
136
137        pairs1 = defaultdict(int)
138        for i1 in range(kd.nibzkpts):
139            for s1 in symmetries_k[i1]:
140                for i2 in range(kd.nibzkpts):
141                    for s2 in symmetries_k[i2]:
142                        s3 = self.symmetry_map_ss[s1, s2]
143                        # s3 = self.inverse_s[s3]
144                        if 1:  # i1 < i2:
145                            pairs1[(i1, i2, s3)] += 1
146                        else:
147                            s4 = self.inverse_s[s3]
148                            if i1 == i2:
149                                # pairs1[(i1, i1, min(s3, s4))] += 1
150                                pairs1[(i1, i1, s3)] += 1
151                            else:
152                                pairs1[(i2, i1, s4)] += 1
153        pairs = {}
154        seen = {}
155        for (i1, i2, s), count in pairs1.items():
156            k2 = kd.bz2bz_ks[kd.ibz2bz_k[i2], s]
157            if (i1, k2) in seen:
158                pairs[seen[(i1, k2)]] += count
159            else:
160                pairs[(i1, i2, s)] = count
161                # seen[(i1, k2)] = (i1, i2, s)
162
163        comm = wfs.world
164        lasti1 = -1
165        lasti2 = -1
166        for (i1, i2, s), count in sorted(pairs.items()):
167            if i1 != lasti1:
168                k1 = kpts[i1]
169                u1_nR = to_real_space(k1.psit)
170                rsk1 = RSKPoint(u1_nR, k1.proj.broadcast(),
171                                k1.f_n, k1.k_c,
172                                k1.weight, k1.dPdR_aniv)
173                lasti1 = i1
174            if i2 == i1:
175                if s == self.s0:
176                    rsk2 = rsk1
177                else:
178                    N = len(rsk1.u_nR)
179                    S = comm.size
180                    B = (N + S - 1) // S
181                    na = min(B * comm.rank, N)
182                    nb = min(na + B, N)
183                    rsk2 = RSKPoint(rsk1.u_nR[na:nb],
184                                    rsk1.proj.view(na, nb),
185                                    rsk1.f_n[na:nb],
186                                    rsk1.k_c,
187                                    rsk1.weight)
188                lasti2 = i2
189            elif i2 != lasti2:
190                k2 = kpts[i2]
191                N = len(k2.psit.array)
192                S = comm.size
193                B = (N + S - 1) // S
194                na = min(B * comm.rank, N)
195                nb = min(na + B, N)
196                u2_nR = to_real_space(k2.psit, na, nb)
197                rsk2 = RSKPoint(u2_nR, k2.proj.broadcast().view(na, nb),
198                                k2.f_n[na:nb], k2.k_c,
199                                k2.weight)
200                lasti2 = i2
201
202            yield (i1, i2, s, rsk1,
203                   self.apply_symmetry(s, rsk2, wfs, spos_ac),
204                   count)
205