1"""Module for evaluating two-center integrals.
2
3Contains classes for evaluating integrals of the form::
4
5             /
6            |   _   _a    _   _b   _
7    Theta = | f(r - R ) g(r - R ) dr ,
8            |
9           /
10
11with f and g each being given as a radial function times a spherical
12harmonic.
13
14Important classes
15-----------------
16
17Low-level:
18
19 * OverlapExpansion: evaluate the overlap between a pair of functions (or a
20   function with itself) for some displacement vector: <f | g>.  An overlap
21   expansion can be created once for a pair of splines f and g, and actual
22   values of the overlap can then be evaluated for several different
23   displacement vectors.
24 * FourierTransformer: create OverlapExpansion object from pair of splines.
25
26Mid-level:
27
28 * TwoSiteOverlapExpansions: evaluate overlaps between two *sets* of functions,
29   where functions in the same set reside on the same location: <f_j | g_j>.
30 * TwoSiteOverlapCalculator: create TwoSiteOverlapExpansions object from
31   pair of lists of splines.
32
33High-level:
34
35 * ManySiteOverlapExpansions:  evaluate overlaps with many functions in many
36   locations: <f_aj | g_aj>.
37 * ManySiteOverlapCalculator: create ManySiteOverlapExpansions object from
38   pair of lists of splines nested by atom and orbital number.
39
40The low-level classes do the actual work, while the higher-level ones depend
41on the lower-level ones.
42
43"""
44
45from math import pi, factorial as fac
46
47import numpy as np
48from numpy.fft import ifft
49
50from ase.neighborlist import PrimitiveNeighborList
51from ase.data import covalent_radii
52from ase.units import Bohr
53
54import _gpaw
55from gpaw.gaunt import gaunt
56from gpaw.spherical_harmonics import Yl, nablarlYL
57from gpaw.spline import Spline
58from gpaw.utilities.tools import tri2full
59from gpaw.utilities.timing import nulltimer
60
61timer = nulltimer  # XXX global timer object, only for hacking
62
63UL = 'L'
64
65# Generate the coefficients for the Fourier-Bessel transform
66C = []
67a = 0.0 + 0.0j
68LMAX = 7
69for n in range(LMAX):
70    c = np.zeros(n + 1, complex)
71    for s in range(n + 1):
72        a = (1.0j)**s * fac(n + s) / (fac(s) * 2**s * fac(n - s))
73        a *= (-1.0j)**(n + 1)
74        c[s] = a
75    C.append(c)
76
77
78def fbt(l, f, r, k):
79    """Fast Bessel transform.
80
81    The following integral is calculated using l+1 FFTs::
82
83                    oo
84                   /
85              l+1 |  2           l
86      g(k) = k    | r dr j (kr) r f (r)
87                  |       l
88                 /
89                  0
90    """
91
92    dr = r[1]
93    m = len(k)
94    g = np.zeros(m)
95    for n in range(l + 1):
96        g += (dr * 2 * m * k**(l - n) *
97              ifft(C[l][n] * f * r**(1 + l - n), 2 * m)[:m].real)
98    return g
99
100
101def spherical_harmonics(R_c, lmax=LMAX):
102    R_c = np.asarray(R_c)
103    rlY_lm = []
104    for l in range(lmax):
105        rlY_m = np.empty(2 * l + 1)
106        Yl(l, R_c, rlY_m)
107        rlY_lm.append(rlY_m)
108    return rlY_lm
109
110
111def spherical_harmonics_and_derivatives(R_c, lmax=LMAX):
112    R_c = np.asarray(R_c)
113    drlYdR_lmc = []
114    rlY_lm = spherical_harmonics(R_c, lmax)
115    for l, rlY_m in enumerate(rlY_lm):
116        drlYdR_mc = np.empty((2 * l + 1, 3))
117        for m in range(2 * l + 1):
118            L = l**2 + m
119            drlYdR_mc[m, :] = nablarlYL(L, R_c)
120        drlYdR_lmc.append(drlYdR_mc)
121    return rlY_lm, drlYdR_lmc
122
123
124class BaseOverlapExpansionSet:
125    def __init__(self, shape):
126        self.shape = shape
127
128    def zeros(self, shape=(), dtype=float):
129        return np.zeros(shape + self.shape, dtype=dtype)
130
131
132class OverlapExpansion(BaseOverlapExpansionSet):
133    """A list of real-space splines representing an overlap integral."""
134    def __init__(self, la, lb, spline_l):
135        self.la = la
136        self.lb = lb
137        self.lmaxgaunt = max(la, lb)
138        self.spline_l = spline_l
139        self.lmaxspline = (la + lb) % 2 + 2 * len(self.spline_l)
140        BaseOverlapExpansionSet.__init__(self, (2 * la + 1, 2 * lb + 1))
141        self.cspline_l = [spline.spline for spline in self.spline_l]
142
143    def get_gaunt(self, l):
144        la = self.la
145        lb = self.lb
146        G_LLL = gaunt(max(la, lb))
147        G_mmm = G_LLL[la**2:(la + 1)**2,
148                      lb**2:(lb + 1)**2,
149                      l**2:(l + 1)**2]
150        return G_mmm
151
152    def gaunt_iter(self):
153        la = self.la
154        lb = self.lb
155        l = (la + lb) % 2
156        for spline in self.spline_l:
157            G_mmm = self.get_gaunt(l)
158            yield l, spline, G_mmm
159            l += 2
160
161    def old_evaluate(self, r, rlY_lm):
162        """Get overlap between localized functions.
163
164        Apply Gaunt coefficients to the list of real-space splines
165        describing the overlap integral."""
166        timer.start('oe eval')
167        x_mi = self.zeros()
168        for l, spline, G_mmm in self.gaunt_iter():
169            s = spline(r)
170            if abs(s) > 1e-10:
171                x_mi += s * np.dot(G_mmm, rlY_lm[l])
172        timer.stop('oe eval')
173        return x_mi
174
175    def evaluate(self, r, rlY_lm, G_LLL, x_mi, _nil=np.empty(0)):
176        _gpaw.tci_overlap(self.la, self.lb, G_LLL, self.cspline_l,
177                          r, rlY_lm, x_mi,
178                          False, _nil, _nil, _nil)
179
180    def old_derivative(self, r, Rhat_c, rlY_lm, G_LLL, drlYdR_lmc):
181        """Get derivative of overlap between localized functions.
182
183        This function assumes r > 0.  If r = 0, i.e. if the functions
184        reside on the same atom, the derivative is zero in any case."""
185        timer.start('oldderiv')
186        dxdR_cmi = self.zeros((3,))
187        for l, spline, G_mmm in self.gaunt_iter():
188            x, dxdr = spline.get_value_and_derivative(r)
189            if abs(x) > 1e-10:
190                GrlY_mi = np.dot(G_mmm, rlY_lm[l])
191                dxdR_cmi += dxdr * Rhat_c[:, None, None] * GrlY_mi
192                dxdR_cmi += x * np.dot(G_mmm, drlYdR_lmc[l]).transpose(2, 0, 1)
193        timer.stop('oldderiv')
194        return dxdR_cmi
195
196    def derivative(self, r, Rhat_c, rlY_L, G_LLL, drlYdR_Lc, dxdR_cmi,
197                   _nil=np.empty(0)):
198        # timer.start('deriv')
199        _gpaw.tci_overlap(self.la, self.lb, G_LLL, self.cspline_l,
200                          r, rlY_L, _nil,
201                          True, Rhat_c, drlYdR_Lc, dxdR_cmi)
202        # timer.stop('deriv')
203
204
205class TwoSiteOverlapExpansions(BaseOverlapExpansionSet):
206    def __init__(self, la_j, lb_j, oe_jj):
207        self.oe_jj = oe_jj
208        shape = (sum([2 * l + 1 for l in la_j]),
209                 sum([2 * l + 1 for l in lb_j]))
210        BaseOverlapExpansionSet.__init__(self, shape)
211        if oe_jj.size == 0:
212            self.lmaxgaunt = 0
213            self.lmaxspline = 0
214        else:
215            self.lmaxgaunt = max(oe.lmaxgaunt for oe in oe_jj.ravel())
216            self.lmaxspline = max(oe.lmaxspline for oe in oe_jj.ravel())
217
218    def slice(self, x_xMM):
219        assert x_xMM.shape[-2:] == self.shape
220        Ma1 = 0
221        for j1, oe_j in enumerate(self.oe_jj):
222            Mb1 = 0
223            Ma2 = Ma1
224            for j2, oe in enumerate(oe_j):
225                Ma2 = Ma1 + oe.shape[0]
226                Mb2 = Mb1 + oe.shape[1]
227                yield x_xMM[..., Ma1:Ma2, Mb1:Mb2], oe
228                Mb1 = Mb2
229            Ma1 = Ma2
230
231    def evaluate(self, r, rlY_lm):
232        timer.start('tsoe eval')
233        x_MM = self.zeros()
234        G_LLL = gaunt(self.lmaxgaunt)
235        rlY_L = rlY_lm.toarray(self.lmaxspline)
236        for x_mm, oe in self.slice(x_MM):
237            oe.evaluate(r, rlY_L, G_LLL, x_mm)
238        timer.stop('tsoe eval')
239        return x_MM
240
241    def derivative(self, r, Rhat, rlY_lm, drlYdR_lmc):
242        x_cMM = self.zeros((3,))
243        G_LLL = gaunt(self.lmaxgaunt)
244        rlY_L = rlY_lm.toarray(self.lmaxspline)
245        drlYdR_Lc = drlYdR_lmc.toarray(self.lmaxspline)
246        # print(drlYdR_lmc.shape)
247        for x_cmm, oe in self.slice(x_cMM):
248            oe.derivative(r, Rhat, rlY_L, G_LLL, drlYdR_Lc, x_cmm)
249        return x_cMM
250
251
252class ManySiteOverlapExpansions(BaseOverlapExpansionSet):
253    def __init__(self, tsoe_II, I1_a, I2_a):
254        self.tsoe_II = tsoe_II
255        self.I1_a = I1_a
256        self.I2_a = I2_a
257
258        M1 = 0
259        M1_a = []
260        for I in I1_a:
261            M1_a.append(M1)
262            M1 += tsoe_II[I, 0].shape[0]
263        self.M1_a = M1_a
264
265        M2 = 0
266        M2_a = []
267        for I in I2_a:
268            M2_a.append(M2)
269            M2 += tsoe_II[0, I].shape[1]
270        self.M2_a = M2_a
271
272        shape = (sum([tsoe_II[I, 0].shape[0] for I in I1_a]),
273                 sum([tsoe_II[0, I].shape[1] for I in I2_a]))
274        assert (M1, M2) == shape
275        BaseOverlapExpansionSet.__init__(self, shape)
276
277    def get(self, a1, a2):
278        return self.tsoe_II[self.I1_a[a1], self.I2_a[a2]]
279
280    def getslice(self, a1, a2, x_xMM):
281        I1 = self.I1_a[a1]
282        I2 = self.I2_a[a2]
283        tsoe = self.tsoe_II[I1, I2]
284        Mstart1 = self.M1_a[a1]
285        Mend1 = Mstart1 + tsoe.shape[0]
286        Mstart2 = self.M2_a[a2]
287        Mend2 = Mstart2 + tsoe.shape[1]
288        return x_xMM[..., Mstart1:Mend1, Mstart2:Mend2], tsoe
289
290    def evaluate_slice(self, disp, x_qxMM):
291        x_qxmm, oe = self.getslice(disp.a1, disp.a2, x_qxMM)
292        disp.evaluate_overlap(oe, x_qxmm)
293
294
295class DomainDecomposedExpansions(BaseOverlapExpansionSet):
296    def __init__(self, msoe, local_indices):
297        self.msoe = msoe
298        self.local_indices = local_indices
299        BaseOverlapExpansionSet.__init__(self, msoe.shape)
300
301    def evaluate_slice(self, disp, x_xqMM):
302        if disp.a2 in self.local_indices:
303            self.msoe.evaluate_slice(disp, x_xqMM)
304
305
306class ManySiteDictionaryWrapper(DomainDecomposedExpansions):
307    # Used with dictionaries such as P_aqMi and dPdR_aqcMi
308    # Works only with NeighborPairs, not SimpleAtomIter, since it
309    # compensates for only seeing the atoms once
310
311    def getslice(self, a1, a2, xdict_aqxMi):
312        msoe = self.msoe
313        tsoe = msoe.tsoe_II[msoe.I1_a[a1], msoe.I2_a[a2]]
314        Mstart = self.msoe.M1_a[a1]
315        Mend = Mstart + tsoe.shape[0]
316        return xdict_aqxMi[a2][..., Mstart:Mend, :], tsoe
317
318    def evaluate_slice(self, disp, x_aqxMi):
319        if disp.a2 in x_aqxMi:
320            x_qxmi, oe = self.getslice(disp.a1, disp.a2, x_aqxMi)
321            disp.evaluate_overlap(oe, x_qxmi)
322        if disp.a1 in x_aqxMi and (disp.a1 != disp.a2):
323            x2_qxmi, oe2 = self.getslice(disp.a2, disp.a1, x_aqxMi)
324            rdisp = disp.reverse()
325            rdisp.evaluate_overlap(oe2, x2_qxmi)
326
327
328class BlacsOverlapExpansions(BaseOverlapExpansionSet):
329    def __init__(self, msoe, local_indices, Mmystart, mynao):
330        self.msoe = msoe
331        self.local_indices = local_indices
332        BaseOverlapExpansionSet.__init__(self, msoe.shape)
333
334        self.Mmystart = Mmystart
335        self.mynao = mynao
336
337        M_a = msoe.M1_a
338        natoms = len(M_a)
339        a = 0
340        while a < natoms and M_a[a] <= Mmystart:
341            a += 1
342        a -= 1
343        self.astart = a
344
345        while a < natoms and M_a[a] < Mmystart + mynao:
346            a += 1
347        self.aend = a
348
349    def evaluate_slice(self, disp, x_xqNM):
350        a1 = disp.a1
351        a2 = disp.a2
352        if a2 in self.local_indices and (self.astart <= a1 < self.aend):
353            # assert a2 <= a1
354            msoe = self.msoe
355            I1 = msoe.I1_a[a1]
356            I2 = msoe.I2_a[a2]
357            tsoe = msoe.tsoe_II[I1, I2]
358            x_qxmm = tsoe.zeros(x_xqNM.shape[:-2], dtype=x_xqNM.dtype)
359            disp.evaluate_overlap(tsoe, x_qxmm)
360            Mstart1 = msoe.M1_a[a1] - self.Mmystart
361            Mend1 = Mstart1 + tsoe.shape[0]
362            Mstart1b = max(0, Mstart1)
363            Mend1b = min(self.mynao, Mend1)
364            Mstart2 = msoe.M2_a[a2]
365            Mend2 = Mstart2 + tsoe.shape[1]
366            x_xqNM[..., Mstart1b:Mend1b, Mstart2:Mend2] += \
367                x_qxmm[..., Mstart1b - Mstart1:Mend1b - Mstart1, :]
368        # This is all right as long as we are only interested in a2 <= a1
369        # if a1 in self.local_indices and a2 < a1 and (self.astart <=
370        #                                              a2 < self.aend):
371        #     self.evaluate_slice(disp.reverse(), x_xqNM)
372
373
374class SimpleAtomIter:
375    def __init__(self, cell_cv, spos1_ac, spos2_ac, offsetsteps=0):
376        self.cell_cv = cell_cv
377        self.spos1_ac = spos1_ac
378        self.spos2_ac = spos2_ac
379        self.offsetsteps = offsetsteps
380
381    def iter(self):
382        """Yield all atom index pairs and corresponding displacements."""
383        offsetsteps = self.offsetsteps
384        offsetrange = range(-offsetsteps, offsetsteps + 1)
385        offsets = np.array([(i, j, k) for i in offsetrange for j in offsetrange
386                            for k in offsetrange])
387        for a1, spos1_c in enumerate(self.spos1_ac):
388            for a2, spos2_c in enumerate(self.spos2_ac):
389                for offset in offsets:
390                    R_c = np.dot(spos2_c - spos1_c + offset, self.cell_cv)
391                    yield a1, a2, R_c, offset
392
393
394class NeighborPairs:
395    """Class for looping over pairs of atoms using a neighbor list."""
396    def __init__(self, cutoff_a, cell_cv, pbc_c, self_interaction):
397        self.neighbors = PrimitiveNeighborList(
398            cutoff_a, skin=0, sorted=True,
399            self_interaction=self_interaction,
400            use_scaled_positions=True)
401        self.cell_cv = cell_cv
402        self.pbc_c = pbc_c
403
404    def set_positions(self, spos_ac):
405        self.spos_ac = spos_ac
406        self.neighbors.update(self.pbc_c, self.cell_cv, spos_ac)
407
408    def iter(self):
409        cell_cv = self.cell_cv
410        for a1, spos1_c in enumerate(self.spos_ac):
411            a2_a, offsets = self.neighbors.get_neighbors(a1)
412            for a2, offset in zip(a2_a, offsets):
413                spos2_c = self.spos_ac[a2] + offset
414                R_c = np.dot(spos2_c - spos1_c, cell_cv)
415                yield a1, a2, R_c, offset
416
417
418class PairFilter:
419    def __init__(self, pairs):
420        self.pairs = pairs
421
422    def set_positions(self, spos_ac):
423        self.pairs.set_positions(spos_ac)
424
425    def iter(self):
426        return self.pairs.iter()
427
428
429class PairsWithSelfinteraction(PairFilter):
430    def iter(self):
431        for a1, a2, R_c, offset in self.pairs.iter():
432            yield a1, a2, R_c, offset
433            if a1 == a2 and offset.any():
434                yield a1, a1, -R_c, -offset
435
436
437class PairsBothWays(PairFilter):
438    def iter(self):
439        for a1, a2, R_c, offset in self.pairs.iter():
440            yield a1, a2, R_c, offset
441            yield a2, a1, -R_c, -offset
442
443
444class OppositeDirection(PairFilter):
445    def iter(self):
446        for a1, a2, R_c, offset in self.pairs.iter():
447            yield a2, a1, -R_c, -offset
448
449
450class FourierTransformer:
451    def __init__(self, rcmax, ng):
452        self.ng = ng
453        self.rcmax = rcmax
454        self.dr = rcmax / self.ng
455        self.r_g = np.arange(self.ng) * self.dr
456        self.Q = 4 * self.ng
457        self.dk = 2 * pi / self.Q / self.dr
458        self.k_q = np.arange(self.Q // 2) * self.dk
459
460    def transform(self, spline):
461        assert spline.get_cutoff() <= self.rcmax, (spline.get_cutoff(),
462                                                   self.rcmax)
463        l = spline.get_angular_momentum_number()
464        f_g = spline.map(self.r_g)
465        f_q = fbt(l, f_g, self.r_g, self.k_q)
466        return f_q
467
468    def calculate_overlap_expansion(self, phit1phit2_q, l1, l2):
469        """Calculate list of splines for one overlap integral.
470
471        Given two Fourier transformed functions, return list of splines
472        in real space necessary to evaluate their overlap.
473
474          phi  (q) * phi  (q) --> [phi    (r), ..., phi    (r)] .
475             l1         l2            lmin             lmax
476
477        The overlap <phi1 | phi2> can then be calculated by linear
478        combinations of the returned splines with appropriate Gaunt
479        coefficients.
480        """
481        lmax = l1 + l2
482        splines = []
483        R = np.arange(self.Q // 2) * self.dr
484        R1 = R.copy()
485        R1[0] = 1.0
486        k1 = self.k_q.copy()
487        k1[0] = 1.0
488        a_q = phit1phit2_q
489        for l in range(lmax % 2, lmax + 1, 2):
490            a_g = (8 * fbt(l, a_q * k1**(-2 - lmax - l), self.k_q, R) /
491                   R1**(2 * l + 1))
492            if l == 0:
493                a_g[0] = 8 * np.sum(a_q * k1**(-lmax)) * self.dk
494            else:
495                a_g[0] = a_g[1]  # XXXX
496            a_g *= (-1)**((l1 - l2 - l) // 2)
497            n = len(a_g) // 256
498            s = Spline(l, 2 * self.rcmax, np.concatenate((a_g[::n], [0.0])))
499            splines.append(s)
500        return OverlapExpansion(l1, l2, splines)
501
502    def laplacian(self, f_jq):
503        return 0.5 * f_jq * self.k_q**2.0
504
505
506class TwoSiteOverlapCalculator:
507    def __init__(self, transformer):
508        self.transformer = transformer
509
510    def transform(self, f_j):
511        f_jq = np.array([self.transformer.transform(f) for f in f_j])
512        return f_jq
513
514    def calculate_expansions(self, la_j, fa_jq, lb_j, fb_jq):
515        nja = len(la_j)
516        njb = len(lb_j)
517        assert nja == len(fa_jq) and njb == len(fb_jq)
518        oe_jj = np.empty((nja, njb), dtype=object)
519        for ja, (la, fa_q) in enumerate(zip(la_j, fa_jq)):
520            for jb, (lb, fb_q) in enumerate(zip(lb_j, fb_jq)):
521                a_q = fa_q * fb_q
522                oe = self.transformer.calculate_overlap_expansion(a_q, la, lb)
523                oe_jj[ja, jb] = oe
524        return TwoSiteOverlapExpansions(la_j, lb_j, oe_jj)
525
526    def calculate_kinetic_expansions(self, l_j, f_jq):
527        t_jq = self.transformer.laplacian(f_jq)
528        return self.calculate_expansions(l_j, f_jq, l_j, t_jq)
529
530    def laplacian(self, f_jq):
531        t_jq = self.transformer.laplacian(f_jq)
532        return t_jq
533
534
535class ManySiteOverlapCalculator:
536    def __init__(self, twosite_calculator, I1_a, I2_a):
537        """Create VeryManyOverlaps object.
538
539        twosite_calculator: instance of TwoSiteOverlapCalculator
540        I_a: mapping from atom index (as in spos_a) to unique atom type"""
541        self.twosite_calculator = twosite_calculator
542        self.I1_a = I1_a
543        self.I2_a = I2_a
544
545    def transform(self, f_Ij):
546        f_Ijq = [self.twosite_calculator.transform(f_j) for f_j in f_Ij]
547        return f_Ijq
548
549    def calculate_expansions(self, l1_Ij, f1_Ijq, l2_Ij, f2_Ijq):
550        # Naive implementation, just loop over everything
551        # We should only need half of them
552        nI1 = len(l1_Ij)
553        nI2 = len(l2_Ij)
554        assert len(l1_Ij) == len(f1_Ijq) and len(l2_Ij) == len(f2_Ijq)
555        tsoe_II = np.empty((nI1, nI2), dtype=object)
556        calc = self.twosite_calculator
557        for I1, (l1_j, f1_jq) in enumerate(zip(l1_Ij, f1_Ijq)):
558            for I2, (l2_j, f2_jq) in enumerate(zip(l2_Ij, f2_Ijq)):
559                tsoe = calc.calculate_expansions(l1_j, f1_jq, l2_j, f2_jq)
560                tsoe_II[I1, I2] = tsoe
561        return ManySiteOverlapExpansions(tsoe_II, self.I1_a, self.I2_a)
562
563    def calculate_kinetic_expansions(self, l_Ij, f_Ijq):
564        t_Ijq = [self.twosite_calculator.laplacian(f_jq) for f_jq in f_Ijq]
565        return self.calculate_expansions(l_Ij, f_Ijq, l_Ij, t_Ijq)
566
567
568class AtomicDisplacement:
569    def __init__(self, factory, a1, a2, R_c, offset, phases):
570        self.factory = factory
571        self.a1 = a1
572        self.a2 = a2
573        self.R_c = R_c
574        self.offset = offset
575        self.phases = phases
576        self.r = np.linalg.norm(R_c)
577        self._set_spherical_harmonics(R_c)
578
579    def _set_spherical_harmonics(self, R_c):
580        self.rlY_lm = LazySphericalHarmonics(R_c)
581
582    # XXX new
583    def evaluate_direct(self, oe, dst_xqmm):
584        src_xmm = self.evaluate_direct_without_phases(oe)
585        self.phases.apply(src_xmm, dst_xqmm)
586
587    # XXX new
588    def evaluate_direct_without_phases(self, oe):
589        return oe.evaluate(self.r, self.rlY_lm)
590
591    # XXX clean up unnecessary methods
592    def overlap_without_phases(self, oe):
593        return oe.evaluate(self.r, self.rlY_lm)
594
595    def _evaluate_without_phases(self, oe):
596        return self.overlap_without_phases(oe)
597
598    def evaluate_overlap(self, oe, dst_xqmm):
599        src_xmm = self._evaluate_without_phases(oe)
600        timer.start('phases')
601        self.phases.apply(src_xmm, dst_xqmm)
602        timer.stop('phases')
603
604    def reverse(self):
605        return self.factory.displacementclass(self.factory, self.a2, self.a1,
606                                              -self.R_c, -self.offset,
607                                              self.phases.inverse())
608
609
610class LazySphericalHarmonics:
611    """Class for caching spherical harmonics as they are calculated.
612
613    Behaves like a list Y_lm, but really calculates (or retrieves) Y_m
614    once a given value of l is __getitem__'d."""
615    def __init__(self, R_c):
616        self.R_c = np.asarray(R_c)
617        self.Y_lm = {}
618        self.Y_lm1 = np.empty(0)
619        # self.dYdr_lmc = {}
620        self.lmax = 0
621
622    def evaluate(self, l):
623        rlY_m = np.empty(2 * l + 1)
624        Yl(l, self.R_c, rlY_m)
625        return rlY_m
626
627    def __getitem__(self, l):
628        Y_m = self.Y_lm.get(l)
629        if Y_m is None:
630            Y_m = self.evaluate(l)
631            self.Y_lm[l] = Y_m
632        return Y_m
633
634    def toarray(self, lmax):
635        if lmax > self.lmax:
636            self.Y_lm1 = np.concatenate([self.Y_lm1] +
637                                        [self.evaluate(l).ravel()
638                                         for l in range(self.lmax, lmax)])
639            self.lmax = lmax
640        return self.Y_lm1
641
642
643class LazySphericalHarmonicsDerivative(LazySphericalHarmonics):
644    def evaluate(self, l):
645        drlYdR_mc = np.empty((2 * l + 1, 3))
646        L0 = l**2
647        for m in range(2 * l + 1):
648            drlYdR_mc[m, :] = nablarlYL(L0 + m, self.R_c)
649        return drlYdR_mc
650
651
652class DerivativeAtomicDisplacement(AtomicDisplacement):
653    def _set_spherical_harmonics(self, R_c):
654        self.rlY_lm = LazySphericalHarmonics(R_c)
655        self.drlYdr_lmc = LazySphericalHarmonicsDerivative(R_c)
656
657        if R_c.any():
658            self.Rhat_c = R_c / self.r
659        else:
660            self.Rhat_c = np.zeros(3)
661
662    def derivative_without_phases(self, oe):
663        return oe.derivative(self.r, self.Rhat_c, self.rlY_lm, self.drlYdr_lmc)
664
665    def _evaluate_without_phases(self, oe):  # override
666        return self.derivative_without_phases(oe)
667
668
669class NullPhases:
670    def __init__(self, ibzk_qc, offset):
671        pass
672
673    def apply(self, src_xMM, dst_qxMM):
674        assert len(dst_qxMM) == 1
675        dst_qxMM[0][:] += src_xMM
676
677    def inverse(self):
678        return self
679
680
681class BlochPhases:
682    def __init__(self, ibzk_qc, offset):
683        self.phase_q = np.exp(-2j * pi * np.dot(ibzk_qc, offset))
684        self.ibzk_qc = ibzk_qc
685        self.offset = offset
686
687    def apply(self, src_xMM, dst_qxMM):
688        assert dst_qxMM.dtype == complex, dst_qxMM.dtype
689        for phase, dst_xMM in zip(self.phase_q, dst_qxMM):
690            dst_xMM[:] += phase * src_xMM
691
692    def inverse(self):
693        return BlochPhases(-self.ibzk_qc, self.offset)
694
695
696class TwoCenterIntegralCalculator:
697    # This class knows how to apply phases, and whether to call the
698    # various derivative() or evaluate() methods
699    def __init__(self, ibzk_qc=None, derivative=False):
700        if derivative:
701            displacementclass = DerivativeAtomicDisplacement
702        else:
703            displacementclass = AtomicDisplacement
704        self.displacementclass = displacementclass
705
706        if ibzk_qc is None or not ibzk_qc.any():
707            self.phaseclass = NullPhases
708        else:
709            self.phaseclass = BlochPhases
710        self.ibzk_qc = ibzk_qc
711        self.derivative = derivative
712
713    def calculate(self, atompairs, expansions, arrays):
714        for disp in self.iter(atompairs):
715            for expansion, X_qxMM in zip(expansions, arrays):
716                expansion.evaluate_slice(disp, X_qxMM)
717
718    def iter(self, atompairs):
719        for a1, a2, R_c, offset in atompairs.iter():
720            # if a1 == a2 and self.derivative:
721            #     continue
722            phase_applier = self.phaseclass(self.ibzk_qc, offset)
723            yield self.displacementclass(self, a1, a2, R_c, offset,
724                                         phase_applier)
725
726
727class NewTwoCenterIntegrals:
728    def __init__(self, cell_cv, pbc_c, setups, ibzk_qc, gamma):
729        self.cell_cv = cell_cv
730        self.pbc_c = pbc_c
731        self.ibzk_qc = ibzk_qc
732        self.gamma = gamma
733
734        timer.start('tci init')
735        cutoff_I = []
736        setups_I = setups.setups.values()
737        I_setup = {}
738        for I, setup in enumerate(setups_I):
739            I_setup[setup] = I
740            cutoff_I.append(max([func.get_cutoff()
741                                 for func in setup.phit_j + setup.pt_j]))
742
743        I_a = []
744        for setup in setups:
745            I_a.append(I_setup[setup])
746
747        cutoff_a = [cutoff_I[I] for I in I_a]
748
749        self.cutoff_a = cutoff_a  # convenient for writing the new new overlap
750        self.I_a = I_a
751        self.setups_I = setups_I
752        self.atompairs = PairsWithSelfinteraction(NeighborPairs(cutoff_a,
753                                                                cell_cv,
754                                                                pbc_c,
755                                                                True))
756
757        scale = 0.01  # XXX minimal distance scale
758        cutoff_close_a = [covalent_radii[int(s.Z)] / Bohr * scale
759                          for s in setups]
760        self.atoms_close = NeighborPairs(cutoff_close_a, cell_cv, pbc_c, False)
761
762        rcmax = max(cutoff_I + [0.001])
763
764        ng = 2**10
765        transformer = FourierTransformer(rcmax, ng)
766        tsoc = TwoSiteOverlapCalculator(transformer)
767        self.msoc = ManySiteOverlapCalculator(tsoc, I_a, I_a)
768        self.calculate_expansions()
769
770        self.calculate = self.evaluate  # XXX compatibility
771
772        self.set_matrix_distribution(None, None)
773        timer.stop('tci init')
774
775    def set_matrix_distribution(self, Mmystart, mynao):
776        """Distribute matrices using BLACS."""
777        # Range of basis functions for BLACS distribution of matrices:
778        self.Mmystart = Mmystart
779        self.mynao = mynao
780        self.blacs = mynao is not None
781
782    def calculate_expansions(self):
783        timer.start('tci calc exp')
784        phit_Ij = [setup.phit_j for setup in self.setups_I]
785        l_Ij = []
786        for phit_j in phit_Ij:
787            l_Ij.append([phit.get_angular_momentum_number()
788                         for phit in phit_j])
789
790        pt_l_Ij = [setup.l_j for setup in self.setups_I]
791        pt_Ij = [setup.pt_j for setup in self.setups_I]
792        phit_Ijq = self.msoc.transform(phit_Ij)
793        pt_Ijq = self.msoc.transform(pt_Ij)
794
795        msoc = self.msoc
796
797        self.Theta_expansions = msoc.calculate_expansions(l_Ij, phit_Ijq,
798                                                          l_Ij, phit_Ijq)
799        self.T_expansions = msoc.calculate_kinetic_expansions(l_Ij, phit_Ijq)
800        self.P_expansions = msoc.calculate_expansions(l_Ij, phit_Ijq,
801                                                      pt_l_Ij, pt_Ijq)
802        timer.stop('tci calc exp')
803
804    def _calculate(self, calc, spos_ac, Theta_qxMM, T_qxMM, P_aqxMi):
805        Theta_qxMM.fill(0.0)
806        T_qxMM.fill(0.0)
807        for P_qxMi in P_aqxMi.values():
808            P_qxMi.fill(0.0)
809
810        if 1:  # XXX
811            self.atoms_close.set_positions(spos_ac)
812            wrk = [x for x in self.atoms_close.iter()]
813            if len(wrk) != 0:
814                txt = ''
815                for a1, a2, R_c, offset in wrk:
816                    txt += 'Atom %d and Atom %d in cell (%d, %d, %d)\n' % \
817                        (a1, a2, offset[0], offset[1], offset[2])
818                raise RuntimeError('Atoms too close!\n' + txt)
819
820        self.atompairs.set_positions(spos_ac)
821
822        if self.blacs:
823            # S and T matrices are distributed:
824            expansions = [
825                BlacsOverlapExpansions(self.Theta_expansions,
826                                       P_aqxMi, self.Mmystart, self.mynao),
827                BlacsOverlapExpansions(self.T_expansions,
828                                       P_aqxMi, self.Mmystart, self.mynao)]
829        else:
830            expansions = [DomainDecomposedExpansions(self.Theta_expansions,
831                                                     P_aqxMi),
832                          DomainDecomposedExpansions(self.T_expansions,
833                                                     P_aqxMi)]
834
835        expansions.append(ManySiteDictionaryWrapper(self.P_expansions,
836                                                    P_aqxMi))
837        arrays = [Theta_qxMM, T_qxMM, P_aqxMi]
838        timer.start('tci calculate')
839        calc.calculate(OppositeDirection(self.atompairs), expansions, arrays)
840        timer.stop('tci calculate')
841
842    def evaluate(self, spos_ac, Theta_qMM, T_qMM, P_aqMi):
843        calc = TwoCenterIntegralCalculator(self.ibzk_qc, derivative=False)
844        self._calculate(calc, spos_ac, Theta_qMM, T_qMM, P_aqMi)
845        if not self.blacs:
846            for X_MM in list(Theta_qMM) + list(T_qMM):
847                tri2full(X_MM, UL=UL)
848
849    def derivative(self, spos_ac, dThetadR_qcMM, dTdR_qcMM, dPdR_aqcMi):
850        calc = TwoCenterIntegralCalculator(self.ibzk_qc, derivative=True)
851        self._calculate(calc, spos_ac, dThetadR_qcMM, dTdR_qcMM, dPdR_aqcMi)
852
853        def antihermitian(src, dst):
854            np.conj(-src, dst)
855
856        if not self.blacs:
857            for X_cMM in list(dThetadR_qcMM) + list(dTdR_qcMM):
858                for X_MM in X_cMM:
859                    tri2full(X_MM, UL=UL, map=antihermitian)
860
861    calculate_derivative = derivative  # XXX compatibility
862
863    def estimate_memory(self, mem):
864        mem.subnode('TCI not impl.', 0)
865