1# encoding: utf-8
2import numpy as np
3import scipy.sparse as sparse
4from ase.neighborlist import PrimitiveNeighborList
5# from ase.utils.timing import timer
6from gpaw.utilities.tools import tri2full
7
8# from gpaw import debug
9from gpaw.lcao.overlap import (FourierTransformer, TwoSiteOverlapCalculator,
10                               ManySiteOverlapCalculator,
11                               AtomicDisplacement, NullPhases, BlochPhases,
12                               DerivativeAtomicDisplacement)
13
14
15def get_cutoffs(f_Ij):
16    rcutmax_I = []
17    for f_j in f_Ij:
18        rcutmax = 0.001  # 'paranoid zero'
19        for f in f_j:
20            rcutmax = max(rcutmax, f.get_cutoff())
21        rcutmax_I.append(rcutmax)
22    return rcutmax_I
23
24
25def get_lvalues(f_Ij):
26    return [[f.get_angular_momentum_number() for f in f_j] for f_j in f_Ij]
27
28
29class AtomPairRegistry:
30    def __init__(self, cutoff_a, pbc_c, cell_cv, spos_ac):
31        nl = PrimitiveNeighborList(cutoff_a, skin=0, sorted=True,
32                                   self_interaction=True,
33                                   use_scaled_positions=True)
34
35        nl.update(pbc=pbc_c, cell=cell_cv, coordinates=spos_ac)
36        r_and_offset_aao = {}
37
38        def add(a1, a2, R_c, offset):
39            r_and_offset_aao.setdefault((a1, a2), []).append((R_c, offset))
40
41        for a1, spos1_c in enumerate(spos_ac):
42            a2_a, offsets = nl.get_neighbors(a1)
43            for a2, offset in zip(a2_a, offsets):
44                spos2_c = spos_ac[a2] + offset
45
46                R_c = np.dot(spos2_c - spos1_c, cell_cv)
47                add(a1, a2, R_c, offset)
48                if a1 != a2 or offset.any():
49                    add(a2, a1, -R_c, -offset)
50        self.r_and_offset_aao = r_and_offset_aao
51
52    def get(self, a1, a2):
53        R_c_and_offset_a = self.r_and_offset_aao.get((a1, a2))
54        return R_c_and_offset_a
55
56    def get_atompairs(self):
57        return list(sorted(self.r_and_offset_aao))
58
59
60class TCIExpansions:
61    def __init__(self, phit_Ij, pt_Ij, I_a):
62        assert len(pt_Ij) == len(phit_Ij)
63
64        # Cutoffs by species:
65        pt_rcmax_I = get_cutoffs(pt_Ij)
66        phit_rcmax_I = get_cutoffs(phit_Ij)
67        rcmax_I = [max(rc1, rc2) for rc1, rc2
68                   in zip(pt_rcmax_I, phit_rcmax_I)]
69
70        transformer = FourierTransformer(rcmax=max(rcmax_I + [1e-3]), ng=2**10)
71        tsoc = TwoSiteOverlapCalculator(transformer)
72        msoc = ManySiteOverlapCalculator(tsoc, I_a, I_a)
73        phit_Ijq = msoc.transform(phit_Ij)
74        pt_Ijq = msoc.transform(pt_Ij)
75        pt_l_Ij = get_lvalues(pt_Ij)
76        phit_l_Ij = get_lvalues(phit_Ij)
77        self.O_expansions = msoc.calculate_expansions(phit_l_Ij, phit_Ijq,
78                                                      phit_l_Ij, phit_Ijq)
79        self.T_expansions = msoc.calculate_kinetic_expansions(phit_l_Ij,
80                                                              phit_Ijq)
81        self.P_expansions = msoc.calculate_expansions(pt_l_Ij, pt_Ijq,
82                                                      phit_l_Ij, phit_Ijq)
83        self.I_a = I_a  # Actually I_a belongs outside, like spos_ac.
84        self.rcmax_I = rcmax_I
85        self.phit_rcmax_I = phit_rcmax_I
86        self.pt_rcmax_I = pt_rcmax_I
87
88    @classmethod
89    def new_from_setups(cls, setups):
90        I_setup = {}
91        setups_I = list(setups.setups.values())
92        for I, setup in enumerate(setups_I):
93            I_setup[setup] = I
94        I_a = [I_setup[setup] for setup in setups]
95
96        return TCIExpansions([s.phit_j for s in setups_I],
97                             [s.pt_j for s in setups_I],
98                             I_a)
99
100    def get_tci_calculator(self, cell_cv, spos_ac, pbc_c, ibzk_qc, dtype):
101        return TCICalculator(self, cell_cv, spos_ac, pbc_c, ibzk_qc, dtype)
102
103    def get_manytci_calculator(self, setups, gd, spos_ac, ibzk_qc, dtype,
104                               timer):
105        return ManyTCICalculator(self, setups, gd, spos_ac, ibzk_qc, dtype,
106                                 timer)
107
108
109class TCICalculator:
110    """High-level two-center integral calculator.
111
112    This object is not aware of parallelization.  It works with any
113    pair of atoms a1, a2.
114
115    Create the object and calculate any interatomic overlap matrix as below.
116
117      tci = TCI(...)
118
119    Projector/basis overlap <pt_i^a1|phi_mu> between atoms a1, a2:
120
121      P_qim = tci.P(a1, a2)
122
123    Derivatives of the above with respect to movement of a2:
124
125      dPdR_qvim = tci.dPdR(a1, a2)
126
127    Basis/basis overlap and kinetic matrix elements between atoms a1, a2:
128
129      O_qmm, T_qmm = tci.O_T(a1, a2)
130
131    Derivative of the above wrt. position of a2:
132
133      dOdR_qvmm, dTdR_qvmm = tci.dOdR_dTdR(a1, a2)
134
135    """
136    def __init__(self, tciexpansions, cell_cv, spos_ac, pbc_c, ibzk_qc,
137                 dtype):
138
139        self.tciexpansions = tciexpansions
140        self.dtype = dtype
141
142        # XXX It is somewhat nasty that rcmax depends on how long our
143        # longest orbital happens to be
144        # Cutoffs by atom:
145        I_a = tciexpansions.I_a
146        cutoff_a = [tciexpansions.rcmax_I[I] for I in I_a]
147        self.pt_rcmax_a = np.array([tciexpansions.pt_rcmax_I[I] for I in I_a])
148        self.phit_rcmax_a = np.array([tciexpansions.phit_rcmax_I[I]
149                                      for I in I_a])
150
151        self.a1a2 = AtomPairRegistry(cutoff_a, pbc_c, cell_cv, spos_ac)
152
153        self.ibzk_qc = ibzk_qc
154        if ibzk_qc.any():
155            self.get_phases = BlochPhases
156        else:
157            self.get_phases = NullPhases
158
159        self.O_T = self._tci_shortcut(False, False)
160        self.P = self._tci_shortcut(True, False)
161        self.dOdR_dTdR = self._tci_shortcut(False, True)
162        self.dPdR = self._tci_shortcut(True, True)
163
164    def _tci_shortcut(self, P, derivative):
165        def calculate(a1, a2):
166            return self._calculate(a1, a2, P, derivative)
167        return calculate
168
169    def _calculate(self, a1, a2, P=False, derivative=False):
170        """Calculate overlap of functions between atoms a1 and a2."""
171
172        # We want to see quickly if there is no overlap because distance
173        # outside bounding spheres.
174
175        R_c_and_offset_a = self.a1a2.get(a1, a2)
176        if R_c_and_offset_a is None:
177            return None if P else (None, None)
178
179        rcut1 = self.pt_rcmax_a[a1] if P else self.phit_rcmax_a[a1]
180        rcut2 = self.phit_rcmax_a[a2]
181        maxdist = rcut1 + rcut2
182
183        # Filter out displacements larger than maxdist:
184        R_c_and_offset_a = [obj for obj in R_c_and_offset_a
185                            if np.linalg.norm(obj[0]) < maxdist]
186        if not R_c_and_offset_a:  # There was no overlap after all
187            return None if P else (None, None)
188
189        dtype = self.dtype
190        get_phases = self.get_phases
191
192        displacement = (DerivativeAtomicDisplacement
193                        if derivative
194                        else AtomicDisplacement)
195        ibzk_qc = self.ibzk_qc
196        nq = len(ibzk_qc)
197        phit_rcmax_a = self.phit_rcmax_a
198        pt_rcmax_a = self.pt_rcmax_a
199
200        shape = (nq, 3) if derivative else (nq,)
201
202        if P:
203            P_expansion = self.tciexpansions.P_expansions.get(a1, a2)
204            obj = P_qim = P_expansion.zeros(shape, dtype=dtype)
205        else:
206            O_expansion = self.tciexpansions.O_expansions.get(a1, a2)
207            T_expansion = self.tciexpansions.T_expansions.get(a1, a2)
208            O_qmm = O_expansion.zeros(shape, dtype=dtype)
209            T_qmm = T_expansion.zeros(shape, dtype=dtype)
210            obj = O_qmm, T_qmm
211
212        for R_c, offset in R_c_and_offset_a:
213            norm = np.linalg.norm(R_c)
214            phases = get_phases(ibzk_qc, offset)
215
216            disp = displacement(None, a1, a2, R_c, offset, phases)
217
218            if P:
219                assert norm < pt_rcmax_a[a1] + phit_rcmax_a[a2]
220                disp.evaluate_overlap(P_expansion, P_qim)
221            else:
222                assert norm < phit_rcmax_a[a1] + phit_rcmax_a[a2]
223                disp.evaluate_overlap(O_expansion, O_qmm)
224                disp.evaluate_overlap(T_expansion, T_qmm)
225
226        return obj
227
228
229class ManyTCICalculator:
230    def __init__(self, tciexpansions, setups, gd, spos_ac, ibzk_qc, dtype,
231                 timer):
232        self.tci = tciexpansions.get_tci_calculator(gd.cell_cv, spos_ac,
233                                                    gd.pbc_c,
234                                                    ibzk_qc, dtype)
235
236        self.setups = setups
237        self.dtype = dtype
238        self.Pindices = setups.projector_indices()
239        self.Mindices = setups.basis_indices()
240        self.natoms = len(setups)
241        self.nq = len(ibzk_qc)
242        self.nao = self.Mindices.max
243        self.timer = timer
244
245    # @timer('tci-projectors')
246    def P_aqMi(self, my_atom_indices, derivative=False):
247        P_axMi = {}
248        if derivative:
249            P = self.tci.dPdR
250
251            def empty(nI):
252                return np.empty((self.nq, 3, self.nao, nI), self.dtype)
253        else:
254            P = self.tci.P
255
256            def empty(nI):
257                return np.empty((self.nq, self.nao, nI), self.dtype)
258
259        Mindices = self.Mindices
260
261        for a1 in my_atom_indices:
262            P_xMi = empty(self.setups[a1].ni)
263
264            for a2 in range(self.natoms):
265                N1, N2 = Mindices[a2]
266                P_xmi = P_xMi[..., N1:N2, :]
267                P_xim = P(a1, a2)
268                if P_xim is None:
269                    P_xmi[:] = 0.0
270                else:
271                    P_xmi[:] = P_xim.swapaxes(-2, -1).conj()
272            P_axMi[a1] = P_xMi
273
274        if derivative:
275            for a in P_axMi:
276                P_axMi[a] *= -1.0
277        return P_axMi
278
279    # @timer('tci-sparseprojectors')
280    def P_qIM(self, my_atom_indices):
281        nq = self.nq
282        P = self.tci.P
283        P_qIM = [sparse.dok_matrix((self.Pindices.max, self.Mindices.max),
284                                   dtype=self.dtype)
285                 for _ in range(nq)]
286
287        for a1 in my_atom_indices:
288            I1, I2 = self.Pindices[a1]
289
290            # We can stride a2 over e.g. bd.comm and then do bd.comm.sum().
291            # How should we do comm.sum() on a sparse matrix though?
292            for a2 in range(self.natoms):
293                M1, M2 = self.Mindices[a2]
294                P_qim = P(a1, a2)
295                if P_qim is not None:
296                    for q in range(nq):
297                        P_qIM[q][I1:I2, M1:M2] = P_qim[q]
298        P_qIM = [P_IM.tocsr() for P_IM in P_qIM]
299        return P_qIM
300
301    # @timer('tci-bfs')
302    def O_qMM_T_qMM(self, gdcomm, Mstart, Mstop, ignore_upper=False,
303                    derivative=False):
304        mynao = Mstop - Mstart
305        Mindices = self.Mindices
306
307        if derivative:
308            O_T = self.tci.dOdR_dTdR
309            shape = (self.nq, 3, mynao, self.nao)
310        else:
311            O_T = self.tci.O_T
312            shape = (self.nq, mynao, self.nao)
313
314        O_xMM = np.zeros(shape, self.dtype)
315        T_xMM = np.zeros(shape, self.dtype)
316
317        # XXX the a1/a2 loops are not yet well load balanced.
318        for a1 in range(self.natoms):
319            M1, M2 = Mindices[a1]
320            if M2 <= Mstart or M1 >= Mstop:
321                continue
322
323            myM1 = max(M1 - Mstart, 0)
324            myM2 = min(M2 - Mstart, mynao)
325            nM = myM2 - myM1
326
327            assert nM > 0, nM
328
329            a2max = a1 + 1  # if not derivative else self.natoms
330
331            for a2 in range(gdcomm.rank, a2max, gdcomm.size):
332                O_xmm, T_xmm = O_T(a1, a2)
333                if O_xmm is None:
334                    continue
335
336                N1, N2 = Mindices[a2]
337                m1 = max(Mstart - M1, 0)
338                m2 = m1 + nM  # (Slice may go beyond end of matrix but OK)
339                O_xmm = O_xmm[..., m1:m2, :]
340                T_xmm = T_xmm[..., m1:m2, :]
341                O_xMM[..., myM1:myM2, N1:N2] = O_xmm
342                T_xMM[..., myM1:myM2, N1:N2] = T_xmm
343
344        if not ignore_upper and O_xMM.size:  # reshape() fails on size-0 arrays
345            assert mynao == self.nao
346            assert O_xMM.shape[-2:] == (self.nao, self.nao)
347            if derivative:
348                def lumap(arr, out):
349                    np.conj(arr, out)
350                    out *= -1.0
351            else:
352                lumap = np.conj
353
354            for arr_xMM in [O_xMM, T_xMM]:
355                for tmp_MM in arr_xMM.reshape(-1, self.nao, self.nao):
356                    tri2full(tmp_MM, UL='L', map=lumap)
357
358        return O_xMM, T_xMM
359