1# -*- coding: utf-8 -*-
2import numbers
3from math import factorial as fac
4from math import pi
5
6import _gpaw
7import gpaw
8import gpaw.fftw as fftw
9import numpy as np
10from ase.units import Bohr, Ha
11from ase.utils.timing import timer
12from gpaw.arraydict import ArrayDict
13from gpaw.band_descriptor import BandDescriptor
14from gpaw.blacs import BlacsDescriptor, BlacsGrid, Redistributor
15from gpaw.density import Density
16from gpaw.hamiltonian import Hamiltonian
17from gpaw.lcao.overlap import fbt
18from gpaw.lfc import BaseLFC
19from gpaw.matrix_descriptor import MatrixDescriptor
20from gpaw.spherical_harmonics import Y, nablarlYL
21from gpaw.spline import Spline
22from gpaw.typing import Array3D
23from gpaw.utilities import unpack
24from gpaw.utilities.blas import axpy, mmm, r2k, rk
25from gpaw.utilities.progressbar import ProgressBar
26from gpaw.wavefunctions.arrays import PlaneWaveExpansionWaveFunctions
27from gpaw.wavefunctions.fdpw import FDPWWaveFunctions
28from gpaw.wavefunctions.mode import Mode
29
30
31def pad(array, N):
32    """Pad 1-d ndarray with zeros up to length N."""
33    if array is None:
34        return None
35    n = len(array)
36    if n == N:
37        return array
38    b = np.empty(N, complex)
39    b[:n] = array
40    b[n:] = 0
41    return b
42
43
44class PW(Mode):
45    name = 'pw'
46
47    def __init__(self, ecut=340, fftwflags=fftw.MEASURE, cell=None,
48                 gammacentered=False,
49                 pulay_stress=None, dedecut=None,
50                 force_complex_dtype=False):
51        """Plane-wave basis mode.
52
53        ecut: float
54            Plane-wave cutoff in eV.
55        gammacentered: bool
56            Center the grid of chosen plane waves around the
57            gamma point or q/k-vector
58        dedecut: float or None or 'estimate'
59            Estimate of derivative of total energy with respect to
60            plane-wave cutoff.  Used to calculate pulay_stress.
61        pulay_stress: float or None
62            Pulay-stress correction.
63        fftwflags: int
64            Flags for making an FFTW plan.  There are 4 possibilities
65            (default is MEASURE)::
66
67                from gpaw.fftw import ESTIMATE, MEASURE, PATIENT, EXHAUSTIVE
68
69        cell: 3x3 ndarray
70            Use this unit cell to chose the planewaves.
71
72        Only one of dedecut and pulay_stress can be used.
73        """
74
75        self.gammacentered = gammacentered
76        self.ecut = ecut / Ha
77        # Don't do expensive planning in dry-run mode:
78        self.fftwflags = fftwflags if not gpaw.dry_run else fftw.MEASURE
79        self.dedecut = dedecut
80        self.pulay_stress = (None
81                             if pulay_stress is None
82                             else pulay_stress * Bohr**3 / Ha)
83
84        assert pulay_stress is None or dedecut is None
85
86        if cell is None:
87            self.cell_cv = None
88        else:
89            self.cell_cv = cell / Bohr
90
91        Mode.__init__(self, force_complex_dtype)
92
93    def __call__(self, parallel, initksl, gd, **kwargs):
94        dedepsilon = 0.0
95        volume = abs(np.linalg.det(gd.cell_cv))
96
97        if self.cell_cv is None:
98            ecut = self.ecut
99        else:
100            volume0 = abs(np.linalg.det(self.cell_cv))
101            ecut = self.ecut * (volume0 / volume)**(2 / 3.0)
102
103        if self.pulay_stress is not None:
104            dedepsilon = self.pulay_stress * volume
105        elif self.dedecut is not None:
106            if self.dedecut == 'estimate':
107                dedepsilon = 'estimate'
108            else:
109                dedepsilon = self.dedecut * 2 / 3 * ecut
110
111        wfs = PWWaveFunctions(ecut, self.gammacentered,
112                              self.fftwflags, dedepsilon,
113                              parallel, initksl, gd=gd,
114                              **kwargs)
115
116        return wfs
117
118    def todict(self):
119        dct = Mode.todict(self)
120        dct['ecut'] = self.ecut * Ha
121        dct['gammacentered'] = self.gammacentered
122
123        if self.cell_cv is not None:
124            dct['cell'] = self.cell_cv * Bohr
125        if self.pulay_stress is not None:
126            dct['pulay_stress'] = self.pulay_stress * Ha / Bohr**3
127        if self.dedecut is not None:
128            dct['dedecut'] = self.dedecut
129        return dct
130
131
132class PWDescriptor:
133    ndim = 1  # all 3d G-vectors are stored in a 1d ndarray
134
135    def __init__(self, ecut, gd, dtype=None, kd=None,
136                 fftwflags=fftw.MEASURE, gammacentered=False):
137
138        assert gd.pbc_c.all()
139
140        self.gd = gd
141        self.fftwflags = fftwflags
142
143        N_c = gd.N_c
144        self.comm = gd.comm
145
146        ecut0 = 0.5 * pi**2 / (self.gd.h_cv**2).sum(1).max()
147        if ecut is None:
148            ecut = 0.9999 * ecut0
149        else:
150            assert ecut <= ecut0
151
152        self.ecut = ecut
153
154        if dtype is None:
155            if kd is None or kd.gamma:
156                dtype = float
157            else:
158                dtype = complex
159        self.dtype = dtype
160        self.gammacentered = gammacentered
161
162        if dtype == float:
163            Nr_c = N_c.copy()
164            Nr_c[2] = N_c[2] // 2 + 1
165            i_Qc = np.indices(Nr_c).transpose((1, 2, 3, 0))
166            i_Qc[..., :2] += N_c[:2] // 2
167            i_Qc[..., :2] %= N_c[:2]
168            i_Qc[..., :2] -= N_c[:2] // 2
169            self.tmp_Q = fftw.empty(Nr_c, complex)
170            self.tmp_R = self.tmp_Q.view(float)[:, :, :N_c[2]]
171        else:
172            i_Qc = np.indices(N_c).transpose((1, 2, 3, 0))
173            i_Qc += N_c // 2
174            i_Qc %= N_c
175            i_Qc -= N_c // 2
176            self.tmp_Q = fftw.empty(N_c, complex)
177            self.tmp_R = self.tmp_Q
178
179        self.fftplan = fftw.FFTPlan(self.tmp_R, self.tmp_Q, -1, fftwflags)
180        self.ifftplan = fftw.FFTPlan(self.tmp_Q, self.tmp_R, 1, fftwflags)
181
182        # Calculate reciprocal lattice vectors:
183        B_cv = 2.0 * pi * gd.icell_cv
184        i_Qc.shape = (-1, 3)
185        self.G_Qv = np.dot(i_Qc, B_cv)
186
187        self.kd = kd
188        if kd is None:
189            self.K_qv = np.zeros((1, 3))
190            self.only_one_k_point = True
191        else:
192            self.K_qv = np.dot(kd.ibzk_qc, B_cv)
193            self.only_one_k_point = (kd.nbzkpts == 1)
194
195        # Map from vectors inside sphere to fft grid:
196        self.Q_qG = []
197        G2_qG = []
198        Q_Q = np.arange(len(i_Qc), dtype=np.int32)
199
200        self.ng_q = []
201        for q, K_v in enumerate(self.K_qv):
202            G2_Q = ((self.G_Qv + K_v)**2).sum(axis=1)
203            if gammacentered:
204                mask_Q = ((self.G_Qv**2).sum(axis=1) <= 2 * ecut)
205            else:
206                mask_Q = (G2_Q <= 2 * ecut)
207
208            if self.dtype == float:
209                mask_Q &= ((i_Qc[:, 2] > 0) |
210                           (i_Qc[:, 1] > 0) |
211                           ((i_Qc[:, 0] >= 0) & (i_Qc[:, 1] == 0)))
212            Q_G = Q_Q[mask_Q]
213            self.Q_qG.append(Q_G)
214            G2_qG.append(G2_Q[Q_G])
215            ng = len(Q_G)
216            self.ng_q.append(ng)
217
218        self.ngmin = min(self.ng_q)
219        self.ngmax = max(self.ng_q)
220
221        if kd is not None:
222            self.ngmin = kd.comm.min(self.ngmin)
223            self.ngmax = kd.comm.max(self.ngmax)
224
225        # Distribute things:
226        S = gd.comm.size
227        self.maxmyng = (self.ngmax + S - 1) // S
228        ng1 = gd.comm.rank * self.maxmyng
229        ng2 = ng1 + self.maxmyng
230
231        self.G2_qG = []
232        self.myQ_qG = []
233        self.myng_q = []
234        for q, G2_G in enumerate(G2_qG):
235            self.G2_qG.append(G2_G[ng1:ng2].copy())
236            myQ_G = self.Q_qG[q][ng1:ng2]
237            self.myQ_qG.append(myQ_G)
238            self.myng_q.append(len(myQ_G))
239
240        if S > 1:
241            self.tmp_G = np.empty(self.maxmyng * S, complex)
242        else:
243            self.tmp_G = None
244
245    def get_reciprocal_vectors(self, q=0, add_q=True):
246        """Returns reciprocal lattice vectors plus q, G + q,
247        in xyz coordinates."""
248
249        if add_q:
250            q_v = self.K_qv[q]
251            return self.G_Qv[self.myQ_qG[q]] + q_v
252        return self.G_Qv[self.myQ_qG[q]]
253
254    def __getstate__(self):
255        return (self.ecut, self.gd, self.dtype, self.kd, self.fftwflags)
256
257    def __setstate__(self, state):
258        self.__init__(*state)
259
260    def estimate_memory(self, mem):
261        nbytes = (self.tmp_R.nbytes +
262                  self.G_Qv.nbytes +
263                  len(self.K_qv) * (self.ngmax * 4 +
264                                    self.maxmyng * (8 + 4)))
265        mem.subnode('Arrays', nbytes)
266
267    def bytecount(self, dtype=float):
268        return self.ngmax * 16
269
270    def zeros(self, x=(), dtype=None, q=None, global_array=False):
271        """Return zeroed array.
272
273        The shape of the array will be x + (ng,) where ng is the number
274        of G-vectors for on this core.  Different k-points will have
275        different values for ng.  Therefore, the q index must be given,
276        unless we are describibg a real-valued function."""
277
278        a_xG = self.empty(x, dtype, q, global_array)
279        a_xG.fill(0.0)
280        return a_xG
281
282    def empty(self, x=(), dtype=None, q=None, global_array=False):
283        """Return empty array."""
284        if dtype is not None:
285            assert dtype == self.dtype
286        if isinstance(x, numbers.Integral):
287            x = (x,)
288        if q is None:
289            assert self.only_one_k_point
290            q = 0
291        if global_array:
292            shape = x + (self.ng_q[q],)
293        else:
294            shape = x + (self.myng_q[q],)
295        return np.empty(shape, complex)
296
297    def fft(self, f_R, q=None, Q_G=None, local=False):
298        """Fast Fourier transform.
299
300        Returns c(G) for G<Gc::
301
302                   --      -iG.R
303            c(G) = > f(R) e
304                   --
305                   R
306
307        If local=True, all cores will do an FFT without any
308        collect/scatter.
309        """
310
311        if local:
312            self.tmp_R[:] = f_R
313        else:
314            self.gd.collect(f_R, self.tmp_R)
315
316        if self.gd.comm.rank == 0 or local:
317            self.fftplan.execute()
318            if Q_G is None:
319                q = q or 0
320                Q_G = self.Q_qG[q]
321            f_G = self.tmp_Q.ravel()[Q_G]
322            if local:
323                return f_G
324        else:
325            f_G = None
326
327        return self.scatter(f_G, q)
328
329    def ifft(self, c_G, q=None, local=False, safe=True, distribute=True):
330        """Inverse fast Fourier transform.
331
332        Returns::
333
334                   1 --        iG.R
335            f(R) = - > c(G) e
336                   N --
337                     G
338
339        If local=True, all cores will do an iFFT without any
340        gather/distribute.
341        """
342        assert q is not None or self.only_one_k_point
343        q = q or 0
344        if not local:
345            c_G = self.gather(c_G, q)
346        comm = self.gd.comm
347        scale = 1.0 / self.tmp_R.size
348        if comm.rank == 0 or local:
349            # Same as:
350            #
351            #    self.tmp_Q[:] = 0.0
352            #    self.tmp_Q.ravel()[self.Q_qG[q]] = scale * c_G
353            #
354            # but much faster:
355            Q_G = self.Q_qG[q]
356            assert len(c_G) == len(Q_G)
357            _gpaw.pw_insert(c_G, Q_G, scale, self.tmp_Q)
358
359            if self.dtype == float:
360                t = self.tmp_Q[:, :, 0]
361                n, m = self.gd.N_c[:2] // 2 - 1
362                t[0, -m:] = t[0, m:0:-1].conj()
363                t[n:0:-1, -m:] = t[-n:, m:0:-1].conj()
364                t[-n:, -m:] = t[n:0:-1, m:0:-1].conj()
365                t[-n:, 0] = t[n:0:-1, 0].conj()
366            self.ifftplan.execute()
367        if comm.size == 1 or local or not distribute:
368            if safe:
369                return self.tmp_R.copy()
370            return self.tmp_R
371        return self.gd.distribute(self.tmp_R)
372
373    def scatter(self, a_G, q=None):
374        """Scatter coefficients from master to all cores."""
375        comm = self.gd.comm
376        if comm.size == 1:
377            return a_G
378
379        mya_G = np.empty(self.maxmyng, complex)
380        comm.scatter(pad(a_G, self.maxmyng * comm.size), mya_G, 0)
381        return mya_G[:self.myng_q[q or 0]]
382
383    def gather(self, a_G, q=None):
384        """Gather coefficients on master."""
385        comm = self.gd.comm
386
387        if comm.size == 1:
388            return a_G
389
390        mya_G = pad(a_G, self.maxmyng)
391        if comm.rank == 0:
392            a_G = self.tmp_G
393        else:
394            a_G = None
395        comm.gather(mya_G, 0, a_G)
396        if comm.rank == 0:
397            return a_G[:self.ng_q[q or 0]]
398
399    def alltoall1(self, a_rG, q):
400        """Gather coefficients from a_rG[r] on rank r.
401
402        On rank r, an array of all G-vector coefficients will be returned.
403        These will be gathered from a_rG[r] on all the cores.
404        """
405        comm = self.gd.comm
406        if comm.size == 1:
407            return a_rG[0]
408        N = len(a_rG)
409        ng = self.ng_q[q]
410        ssize_r = np.zeros(comm.size, int)
411        ssize_r[:N] = self.myng_q[q]
412        soffset_r = np.arange(comm.size) * self.myng_q[q]
413        soffset_r[N:] = 0
414        roffset_r = (np.arange(comm.size) * self.maxmyng).clip(max=ng)
415        rsize_r = np.zeros(comm.size, int)
416        if comm.rank < N:
417            rsize_r[:-1] = roffset_r[1:] - roffset_r[:-1]
418            rsize_r[-1] = ng - roffset_r[-1]
419        b_G = self.tmp_G[:ng]
420        comm.alltoallv(a_rG, ssize_r, soffset_r, b_G, rsize_r, roffset_r)
421        if comm.rank < N:
422            return b_G
423
424    def alltoall2(self, a_G, q, b_rG):
425        """Scatter all coefs. from rank r to B_rG[r] on other cores."""
426        comm = self.gd.comm
427        if comm.size == 1:
428            b_rG[0] += a_G
429            return
430        N = len(b_rG)
431        ng = self.ng_q[q]
432        rsize_r = np.zeros(comm.size, int)
433        rsize_r[:N] = self.myng_q[q]
434        roffset_r = np.arange(comm.size) * self.myng_q[q]
435        roffset_r[N:] = 0
436        soffset_r = (np.arange(comm.size) * self.maxmyng).clip(max=ng)
437        ssize_r = np.zeros(comm.size, int)
438        if comm.rank < N:
439            ssize_r[:-1] = soffset_r[1:] - soffset_r[:-1]
440            ssize_r[-1] = ng - soffset_r[-1]
441        tmp_rG = self.tmp_G[:b_rG.size].reshape(b_rG.shape)
442        comm.alltoallv(a_G, ssize_r, soffset_r, tmp_rG, rsize_r, roffset_r)
443        b_rG += tmp_rG
444
445    def integrate(self, a_xg, b_yg=None,
446                  global_integral=True, hermitian=False):
447        """Integrate function(s) over domain.
448
449        a_xg: ndarray
450            Function(s) to be integrated.
451        b_yg: ndarray
452            If present, integrate a_xg.conj() * b_yg.
453        global_integral: bool
454            If the array(s) are distributed over several domains, then the
455            total sum will be returned.  To get the local contribution
456            only, use global_integral=False.
457        hermitian: bool
458            Result is hermitian.
459        """
460
461        if b_yg is None:
462            # Only one array:
463            assert self.dtype == float and self.gd.comm.size == 1
464            return a_xg[..., 0].real * self.gd.dv
465
466        if a_xg.ndim == 1:
467            A_xg = a_xg.reshape((1, len(a_xg)))
468        else:
469            A_xg = a_xg
470        if b_yg.ndim == 1:
471            B_yg = b_yg.reshape((1, len(b_yg)))
472        else:
473            B_yg = b_yg
474
475        alpha = self.gd.dv / self.gd.N_c.prod()
476
477        if self.dtype == float:
478            alpha *= 2
479            A_xg = A_xg.view(float)
480            B_yg = B_yg.view(float)
481
482        result_yx = np.zeros((len(B_yg), len(A_xg)), self.dtype)
483
484        if a_xg is b_yg:
485            rk(alpha, A_xg, 0.0, result_yx)
486        elif hermitian:
487            r2k(0.5 * alpha, A_xg, B_yg, 0.0, result_yx)
488        else:
489            mmm(alpha, B_yg, 'N', A_xg, 'C', 0.0, result_yx)
490
491        if self.dtype == float and self.gd.comm.rank == 0:
492            correction_yx = np.outer(B_yg[:, 0], A_xg[:, 0])
493            if hermitian:
494                result_yx -= 0.25 * alpha * (correction_yx + correction_yx.T)
495            else:
496                result_yx -= 0.5 * alpha * correction_yx
497
498        xshape = a_xg.shape[:-1]
499        yshape = b_yg.shape[:-1]
500        result = result_yx.T.reshape(xshape + yshape)
501
502        if result.ndim == 0:
503            if global_integral:
504                return self.gd.comm.sum(result.item())
505            return result.item()
506        else:
507            assert global_integral or self.gd.comm.size == 1
508            self.gd.comm.sum(result.T)
509            return result
510
511    def interpolate(self, a_R, pd):
512        if (pd.gd.N_c <= self.gd.N_c).any():
513            raise ValueError('Too few points in target grid!')
514
515        self.gd.collect(a_R, self.tmp_R[:])
516
517        if self.gd.comm.rank == 0:
518            self.fftplan.execute()
519
520            a_Q = self.tmp_Q
521            b_Q = pd.tmp_Q
522
523            e0, e1, e2 = 1 - self.gd.N_c % 2  # even or odd size
524            a0, a1, a2 = pd.gd.N_c // 2 - self.gd.N_c // 2
525            b0, b1, b2 = self.gd.N_c + (a0, a1, a2)
526
527            if self.dtype == float:
528                b2 = (b2 - a2) // 2 + 1
529                a2 = 0
530                axes = (0, 1)
531            else:
532                axes = (0, 1, 2)
533
534            b_Q[:] = 0.0
535            b_Q[a0:b0, a1:b1, a2:b2] = np.fft.fftshift(a_Q, axes=axes)
536
537            if e0:
538                b_Q[a0, a1:b1, a2:b2] *= 0.5
539                b_Q[b0, a1:b1, a2:b2] = b_Q[a0, a1:b1, a2:b2]
540                b0 += 1
541            if e1:
542                b_Q[a0:b0, a1, a2:b2] *= 0.5
543                b_Q[a0:b0, b1, a2:b2] = b_Q[a0:b0, a1, a2:b2]
544                b1 += 1
545            if self.dtype == complex:
546                if e2:
547                    b_Q[a0:b0, a1:b1, a2] *= 0.5
548                    b_Q[a0:b0, a1:b1, b2] = b_Q[a0:b0, a1:b1, a2]
549            else:
550                if e2:
551                    b_Q[a0:b0, a1:b1, b2 - 1] *= 0.5
552
553            b_Q[:] = np.fft.ifftshift(b_Q, axes=axes)
554            pd.ifftplan.execute()
555
556            a_G = a_Q.ravel()[self.Q_qG[0]]
557        else:
558            a_G = None
559
560        return (pd.gd.distribute(pd.tmp_R) * (1.0 / self.tmp_R.size),
561                self.scatter(a_G))
562
563    def restrict(self, a_R, pd):
564        self.gd.collect(a_R, self.tmp_R[:])
565
566        if self.gd.comm.rank == 0:
567            a_Q = pd.tmp_Q
568            b_Q = self.tmp_Q
569
570            e0, e1, e2 = 1 - pd.gd.N_c % 2  # even or odd size
571            a0, a1, a2 = self.gd.N_c // 2 - pd.gd.N_c // 2
572            b0, b1, b2 = pd.gd.N_c // 2 + self.gd.N_c // 2 + 1
573
574            if self.dtype == float:
575                b2 = pd.gd.N_c[2] // 2 + 1
576                a2 = 0
577                axes = (0, 1)
578            else:
579                axes = (0, 1, 2)
580
581            self.fftplan.execute()
582            b_Q[:] = np.fft.fftshift(b_Q, axes=axes)
583
584            if e0:
585                b_Q[a0, a1:b1, a2:b2] += b_Q[b0 - 1, a1:b1, a2:b2]
586                b_Q[a0, a1:b1, a2:b2] *= 0.5
587                b0 -= 1
588            if e1:
589                b_Q[a0:b0, a1, a2:b2] += b_Q[a0:b0, b1 - 1, a2:b2]
590                b_Q[a0:b0, a1, a2:b2] *= 0.5
591                b1 -= 1
592            if self.dtype == complex and e2:
593                b_Q[a0:b0, a1:b1, a2] += b_Q[a0:b0, a1:b1, b2 - 1]
594                b_Q[a0:b0, a1:b1, a2] *= 0.5
595                b2 -= 1
596
597            a_Q[:] = b_Q[a0:b0, a1:b1, a2:b2]
598            a_Q[:] = np.fft.ifftshift(a_Q, axes=axes)
599            a_G = a_Q.ravel()[pd.Q_qG[0]] / 8
600            pd.ifftplan.execute()
601        else:
602            a_G = None
603
604        return (pd.gd.distribute(pd.tmp_R) * (1.0 / self.tmp_R.size),
605                pd.scatter(a_G))
606
607
608class PWMapping:
609    def __init__(self, pd1, pd2):
610        """Mapping from pd1 to pd2."""
611        N_c = np.array(pd1.tmp_Q.shape)
612        N2_c = pd2.tmp_Q.shape
613        Q1_G = pd1.Q_qG[0]
614        Q1_Gc = np.empty((len(Q1_G), 3), int)
615        Q1_Gc[:, 0], r_G = divmod(Q1_G, N_c[1] * N_c[2])
616        Q1_Gc.T[1:] = divmod(r_G, N_c[2])
617        if pd1.dtype == float:
618            C = 2
619        else:
620            C = 3
621        Q1_Gc[:, :C] += N_c[:C] // 2
622        Q1_Gc[:, :C] %= N_c[:C]
623        Q1_Gc[:, :C] -= N_c[:C] // 2
624        Q1_Gc[:, :C] %= N2_c[:C]
625        Q2_G = Q1_Gc[:, 2] + N2_c[2] * (Q1_Gc[:, 1] + N2_c[1] * Q1_Gc[:, 0])
626        G2_Q = np.empty(N2_c, int).ravel()
627        G2_Q[:] = -1
628        G2_Q[pd2.myQ_qG[0]] = np.arange(pd2.myng_q[0])
629        G2_G1 = G2_Q[Q2_G]
630
631        if pd1.gd.comm.size == 1:
632            self.G2_G1 = G2_G1
633            self.G1 = None
634        else:
635            mask_G1 = (G2_G1 != -1)
636            self.G2_G1 = G2_G1[mask_G1]
637            self.G1 = np.arange(pd1.ngmax)[mask_G1]
638
639        self.pd1 = pd1
640        self.pd2 = pd2
641
642    def add_to1(self, a_G1, b_G2):
643        """Do a += b * scale, where a is on pd1 and b on pd2."""
644        scale = self.pd1.tmp_R.size / self.pd2.tmp_R.size
645
646        if self.pd1.gd.comm.size == 1:
647            a_G1 += b_G2[self.G2_G1] * scale
648            return
649
650        b_G1 = self.pd1.tmp_G
651        b_G1[:] = 0.0
652        b_G1[self.G1] = b_G2[self.G2_G1]
653        self.pd1.gd.comm.sum(b_G1)
654        ng1 = self.pd1.gd.comm.rank * self.pd1.maxmyng
655        ng2 = ng1 + self.pd1.myng_q[0]
656        a_G1 += b_G1[ng1:ng2] * scale
657
658    def add_to2(self, a_G2, b_G1):
659        """Do a += b * scale, where a is on pd2 and b on pd1."""
660        myb_G1 = b_G1 * (self.pd2.tmp_R.size / self.pd1.tmp_R.size)
661        if self.pd1.gd.comm.size == 1:
662            a_G2[self.G2_G1] += myb_G1
663            return
664
665        b_G1 = self.pd1.tmp_G
666        self.pd1.gd.comm.all_gather(pad(myb_G1, self.pd1.maxmyng), b_G1)
667        a_G2[self.G2_G1] += b_G1[self.G1]
668
669
670def count_reciprocal_vectors(ecut, gd, q_c):
671    assert gd.comm.size == 1
672    N_c = gd.N_c
673    i_Qc = np.indices(N_c).transpose((1, 2, 3, 0))
674    i_Qc += N_c // 2
675    i_Qc %= N_c
676    i_Qc -= N_c // 2
677
678    B_cv = 2.0 * pi * gd.icell_cv
679    i_Qc.shape = (-1, 3)
680    Gpq_Qv = np.dot(i_Qc, B_cv) + np.dot(q_c, B_cv)
681
682    G2_Q = (Gpq_Qv**2).sum(axis=1)
683    return (G2_Q <= 2 * ecut).sum()
684
685
686class Preconditioner:
687    """Preconditioner for KS equation.
688
689    From:
690
691      Teter, Payne and Allen, Phys. Rev. B 40, 12255 (1989)
692
693    as modified by:
694
695      Kresse and Furthmüller, Phys. Rev. B 54, 11169 (1996)
696    """
697
698    def __init__(self, G2_qG, pd):
699        self.G2_qG = G2_qG
700        self.pd = pd
701
702    def calculate_kinetic_energy(self, psit_xG, kpt):
703        if psit_xG.ndim == 1:
704            return self.calculate_kinetic_energy(psit_xG[np.newaxis], kpt)[0]
705        G2_G = self.G2_qG[kpt.q]
706        return np.array([self.pd.integrate(0.5 * G2_G * psit_G, psit_G).real
707                         for psit_G in psit_xG])
708
709    def __call__(self, R_xG, kpt, ekin_x, out=None):
710        if out is None:
711            out = np.empty_like(R_xG)
712        G2_G = self.G2_qG[kpt.q]
713        if R_xG.ndim == 1:
714            _gpaw.pw_precond(G2_G, R_xG, ekin_x, out)
715        else:
716            for PR_G, R_G, ekin in zip(out, R_xG, ekin_x):
717                _gpaw.pw_precond(G2_G, R_G, ekin, PR_G)
718        return out
719
720
721class NonCollinearPreconditioner(Preconditioner):
722    def calculate_kinetic_energy(self, psit_xsG, kpt):
723        shape = psit_xsG.shape
724        ekin_xs = Preconditioner.calculate_kinetic_energy(
725            self, psit_xsG.reshape((-1, shape[-1])), kpt)
726        return ekin_xs.reshape(shape[:-1]).sum(-1)
727
728    def __call__(self, R_sG, kpt, ekin, out=None):
729        return Preconditioner.__call__(self, R_sG, kpt, [ekin, ekin], out)
730
731
732class PWWaveFunctions(FDPWWaveFunctions):
733    mode = 'pw'
734
735    def __init__(self, ecut, gammacentered, fftwflags, dedepsilon,
736                 parallel, initksl,
737                 reuse_wfs_method, collinear,
738                 gd, nvalence, setups, bd, dtype,
739                 world, kd, kptband_comm, timer):
740        self.ecut = ecut
741        self.gammacentered = gammacentered
742        self.fftwflags = fftwflags
743        self.dedepsilon = dedepsilon  # Pulay correction for stress tensor
744
745        self.ng_k = None  # number of G-vectors for all IBZ k-points
746
747        FDPWWaveFunctions.__init__(self, parallel, initksl,
748                                   reuse_wfs_method=reuse_wfs_method,
749                                   collinear=collinear,
750                                   gd=gd, nvalence=nvalence, setups=setups,
751                                   bd=bd, dtype=dtype, world=world, kd=kd,
752                                   kptband_comm=kptband_comm, timer=timer)
753
754    def empty(self, n=(), global_array=False, realspace=False, q=None):
755        if isinstance(n, numbers.Integral):
756            n = (n,)
757        if realspace:
758            return self.gd.empty(n, self.dtype, global_array)
759        elif global_array:
760            return np.zeros(n + (self.pd.ngmax,), complex)
761        elif q is None:
762            return np.zeros(n + (self.pd.maxmyng,), complex)
763        else:
764            return self.pd.empty(n, self.dtype, q)
765
766    def integrate(self, a_xg, b_yg=None, global_integral=True):
767        return self.pd.integrate(a_xg, b_yg, global_integral)
768
769    def bytes_per_wave_function(self):
770        return 16 * self.pd.ngmax
771
772    def set_setups(self, setups):
773        self.timer.start('PWDescriptor')
774        self.pd = PWDescriptor(self.ecut, self.gd, self.dtype, self.kd,
775                               self.fftwflags, self.gammacentered)
776        self.timer.stop('PWDescriptor')
777
778        # Build array of number of plane wave coefficiants for all k-points
779        # in the IBZ:
780        self.ng_k = np.zeros(self.kd.nibzkpts, dtype=int)
781        for kpt in self.kpt_u:
782            if kpt.s != 1:  # avoid double counting (only sum over s=0 or None)
783                self.ng_k[kpt.k] = len(self.pd.Q_qG[kpt.q])
784        self.kd.comm.sum(self.ng_k)
785
786        self.pt = PWLFC([setup.pt_j for setup in setups], self.pd)
787
788        FDPWWaveFunctions.set_setups(self, setups)
789
790        if self.dedepsilon == 'estimate':
791            dedecut = self.setups.estimate_dedecut(self.ecut)
792            self.dedepsilon = dedecut * 2 / 3 * self.ecut
793
794    def get_pseudo_partial_waves(self):
795        return PWLFC([setup.get_partial_waves_for_atomic_orbitals()
796                      for setup in self.setups], self.pd)
797
798    def __str__(self):
799        s = 'Wave functions: Plane wave expansion\n'
800        s += '  Cutoff energy: %.3f eV\n' % (self.pd.ecut * Ha)
801
802        if self.dtype == float:
803            s += ('  Number of coefficients: %d (reduced to %d)\n' %
804                  (self.pd.ngmax * 2 - 1, self.pd.ngmax))
805        else:
806            s += ('  Number of coefficients (min, max): %d, %d\n' %
807                  (self.pd.ngmin, self.pd.ngmax))
808
809        stress = self.dedepsilon / self.gd.volume * Ha / Bohr**3
810        dedecut = 1.5 * self.dedepsilon / self.ecut
811        s += ('  Pulay-stress correction: {:.6f} eV/Ang^3 '
812              '(de/decut={:.6f})\n'.format(stress, dedecut))
813
814        if fftw.FFTPlan is fftw.NumpyFFTPlan:
815            s += "  Using Numpy's FFT\n"
816        else:
817            s += '  Using FFTW library\n'
818        return s + FDPWWaveFunctions.__str__(self)
819
820    def make_preconditioner(self, block=1):
821        if self.collinear:
822            return Preconditioner(self.pd.G2_qG, self.pd)
823        return NonCollinearPreconditioner(self.pd.G2_qG, self.pd)
824
825    @timer('Apply H')
826    def apply_pseudo_hamiltonian(self, kpt, ham, psit_xG, Htpsit_xG):
827        """Apply the pseudo Hamiltonian i.e. without PAW corrections."""
828        if not self.collinear:
829            self.apply_pseudo_hamiltonian_nc(kpt, ham, psit_xG, Htpsit_xG)
830            return
831
832        N = len(psit_xG)
833        S = self.gd.comm.size
834
835        vt_R = self.gd.collect(ham.vt_sG[kpt.s], broadcast=True)
836        Q_G = self.pd.Q_qG[kpt.q]
837        T_G = 0.5 * self.pd.G2_qG[kpt.q]
838
839        for n1 in range(0, N, S):
840            n2 = min(n1 + S, N)
841            psit_G = self.pd.alltoall1(psit_xG[n1:n2], kpt.q)
842            with self.timer('HMM T'):
843                np.multiply(T_G, psit_xG[n1:n2], Htpsit_xG[n1:n2])
844            if psit_G is not None:
845                psit_R = self.pd.ifft(psit_G, kpt.q, local=True, safe=False)
846                psit_R *= vt_R
847                self.pd.fftplan.execute()
848                vtpsit_G = self.pd.tmp_Q.ravel()[Q_G]
849            else:
850                vtpsit_G = self.pd.tmp_G
851            self.pd.alltoall2(vtpsit_G, kpt.q, Htpsit_xG[n1:n2])
852
853        ham.xc.apply_orbital_dependent_hamiltonian(
854            kpt, psit_xG, Htpsit_xG, ham.dH_asp)
855
856    def apply_pseudo_hamiltonian_nc(self, kpt, ham, psit_xG, Htpsit_xG):
857        Htpsit_xG[:] = 0.5 * self.pd.G2_qG[kpt.q] * psit_xG
858        v, x, y, z = ham.vt_xG
859        iy = y * 1j
860        for psit_sG, Htpsit_sG in zip(psit_xG, Htpsit_xG):
861            a = self.pd.ifft(psit_sG[0], kpt.q)
862            b = self.pd.ifft(psit_sG[1], kpt.q)
863            Htpsit_sG[0] += self.pd.fft(a * (v + z) + b * (x - iy), kpt.q)
864            Htpsit_sG[1] += self.pd.fft(a * (x + iy) + b * (v - z), kpt.q)
865
866    def add_orbital_density(self, nt_G, kpt, n):
867        axpy(1.0, abs(self.pd.ifft(kpt.psit_nG[n], kpt.q))**2, nt_G)
868
869    def add_to_density_from_k_point_with_occupation(self, nt_xR, kpt, f_n):
870        if not self.collinear:
871            self.add_to_density_from_k_point_with_occupation_nc(
872                nt_xR, kpt, f_n)
873            return
874
875        comm = self.gd.comm
876
877        nt_R = self.gd.zeros(global_array=True)
878
879        for n1 in range(0, self.bd.mynbands, comm.size):
880            n2 = min(n1 + comm.size, self.bd.mynbands)
881            psit_G = self.pd.alltoall1(kpt.psit.array[n1:n2], kpt.q)
882            if psit_G is not None:
883                f = f_n[n1 + comm.rank]
884                psit_R = self.pd.ifft(psit_G, kpt.q, local=True, safe=False)
885                # Same as nt_R += f * abs(psit_R)**2, but much faster:
886                _gpaw.add_to_density(f, psit_R, nt_R)
887
888        comm.sum(nt_R)
889        nt_R = self.gd.distribute(nt_R)
890        nt_xR[kpt.s] += nt_R
891
892    def add_to_density_from_k_point_with_occupation_nc(self, nt_xR, kpt, f_n):
893        for f, psit_sG in zip(f_n, kpt.psit.array):
894            p1 = self.pd.ifft(psit_sG[0], kpt.q)
895            p2 = self.pd.ifft(psit_sG[1], kpt.q)
896            p11 = p1.real**2 + p1.imag**2
897            p22 = p2.real**2 + p2.imag**2
898            p12 = p1.conj() * p2
899            nt_xR[0] += f * (p11 + p22)
900            nt_xR[1] += 2 * f * p12.real
901            nt_xR[2] += 2 * f * p12.imag
902            nt_xR[3] += f * (p11 - p22)
903
904    def calculate_kinetic_energy_density(self):
905        if self.kpt_u[0].f_n is None:
906            return None
907
908        taut_sR = self.gd.zeros(self.nspins)
909        for kpt in self.kpt_u:
910            G_Gv = self.pd.get_reciprocal_vectors(q=kpt.q)
911            for f, psit_G in zip(kpt.f_n, kpt.psit_nG):
912                for v in range(3):
913                    taut_sR[kpt.s] += 0.5 * f * abs(
914                        self.pd.ifft(1j * G_Gv[:, v] * psit_G, kpt.q))**2
915
916        self.kptband_comm.sum(taut_sR)
917        return taut_sR
918
919    def apply_mgga_orbital_dependent_hamiltonian(self, kpt, psit_xG,
920                                                 Htpsit_xG, dH_asp,
921                                                 dedtaut_R):
922        G_Gv = self.pd.get_reciprocal_vectors(q=kpt.q)
923        for psit_G, Htpsit_G in zip(psit_xG, Htpsit_xG):
924            for v in range(3):
925                a_R = self.pd.ifft(1j * G_Gv[:, v] * psit_G, kpt.q)
926                axpy(-0.5, 1j * G_Gv[:, v] *
927                     self.pd.fft(dedtaut_R * a_R, kpt.q),
928                     Htpsit_G)
929
930    def _get_wave_function_array(self, u, n, realspace=True, periodic=False):
931        kpt = self.kpt_u[u]
932        psit_G = kpt.psit_nG[n]
933
934        if realspace:
935            psit_R = self.pd.ifft(psit_G, kpt.q)
936            if self.kd.gamma or periodic:
937                return psit_R
938
939            k_c = self.kd.ibzk_kc[kpt.k]
940            eikr_R = self.gd.plane_wave(k_c)
941            return psit_R * eikr_R
942
943        return psit_G
944
945    def get_wave_function_array(self, n, k, s, realspace=True,
946                                cut=True, periodic=False):
947        kpt_rank, q = self.kd.get_rank_and_index(k)
948        u = q * self.nspins + s
949        band_rank, myn = self.bd.who_has(n)
950
951        rank = self.world.rank
952        if (self.kd.comm.rank == kpt_rank and
953            self.bd.comm.rank == band_rank):
954            psit_G = self._get_wave_function_array(u, myn, realspace, periodic)
955
956            if realspace:
957                psit_G = self.gd.collect(psit_G)
958            else:
959                assert not cut
960                tmp_G = self.pd.gather(psit_G, self.kpt_u[u].q)
961                if tmp_G is not None:
962                    ng = self.pd.ngmax
963                    if self.collinear:
964                        psit_G = np.zeros(ng, complex)
965                    else:
966                        psit_G = np.zeros((2, ng), complex)
967                    psit_G[..., :tmp_G.shape[-1]] = tmp_G
968
969            if rank == 0:
970                return psit_G
971
972            # Domain master send this to the global master
973            if self.gd.comm.rank == 0:
974                self.world.ssend(psit_G, 0, 1398)
975
976        if rank == 0:
977            # allocate full wave function and receive
978            shape = () if self.collinear else (2,)
979            psit_G = self.empty(shape, global_array=True,
980                                realspace=realspace)
981            # XXX this will fail when using non-standard nesting
982            # of communicators.
983            world_rank = (kpt_rank * self.gd.comm.size *
984                          self.bd.comm.size +
985                          band_rank * self.gd.comm.size)
986            self.world.receive(psit_G, world_rank, 1398)
987            return psit_G
988
989        # We return a number instead of None on all the slaves.  Most of
990        # the time the return value will be ignored on the slaves, but
991        # in some cases it will be multiplied by some other number and
992        # then ignored.  Allowing for this will simplify some code here
993        # and there.
994        return np.nan
995
996    def write(self, writer, write_wave_functions=False):
997        FDPWWaveFunctions.write(self, writer)
998
999        if not write_wave_functions:
1000            return
1001
1002        if self.collinear:
1003            shape = (self.nspins,
1004                     self.kd.nibzkpts, self.bd.nbands, self.pd.ngmax)
1005        else:
1006            shape = (self.kd.nibzkpts, self.bd.nbands, 2, self.pd.ngmax)
1007
1008        writer.add_array('coefficients', shape, complex)
1009
1010        c = Bohr**-1.5
1011        for s in range(self.nspins):
1012            for k in range(self.kd.nibzkpts):
1013                for n in range(self.bd.nbands):
1014                    psit_G = self.get_wave_function_array(n, k, s,
1015                                                          realspace=False,
1016                                                          cut=False)
1017                    writer.fill(psit_G * c)
1018
1019        writer.add_array('indices', (self.kd.nibzkpts, self.pd.ngmax),
1020                         np.int32)
1021
1022        if self.bd.comm.rank > 0:
1023            return
1024
1025        Q_G = np.empty(self.pd.ngmax, np.int32)
1026        kk = 0
1027        for r in range(self.kd.comm.size):
1028            for q, k in enumerate(self.kd.get_indices(r)):
1029                ng = self.ng_k[k]
1030                if r == self.kd.comm.rank:
1031                    Q_G[:ng] = self.pd.Q_qG[q]
1032                    if r > 0:
1033                        self.kd.comm.send(Q_G, 0)
1034                if self.kd.comm.rank == 0:
1035                    if r > 0:
1036                        self.kd.comm.receive(Q_G, r)
1037                    Q_G[ng:] = -1
1038                    writer.fill(Q_G)
1039                    assert k == kk
1040                    kk += 1
1041
1042    def read(self, reader):
1043        FDPWWaveFunctions.read(self, reader)
1044
1045        if 'coefficients' not in reader.wave_functions:
1046            return
1047
1048        Q_kG = reader.wave_functions.indices
1049        for kpt in self.kpt_u:
1050            if kpt.s == 0:
1051                Q_G = Q_kG[kpt.k]
1052                ng = self.ng_k[kpt.k]
1053                assert (Q_G[:ng] == self.pd.Q_qG[kpt.q]).all()
1054                assert (Q_G[ng:] == -1).all()
1055
1056        c = reader.bohr**1.5
1057        if reader.version < 0:
1058            c = 1  # old gpw file
1059        for kpt in self.kpt_u:
1060            ng = self.ng_k[kpt.k]
1061            index = (kpt.s, kpt.k) if self.collinear else (kpt.k,)
1062            psit_nG = reader.wave_functions.proxy('coefficients', *index)
1063            psit_nG.scale = c
1064            psit_nG.length_of_last_dimension = ng
1065
1066            kpt.psit = PlaneWaveExpansionWaveFunctions(
1067                self.bd.nbands, self.pd, self.dtype, psit_nG,
1068                kpt=kpt.q, dist=(self.bd.comm, self.bd.comm.size),
1069                spin=kpt.s, collinear=self.collinear)
1070
1071        if self.world.size > 1:
1072            # Read to memory:
1073            for kpt in self.kpt_u:
1074                kpt.psit.read_from_file()
1075
1076    def hs(self, ham, q=-1, s=0, md=None):
1077        npw = len(self.pd.Q_qG[q])
1078        N = self.pd.tmp_R.size
1079
1080        if md is None:
1081            H_GG = np.zeros((npw, npw), complex)
1082            S_GG = np.zeros((npw, npw), complex)
1083            G1 = 0
1084            G2 = npw
1085        else:
1086            H_GG = md.zeros(dtype=complex)
1087            S_GG = md.zeros(dtype=complex)
1088            if S_GG.size == 0:
1089                return H_GG, S_GG
1090            G1, G2 = next(md.my_blocks(S_GG))[:2]
1091
1092        H_GG.ravel()[G1::npw + 1] = (0.5 * self.pd.gd.dv / N *
1093                                     self.pd.G2_qG[q][G1:G2])
1094        for G in range(G1, G2):
1095            x_G = self.pd.zeros(q=q)
1096            x_G[G] = 1.0
1097            H_GG[G - G1] += (self.pd.gd.dv / N *
1098                             self.pd.fft(ham.vt_sG[s] *
1099                                         self.pd.ifft(x_G, q), q))
1100
1101        S_GG.ravel()[G1::npw + 1] = self.pd.gd.dv / N
1102
1103        f_GI = self.pt.expand(q)
1104        nI = f_GI.shape[1]
1105        dH_II = np.zeros((nI, nI))
1106        dS_II = np.zeros((nI, nI))
1107        I1 = 0
1108        for a in self.pt.my_atom_indices:
1109            dH_ii = unpack(ham.dH_asp[a][s])
1110            dS_ii = self.setups[a].dO_ii
1111            I2 = I1 + len(dS_ii)
1112            dH_II[I1:I2, I1:I2] = dH_ii / N**2
1113            dS_II[I1:I2, I1:I2] = dS_ii / N**2
1114            I1 = I2
1115
1116        H_GG += np.dot(f_GI[G1:G2].conj(), np.dot(dH_II, f_GI.T))
1117        S_GG += np.dot(f_GI[G1:G2].conj(), np.dot(dS_II, f_GI.T))
1118
1119        return H_GG, S_GG
1120
1121    @timer('Full diag')
1122    def diagonalize_full_hamiltonian(self, ham, atoms, log,
1123                                     nbands=None, ecut=None, scalapack=None,
1124                                     expert=False):
1125
1126        if self.dtype != complex:
1127            raise ValueError(
1128                'Please use mode=PW(..., force_complex_dtype=True)')
1129
1130        if self.gd.comm.size > 1:
1131            raise ValueError(
1132                "Please use parallel={'domain': 1}")
1133
1134        S = self.bd.comm.size
1135
1136        if nbands is None and ecut is None:
1137            nbands = self.pd.ngmin // S * S
1138        elif nbands is None:
1139            ecut /= Ha
1140            vol = abs(np.linalg.det(self.gd.cell_cv))
1141            nbands = int(vol * ecut**1.5 * 2**0.5 / 3 / pi**2)
1142
1143        if nbands % S != 0:
1144            nbands += S - nbands % S
1145
1146        assert nbands <= self.pd.ngmin
1147
1148        if expert:
1149            iu = nbands
1150        else:
1151            iu = None
1152
1153        self.bd = bd = BandDescriptor(nbands, self.bd.comm)
1154        self.occupations.bd = bd
1155
1156        log('Diagonalizing full Hamiltonian ({} lowest bands)'.format(nbands))
1157        log('Matrix size (min, max): {}, {}'.format(self.pd.ngmin,
1158                                                    self.pd.ngmax))
1159        mem = 3 * self.pd.ngmax**2 * 16 / S / 1024**2
1160        log('Approximate memory used per core to store H_GG, S_GG: {:.3f} MB'
1161            .format(mem))
1162        log('Notice: Up to twice the amount of memory might be allocated\n'
1163            'during diagonalization algorithm.')
1164        log('The least memory is required when the parallelization is purely\n'
1165            'over states (bands) and not k-points, set '
1166            "GPAW(..., parallel={'kpt': 1}, ...).")
1167
1168        if S > 1:
1169            if isinstance(scalapack, (list, tuple)):
1170                nprow, npcol, b = scalapack
1171                assert nprow * npcol == S, (nprow, npcol, S)
1172            else:
1173                nprow = int(round(S**0.5))
1174                while S % nprow != 0:
1175                    nprow -= 1
1176                npcol = S // nprow
1177                b = 64
1178            log('ScaLapack grid: {}x{},'.format(nprow, npcol),
1179                'block-size:', b)
1180            bg = BlacsGrid(bd.comm, S, 1)
1181            bg2 = BlacsGrid(bd.comm, nprow, npcol)
1182            scalapack = True
1183        else:
1184            scalapack = False
1185
1186        self.set_positions(atoms.get_scaled_positions())
1187        self.kpt_u[0].projections = None
1188        self.allocate_arrays_for_projections(self.pt.my_atom_indices)
1189
1190        myslice = bd.get_slice()
1191
1192        pb = ProgressBar(log.fd)
1193        nkpt = len(self.kpt_u)
1194
1195        for u, kpt in enumerate(self.kpt_u):
1196            pb.update(u / nkpt)
1197            npw = len(self.pd.Q_qG[kpt.q])
1198            if scalapack:
1199                mynpw = -(-npw // S)
1200                md = BlacsDescriptor(bg, npw, npw, mynpw, npw)
1201                md2 = BlacsDescriptor(bg2, npw, npw, b, b)
1202            else:
1203                md = md2 = MatrixDescriptor(npw, npw)
1204
1205            with self.timer('Build H and S'):
1206                H_GG, S_GG = self.hs(ham, kpt.q, kpt.s, md)
1207
1208            if scalapack:
1209                r = Redistributor(bd.comm, md, md2)
1210                H_GG = r.redistribute(H_GG)
1211                S_GG = r.redistribute(S_GG)
1212
1213            psit_nG = md2.empty(dtype=complex)
1214            eps_n = np.empty(npw)
1215
1216            with self.timer('Diagonalize'):
1217                if not scalapack:
1218                    md2.general_diagonalize_dc(H_GG, S_GG, psit_nG, eps_n,
1219                                               iu=iu)
1220                else:
1221                    md2.general_diagonalize_dc(H_GG, S_GG, psit_nG, eps_n)
1222            del H_GG, S_GG
1223
1224            kpt.eps_n = eps_n[myslice].copy()
1225
1226            if scalapack:
1227                md3 = BlacsDescriptor(bg, npw, npw, bd.maxmynbands, npw)
1228                r = Redistributor(bd.comm, md2, md3)
1229                psit_nG = r.redistribute(psit_nG)
1230
1231            kpt.psit = PlaneWaveExpansionWaveFunctions(
1232                self.bd.nbands, self.pd, self.dtype,
1233                psit_nG[:bd.mynbands].copy(),
1234                kpt=kpt.q, dist=(self.bd.comm, self.bd.comm.size),
1235                spin=kpt.s, collinear=self.collinear)
1236            del psit_nG
1237
1238            with self.timer('Projections'):
1239                self.pt.integrate(kpt.psit_nG, kpt.P_ani, kpt.q)
1240
1241            kpt.f_n = None
1242
1243        pb.finish()
1244
1245        self.calculate_occupation_numbers()
1246
1247        return nbands
1248
1249    def initialize_from_lcao_coefficients(self, basis_functions):
1250        psit_nR = self.gd.empty(1, self.dtype)
1251
1252        for kpt in self.kpt_u:
1253            if self.kd.gamma:
1254                emikr_R = 1.0
1255            else:
1256                k_c = self.kd.ibzk_kc[kpt.k]
1257                emikr_R = self.gd.plane_wave(-k_c)
1258            kpt.psit = PlaneWaveExpansionWaveFunctions(
1259                self.bd.nbands, self.pd, self.dtype, kpt=kpt.q,
1260                dist=(self.bd.comm, -1, 1),
1261                spin=kpt.s, collinear=self.collinear)
1262            psit_nG = kpt.psit.array
1263            if psit_nG.ndim == 3:
1264                N, S, G = psit_nG.shape
1265                psit_nG = psit_nG.reshape((N * S, G))
1266            for n, psit_G in enumerate(psit_nG):
1267                psit_nR[:] = 0.0
1268                basis_functions.lcao_to_grid(kpt.C_nM[n:n + 1],
1269                                             psit_nR, kpt.q)
1270                psit_G[:] = self.pd.fft(psit_nR[0] * emikr_R, kpt.q)
1271            kpt.C_nM = None
1272
1273    def random_wave_functions(self, mynao):
1274        rs = np.random.RandomState(self.world.rank)
1275        for kpt in self.kpt_u:
1276            if kpt.psit is None:
1277                kpt.psit = PlaneWaveExpansionWaveFunctions(
1278                    self.bd.nbands, self.pd, self.dtype, kpt=kpt.q,
1279                    dist=(self.bd.comm, -1, 1),
1280                    spin=kpt.s, collinear=self.collinear)
1281
1282            array = kpt.psit.array[mynao:]
1283            weight_G = 1.0 / (1.0 + self.pd.G2_qG[kpt.q])
1284            array.real = rs.uniform(-1, 1, array.shape) * weight_G
1285            array.imag = rs.uniform(-1, 1, array.shape) * weight_G
1286
1287    def estimate_memory(self, mem):
1288        FDPWWaveFunctions.estimate_memory(self, mem)
1289        self.pd.estimate_memory(mem.subnode('PW-descriptor'))
1290
1291    def get_kinetic_stress(self):
1292        sigma_vv = np.zeros((3, 3), dtype=complex)
1293        pd = self.pd
1294        dOmega = pd.gd.dv / pd.gd.N_c.prod()
1295        if pd.dtype == float:
1296            dOmega *= 2
1297        K_qv = self.pd.K_qv
1298        for kpt in self.kpt_u:
1299            G_Gv = pd.get_reciprocal_vectors(q=kpt.q, add_q=False)
1300            psit2_G = 0.0
1301            for n, f in enumerate(kpt.f_n):
1302                psit2_G += f * np.abs(kpt.psit_nG[n])**2
1303            for alpha in range(3):
1304                Ga_G = G_Gv[:, alpha] + K_qv[kpt.q, alpha]
1305                for beta in range(3):
1306                    Gb_G = G_Gv[:, beta] + K_qv[kpt.q, beta]
1307                    sigma_vv[alpha, beta] += (psit2_G * Ga_G * Gb_G).sum()
1308
1309        sigma_vv *= -dOmega
1310        self.world.sum(sigma_vv)
1311        return sigma_vv
1312
1313
1314def ft(spline):
1315    l = spline.get_angular_momentum_number()
1316    rc = 50.0
1317    N = 2**10
1318    assert spline.get_cutoff() <= rc
1319
1320    dr = rc / N
1321    r_r = np.arange(N) * dr
1322    dk = pi / 2 / rc
1323    k_q = np.arange(2 * N) * dk
1324    f_r = spline.map(r_r) * (4 * pi)
1325
1326    f_q = fbt(l, f_r, r_r, k_q)
1327    f_q[1:] /= k_q[1:]**(2 * l + 1)
1328    f_q[0] = (np.dot(f_r, r_r**(2 + 2 * l)) *
1329              dr * 2**l * fac(l) / fac(2 * l + 1))
1330
1331    return Spline(l, k_q[-1], f_q)
1332
1333
1334class PWLFC(BaseLFC):
1335    def __init__(self, spline_aj, pd, blocksize=5000, comm=None):
1336        """Reciprocal-space plane-wave localized function collection.
1337
1338        spline_aj: list of list of spline objects
1339            Splines.
1340        pd: PWDescriptor
1341            Plane-wave descriptor object.
1342        blocksize: int
1343            Block-size to use when looping over G-vectors.  Use None for
1344            doing all G-vectors in one big block.
1345        comm: communicator
1346            Communicator for operations that support parallelization
1347            over planewaves (only integrate so far)."""
1348
1349        self.pd = pd
1350        self.spline_aj = spline_aj
1351
1352        self.dtype = pd.dtype
1353
1354        self.initialized = False
1355
1356        # These will be filled in later:
1357        self.Y_qGL = []
1358        self.emiGR_qGa = []
1359        self.f_qGs = []
1360        self.l_s = None
1361        self.a_J = None
1362        self.s_J = None
1363        self.lmax = None
1364
1365        if blocksize is not None:
1366            if pd.ngmax <= blocksize:
1367                # No need to block G-vectors
1368                blocksize = None
1369        self.blocksize = blocksize
1370
1371        # These are set later in set_potitions():
1372        self.eikR_qa = None
1373        self.my_atom_indices = None
1374        self.my_indices = None
1375        self.pos_av = None
1376        self.nI = None
1377
1378        if comm is None:
1379            comm = pd.gd.comm
1380        else:
1381            assert False
1382        self.comm = comm
1383
1384    def initialize(self):
1385        """Initialize position-independent stuff."""
1386        if self.initialized:
1387            return
1388
1389        splines = {}  # Dict[Spline, int]
1390        for spline_j in self.spline_aj:
1391            for spline in spline_j:
1392                if spline not in splines:
1393                    splines[spline] = len(splines)
1394        nsplines = len(splines)
1395
1396        nJ = sum(len(spline_j) for spline_j in self.spline_aj)
1397
1398        self.f_qGs = [np.empty((mynG, nsplines)) for mynG in self.pd.myng_q]
1399        self.l_s = np.empty(nsplines, np.int32)
1400        self.a_J = np.empty(nJ, np.int32)
1401        self.s_J = np.empty(nJ, np.int32)
1402
1403        # Fourier transform radial functions:
1404        J = 0
1405        done = set()  # Set[Spline]
1406        for a, spline_j in enumerate(self.spline_aj):
1407            for spline in spline_j:
1408                s = splines[spline]  # get spline index
1409                if spline not in done:
1410                    f = ft(spline)
1411                    for f_Gs, G2_G in zip(self.f_qGs, self.pd.G2_qG):
1412                        G_G = G2_G**0.5
1413                        f_Gs[:, s] = f.map(G_G)
1414                    self.l_s[s] = spline.get_angular_momentum_number()
1415                    done.add(spline)
1416                self.a_J[J] = a
1417                self.s_J[J] = s
1418                J += 1
1419
1420        self.lmax = max(self.l_s, default=-1)
1421
1422        # Spherical harmonics:
1423        for q, K_v in enumerate(self.pd.K_qv):
1424            G_Gv = self.pd.get_reciprocal_vectors(q=q)
1425            Y_GL = np.empty((len(G_Gv), (self.lmax + 1)**2))
1426            for L in range((self.lmax + 1)**2):
1427                Y_GL[:, L] = Y(L, *G_Gv.T)
1428            self.Y_qGL.append(Y_GL)
1429
1430        self.initialized = True
1431
1432    def estimate_memory(self, mem):
1433        splines = set()
1434        lmax = -1
1435        for spline_j in self.spline_aj:
1436            for spline in spline_j:
1437                splines.add(spline)
1438                l = spline.get_angular_momentum_number()
1439                lmax = max(lmax, l)
1440        nbytes = ((len(splines) + (lmax + 1)**2) *
1441                  sum(G2_G.nbytes for G2_G in self.pd.G2_qG))
1442        mem.subnode('Arrays', nbytes)
1443
1444    def get_function_count(self, a):
1445        return sum(2 * spline.get_angular_momentum_number() + 1
1446                   for spline in self.spline_aj[a])
1447
1448    def set_positions(self, spos_ac, atom_partition=None):
1449        self.initialize()
1450        kd = self.pd.kd
1451        if kd is None or kd.gamma:
1452            self.eikR_qa = np.ones((1, len(spos_ac)))
1453        else:
1454            self.eikR_qa = np.exp(2j * pi * np.dot(kd.ibzk_qc, spos_ac.T))
1455
1456        self.pos_av = np.dot(spos_ac, self.pd.gd.cell_cv)
1457
1458        del self.emiGR_qGa[:]
1459        G_Qv = self.pd.G_Qv
1460        for Q_G in self.pd.myQ_qG:
1461            GR_Ga = np.dot(G_Qv[Q_G], self.pos_av.T)
1462            self.emiGR_qGa.append(np.exp(-1j * GR_Ga))
1463
1464        if atom_partition is None:
1465            assert self.comm.size == 1
1466            rank_a = np.zeros(len(spos_ac), int)
1467        else:
1468            rank_a = atom_partition.rank_a
1469
1470        self.my_atom_indices = []
1471        self.my_indices = []
1472        I1 = 0
1473        for a, rank in enumerate(rank_a):
1474            I2 = I1 + self.get_function_count(a)
1475            if rank == self.comm.rank:
1476                self.my_atom_indices.append(a)
1477                self.my_indices.append((a, I1, I2))
1478            I1 = I2
1479        self.nI = I1
1480
1481    def expand(self, q=-1, G1=0, G2=None, cc=False):
1482        """Expand functions in plane-waves.
1483
1484        q: int
1485            k-point index.
1486        G1: int
1487            Start G-vector index.
1488        G2: int
1489            End G-vector index.
1490        cc: bool
1491            Complex conjugate.
1492        """
1493        if G2 is None:
1494            G2 = self.Y_qGL[q].shape[0]
1495
1496        emiGR_Ga = self.emiGR_qGa[q][G1:G2]
1497        f_Gs = self.f_qGs[q][G1:G2]
1498        Y_GL = self.Y_qGL[q][G1:G2]
1499
1500        if self.pd.dtype == complex:
1501            f_GI = np.empty((G2 - G1, self.nI), complex)
1502        else:
1503            # Special layout because BLAS does not have real-complex
1504            # multiplications.  f_GI(G,I) layout:
1505            #
1506            #    real(G1, 0),   real(G1, 1),   ...
1507            #    imag(G1, 0),   imag(G1, 1),   ...
1508            #    real(G1+1, 0), real(G1+1, 1), ...
1509            #    imag(G1+1, 0), imag(G1+1, 1), ...
1510            #    ...
1511
1512            f_GI = np.empty((2 * (G2 - G1), self.nI))
1513
1514        if True:
1515            # Fast C-code:
1516            _gpaw.pwlfc_expand(f_Gs, emiGR_Ga, Y_GL,
1517                               self.l_s, self.a_J, self.s_J,
1518                               cc, f_GI)
1519            return f_GI
1520
1521        # Equivalent slow Python code:
1522        f_GI = np.empty((G2 - G1, self.nI), complex)
1523        I1 = 0
1524        for J, (a, s) in enumerate(zip(self.a_J, self.s_J)):
1525            l = self.l_s[s]
1526            I2 = I1 + 2 * l + 1
1527            f_GI[:, I1:I2] = (f_Gs[:, s] *
1528                              emiGR_Ga[:, a] *
1529                              Y_GL[:, l**2:(l + 1)**2].T *
1530                              (-1.0j)**l).T
1531            I1 = I2
1532        if cc:
1533            f_GI = f_GI.conj()
1534        if self.pd.dtype == float:
1535            f_GI = f_GI.T.copy().view(float).T.copy()
1536
1537        return f_GI
1538
1539    def block(self, q=-1, ensure_same_number_of_blocks=False):
1540        nG = self.Y_qGL[q].shape[0]
1541        B = self.blocksize
1542        if B:
1543            G1 = 0
1544            while G1 < nG:
1545                G2 = min(G1 + B, nG)
1546                yield G1, G2
1547                G1 = G2
1548            if ensure_same_number_of_blocks:
1549                # Make sure we yield the same number of times:
1550                nb = (self.pd.maxmyng + B - 1) // B
1551                mynb = (nG + B - 1) // B
1552                if mynb < nb:
1553                    yield nG, nG  # empty block
1554        else:
1555            yield 0, nG
1556
1557    def add(self, a_xG, c_axi=1.0, q=-1, f0_IG=None):
1558        c_xI = np.empty(a_xG.shape[:-1] + (self.nI,), self.pd.dtype)
1559
1560        if isinstance(c_axi, float):
1561            assert q == -1 and a_xG.ndim == 1
1562            c_xI[:] = c_axi
1563        else:
1564            assert q != -1 or self.pd.only_one_k_point
1565            if self.comm.size != 1:
1566                c_xI[:] = 0.0
1567            for a, I1, I2 in self.my_indices:
1568                c_xI[..., I1:I2] = c_axi[a] * self.eikR_qa[q][a].conj()
1569            if self.comm.size != 1:
1570                self.comm.sum(c_xI)
1571
1572        nx = np.prod(c_xI.shape[:-1], dtype=int)
1573        c_xI = c_xI.reshape((nx, self.nI))
1574        a_xG = a_xG.reshape((nx, a_xG.shape[-1])).view(self.pd.dtype)
1575
1576        for G1, G2 in self.block(q):
1577            if f0_IG is None:
1578                f_GI = self.expand(q, G1, G2, cc=False)
1579            else:
1580                1 / 0
1581                # f_IG = f0_IG
1582
1583            if self.pd.dtype == float:
1584                # f_IG = f_IG.view(float)
1585                G1 *= 2
1586                G2 *= 2
1587
1588            mmm(1.0 / self.pd.gd.dv, c_xI, 'N', f_GI, 'T',
1589                1.0, a_xG[:, G1:G2])
1590
1591    def integrate(self, a_xG, c_axi=None, q=-1):
1592        c_xI = np.zeros(a_xG.shape[:-1] + (self.nI,), self.pd.dtype)
1593
1594        nx = np.prod(c_xI.shape[:-1], dtype=int)
1595        b_xI = c_xI.reshape((nx, self.nI))
1596        a_xG = a_xG.reshape((nx, a_xG.shape[-1]))
1597
1598        alpha = 1.0 / self.pd.gd.N_c.prod()
1599        if self.pd.dtype == float:
1600            alpha *= 2
1601            a_xG = a_xG.view(float)
1602
1603        if c_axi is None:
1604            c_axi = self.dict(a_xG.shape[:-1])
1605
1606        x = 0.0
1607        for G1, G2 in self.block(q):
1608            f_GI = self.expand(q, G1, G2, cc=self.pd.dtype == complex)
1609            if self.pd.dtype == float:
1610                if G1 == 0 and self.comm.rank == 0:
1611                    f_GI[0] *= 0.5
1612                G1 *= 2
1613                G2 *= 2
1614            mmm(alpha, a_xG[:, G1:G2], 'N', f_GI, 'N', x, b_xI)
1615            x = 1.0
1616
1617        self.comm.sum(b_xI)
1618        for a, I1, I2 in self.my_indices:
1619            c_axi[a][:] = self.eikR_qa[q][a] * c_xI[..., I1:I2]
1620
1621        return c_axi
1622
1623    def matrix_elements(self, psit, out):
1624        P_ani = {a: P_in.T for a, P_in in out.items()}
1625        self.integrate(psit.array, P_ani, psit.kpt)
1626
1627    def derivative(self, a_xG, c_axiv=None, q=-1):
1628        c_vxI = np.zeros((3,) + a_xG.shape[:-1] + (self.nI,), self.pd.dtype)
1629        nx = np.prod(c_vxI.shape[1:-1], dtype=int)
1630        b_vxI = c_vxI.reshape((3, nx, self.nI))
1631        a_xG = a_xG.reshape((nx, a_xG.shape[-1])).view(self.pd.dtype)
1632
1633        alpha = 1.0 / self.pd.gd.N_c.prod()
1634
1635        if c_axiv is None:
1636            c_axiv = self.dict(a_xG.shape[:-1], derivative=True)
1637
1638        K_v = self.pd.K_qv[q]
1639
1640        x = 0.0
1641        for G1, G2 in self.block(q):
1642            f_GI = self.expand(q, G1, G2, cc=True)
1643            G_Gv = self.pd.G_Qv[self.pd.myQ_qG[q][G1:G2]]
1644            if self.pd.dtype == float:
1645                d_GI = np.empty_like(f_GI)
1646                for v in range(3):
1647                    d_GI[::2] = f_GI[1::2] * G_Gv[:, v, np.newaxis]
1648                    d_GI[1::2] = f_GI[::2] * G_Gv[:, v, np.newaxis]
1649                    mmm(2 * alpha,
1650                        a_xG[:, 2 * G1:2 * G2], 'N',
1651                        d_GI, 'N',
1652                        x, b_vxI[v])
1653            else:
1654                for v in range(3):
1655                    mmm(-alpha,
1656                        a_xG[:, G1:G2], 'N',
1657                        f_GI * (G_Gv[:, v] + K_v[v])[:, np.newaxis], 'N',
1658                        x, b_vxI[v])
1659            x = 1.0
1660
1661        self.comm.sum(c_vxI)
1662
1663        for v in range(3):
1664            if self.pd.dtype == float:
1665                for a, I1, I2 in self.my_indices:
1666                    c_axiv[a][..., v] = c_vxI[v, ..., I1:I2]
1667            else:
1668                for a, I1, I2 in self.my_indices:
1669                    c_axiv[a][..., v] = (1.0j * self.eikR_qa[q][a] *
1670                                         c_vxI[v, ..., I1:I2])
1671
1672        return c_axiv
1673
1674    def stress_tensor_contribution(self, a_xG, c_axi=1.0, q=-1):
1675        cache = {}
1676        things = []
1677        I1 = 0
1678        lmax = 0
1679        for a, spline_j in enumerate(self.spline_aj):
1680            for spline in spline_j:
1681                if spline not in cache:
1682                    s = ft(spline)
1683                    G_G = self.pd.G2_qG[q]**0.5
1684                    f_G = []
1685                    dfdGoG_G = []
1686                    for G in G_G:
1687                        f, dfdG = s.get_value_and_derivative(G)
1688                        if G < 1e-10:
1689                            G = 1.0
1690                        f_G.append(f)
1691                        dfdGoG_G.append(dfdG / G)
1692                    f_G = np.array(f_G)
1693                    dfdGoG_G = np.array(dfdGoG_G)
1694                    cache[spline] = (f_G, dfdGoG_G)
1695                else:
1696                    f_G, dfdGoG_G = cache[spline]
1697                l = spline.l
1698                lmax = max(l, lmax)
1699                I2 = I1 + 2 * l + 1
1700                things.append((a, l, I1, I2, f_G, dfdGoG_G))
1701                I1 = I2
1702
1703        if isinstance(c_axi, float):
1704            c_axi = dict((a, c_axi) for a in range(len(self.pos_av)))
1705
1706        G0_Gv = self.pd.get_reciprocal_vectors(q=q)
1707
1708        stress_vv = np.zeros((3, 3))
1709        for G1, G2 in self.block(q, ensure_same_number_of_blocks=True):
1710            G_Gv = G0_Gv[G1:G2]
1711            Z_LvG = np.array([nablarlYL(L, G_Gv.T)
1712                              for L in range((lmax + 1)**2)])
1713            aa_xG = a_xG[..., G1:G2]
1714            for v1 in range(3):
1715                for v2 in range(3):
1716                    stress_vv[v1, v2] += self._stress_tensor_contribution(
1717                        v1, v2, things, G1, G2, G_Gv, aa_xG, c_axi, q, Z_LvG)
1718
1719        self.comm.sum(stress_vv)
1720
1721        return stress_vv
1722
1723    def _stress_tensor_contribution(self, v1, v2, things, G1, G2,
1724                                    G_Gv, a_xG, c_axi, q, Z_LvG):
1725        f_IG = np.empty((self.nI, G2 - G1), complex)
1726        emiGR_Ga = self.emiGR_qGa[q][G1:G2]
1727        Y_LG = self.Y_qGL[q].T
1728        for a, l, I1, I2, f_G, dfdGoG_G in things:
1729            L1 = l**2
1730            L2 = (l + 1)**2
1731            f_IG[I1:I2] = (emiGR_Ga[:, a] * (-1.0j)**l *
1732                           (dfdGoG_G[G1:G2] * G_Gv[:, v1] * G_Gv[:, v2] *
1733                            Y_LG[L1:L2, G1:G2] +
1734                            f_G[G1:G2] * G_Gv[:, v1] * Z_LvG[L1:L2, v2]))
1735
1736        c_xI = np.zeros(a_xG.shape[:-1] + (self.nI,), self.pd.dtype)
1737
1738        x = np.prod(c_xI.shape[:-1], dtype=int)
1739        b_xI = c_xI.reshape((x, self.nI))
1740        a_xG = a_xG.reshape((x, a_xG.shape[-1]))
1741
1742        alpha = 1.0 / self.pd.gd.N_c.prod()
1743        if self.pd.dtype == float:
1744            alpha *= 2
1745            if G1 == 0 and self.pd.gd.comm.rank == 0:
1746                f_IG[:, 0] *= 0.5
1747            f_IG = f_IG.view(float)
1748            a_xG = a_xG.copy().view(float)
1749
1750        mmm(alpha, a_xG, 'N', f_IG, 'C', 0.0, b_xI)
1751        self.comm.sum(b_xI)
1752
1753        stress = 0.0
1754        for a, I1, I2 in self.my_indices:
1755            stress -= self.eikR_qa[q][a] * (c_axi[a] * c_xI[..., I1:I2]).sum()
1756        return stress.real
1757
1758
1759class PseudoCoreKineticEnergyDensityLFC(PWLFC):
1760    def add(self, tauct_R):
1761        tauct_R += self.pd.ifft(1.0 / self.pd.gd.dv *
1762                                self.expand().sum(1).view(complex))
1763
1764    def derivative(self, dedtaut_R, dF_aiv):
1765        PWLFC.derivative(self, self.pd.fft(dedtaut_R), dF_aiv)
1766
1767
1768class ReciprocalSpaceDensity(Density):
1769    def __init__(self, ecut,
1770                 gd, finegd, nspins, collinear, charge, redistributor,
1771                 background_charge=None):
1772        Density.__init__(self, gd, finegd, nspins, collinear, charge,
1773                         redistributor=redistributor,
1774                         background_charge=background_charge)
1775
1776        ecut0 = 0.5 * pi**2 / (gd.h_cv**2).sum(1).max()
1777        ecut = min(ecut, ecut0)
1778        self.pd2 = PWDescriptor(ecut, gd)
1779        self.pd3 = PWDescriptor(4 * ecut, finegd)
1780
1781        self.map23 = PWMapping(self.pd2, self.pd3)
1782
1783        self.nct_q = None
1784        self.nt_Q = None
1785        self.rhot_q = None
1786
1787    def initialize(self, setups, timer, magmom_av, hund):
1788        Density.initialize(self, setups, timer, magmom_av, hund)
1789
1790        spline_aj = []
1791        for setup in setups:
1792            if setup.nct is None:
1793                spline_aj.append([])
1794            else:
1795                spline_aj.append([setup.nct])
1796        self.nct = PWLFC(spline_aj, self.pd2)
1797
1798        self.ghat = PWLFC([setup.ghat_l for setup in setups], self.pd3,
1799                          )  # blocksize=256, comm=self.xc_redistributor.comm)
1800
1801    def set_positions(self, spos_ac, atom_partition):
1802        Density.set_positions(self, spos_ac, atom_partition)
1803        self.nct_q = self.pd2.zeros()
1804        self.nct.add(self.nct_q, 1.0 / self.nspins)
1805        self.nct_G = self.pd2.ifft(self.nct_q)
1806
1807    def interpolate_pseudo_density(self, comp_charge=None):
1808        """Interpolate pseudo density to fine grid."""
1809        if comp_charge is None:
1810            comp_charge, _Q_aL = self.calculate_multipole_moments()
1811
1812        if self.nt_xg is None:
1813            self.nt_xg = self.finegd.empty(self.ncomponents)
1814            self.nt_sg = self.nt_xg[:self.nspins]
1815            self.nt_vg = self.nt_xg[self.nspins:]
1816            self.nt_Q = self.pd2.empty()
1817
1818        self.nt_Q[:] = 0.0
1819
1820        x = 0
1821        for nt_G, nt_g in zip(self.nt_xG, self.nt_xg):
1822            nt_g[:], nt_Q = self.pd2.interpolate(nt_G, self.pd3)
1823            if x < self.nspins:
1824                self.nt_Q += nt_Q
1825            x += 1
1826
1827    def interpolate(self, in_xR, out_xR=None):
1828        """Interpolate array(s)."""
1829        if out_xR is None:
1830            out_xR = self.finegd.empty(in_xR.shape[:-3])
1831
1832        a_xR = in_xR.reshape((-1,) + in_xR.shape[-3:])
1833        b_xR = out_xR.reshape((-1,) + out_xR.shape[-3:])
1834
1835        for in_R, out_R in zip(a_xR, b_xR):
1836            out_R[:] = self.pd2.interpolate(in_R, self.pd3)[0]
1837
1838        return out_xR
1839
1840    distribute_and_interpolate = interpolate
1841
1842    def calculate_pseudo_charge(self):
1843        self.rhot_q = self.pd3.zeros()
1844        Q_aL = self.Q.calculate(self.D_asp)
1845        self.ghat.add(self.rhot_q, Q_aL)
1846        self.map23.add_to2(self.rhot_q, self.nt_Q)
1847        self.background_charge.add_fourier_space_charge_to(self.pd3,
1848                                                           self.rhot_q)
1849        if self.gd.comm.rank == 0:
1850            self.rhot_q[0] = 0.0
1851
1852    def get_pseudo_core_kinetic_energy_density_lfc(self):
1853        return PseudoCoreKineticEnergyDensityLFC(
1854            [[setup.tauct] for setup in self.setups], self.pd2)
1855
1856    def calculate_dipole_moment(self):
1857        pd = self.pd3
1858        N_c = pd.tmp_Q.shape
1859
1860        m0_q, m1_q, m2_q = [i_G == 0
1861                            for i_G in np.unravel_index(pd.Q_qG[0], N_c)]
1862        rhot_q = self.pd3.gather(self.rhot_q)
1863        if pd.comm.rank == 0:
1864            irhot_q = rhot_q.imag
1865            rhot_cs = [irhot_q[m1_q & m2_q],
1866                       irhot_q[m0_q & m2_q],
1867                       irhot_q[m0_q & m1_q]]
1868            d_c = [np.dot(rhot_s[1:], 1.0 / np.arange(1, len(rhot_s)))
1869                   for rhot_s in rhot_cs]
1870            d_v = -np.dot(d_c, pd.gd.cell_cv) / pi * pd.gd.dv
1871        else:
1872            d_v = np.empty(3)
1873        pd.comm.broadcast(d_v, 0)
1874        return d_v
1875
1876
1877class ReciprocalSpacePoissonSolver:
1878    def __init__(self, pd, realpbc_c):
1879        self.pd = pd
1880        self.realpbc_c = realpbc_c
1881        self.G2_q = pd.G2_qG[0]
1882        if pd.gd.comm.rank == 0:
1883            # Avoid division by zero:
1884            self.G2_q[0] = 1.0
1885
1886    def initialize(self):
1887        pass
1888
1889    def get_stencil(self):
1890        return '????'
1891
1892    def estimate_memory(self, mem):
1893        pass
1894
1895    def todict(self):
1896        return {}
1897
1898    def solve(self, vHt_q, dens):
1899        vHt_q[:] = 4 * pi * dens.rhot_q
1900        vHt_q /= self.G2_q
1901
1902
1903def integrate(pd, a, b):
1904    """Shortcut for integrals without calling pd.gd.comm.sum()."""
1905    return pd.integrate(a, b, global_integral=False)
1906
1907
1908class ReciprocalSpaceHamiltonian(Hamiltonian):
1909    def __init__(self, gd, finegd, pd2, pd3, nspins, collinear,
1910                 setups, timer, xc, world, xc_redistributor,
1911                 vext=None,
1912                 psolver=None, redistributor=None, realpbc_c=None):
1913
1914        assert redistributor is not None  # XXX should not be like this
1915        Hamiltonian.__init__(self, gd, finegd, nspins, collinear, setups,
1916                             timer, xc, world, vext=vext,
1917                             redistributor=redistributor)
1918
1919        self.vbar = PWLFC([[setup.vbar] for setup in setups], pd2)
1920        self.pd2 = pd2
1921        self.pd3 = pd3
1922        self.xc_redistributor = xc_redistributor
1923
1924        self.vHt_q = pd3.empty()
1925
1926        if psolver is None:
1927            psolver = ReciprocalSpacePoissonSolver(pd3, realpbc_c)
1928        elif isinstance(psolver, dict):
1929            direction = psolver['dipolelayer']
1930            assert len(psolver) == 1
1931            from gpaw.dipole_correction import DipoleCorrection
1932            psolver = DipoleCorrection(
1933                ReciprocalSpacePoissonSolver(pd3, realpbc_c), direction)
1934        self.poisson = psolver
1935        self.npoisson = 0
1936
1937        self.vbar_Q = None
1938        self.vt_Q = None
1939        self.estress = None
1940
1941    @property
1942    def xc_gd(self):
1943        if self.xc_redistributor is None:
1944            return self.finegd
1945        return self.xc_redistributor.aux_gd
1946
1947    def set_positions(self, spos_ac, atom_partition):
1948        Hamiltonian.set_positions(self, spos_ac, atom_partition)
1949        self.vbar_Q = self.pd2.zeros()
1950        self.vbar.add(self.vbar_Q)
1951
1952    def update_pseudo_potential(self, dens):
1953        ebar = integrate(self.pd2, self.vbar_Q, dens.nt_Q)
1954        with self.timer('Poisson'):
1955            self.poisson.solve(self.vHt_q, dens)
1956            epot = 0.5 * integrate(self.pd3, self.vHt_q, dens.rhot_q)
1957
1958        if self.vext is None:
1959            v_q = self.vHt_q
1960            eext = 0.0
1961        else:
1962            v_q = self.vext.get_potentialq(self.finegd, self.pd3).copy()
1963            eext = integrate(self.pd3, v_q, dens.rhot_q)
1964            v_q += self.vHt_q
1965
1966        self.vt_Q = self.vbar_Q.copy()
1967        dens.map23.add_to1(self.vt_Q, v_q)
1968
1969        self.vt_sG[:] = self.pd2.ifft(self.vt_Q)
1970
1971        self.timer.start('XC 3D grid')
1972
1973        nt_xg = dens.nt_xg
1974
1975        # If we have a redistributor, we want to do the
1976        # good old distribute-calculate-collect:
1977        redist = self.xc_redistributor
1978        if redist is not None:
1979            nt_xg = redist.distribute(nt_xg)
1980
1981        vxct_xg = np.zeros_like(nt_xg)
1982        exc = self.xc.calculate(self.xc_gd, nt_xg, vxct_xg)
1983        exc /= self.finegd.comm.size
1984        if redist is not None:
1985            vxct_xg = redist.collect(vxct_xg)
1986
1987        x = 0
1988        for vt_G, vxct_g in zip(self.vt_xG, vxct_xg):
1989            vxc_G, vxc_Q = self.pd3.restrict(vxct_g, self.pd2)
1990            if x < self.nspins:
1991                vt_G += vxc_G
1992                self.vt_Q += vxc_Q / self.nspins
1993            else:
1994                vt_G[:] = vxc_G
1995            x += 1
1996
1997        self.timer.stop('XC 3D grid')
1998
1999        energies = np.array([epot, ebar, eext, exc])
2000        self.estress = self.gd.comm.sum(epot + ebar)
2001        return energies
2002
2003    def calculate_atomic_hamiltonians(self, density):
2004        def getshape(a):
2005            return sum(2 * l + 1
2006                       for l, _ in enumerate(self.setups[a].ghat_l)),
2007        W_aL = ArrayDict(self.atomdist.aux_partition, getshape, float)
2008
2009        if self.vext:
2010            vext_q = self.vext.get_potentialq(self.finegd, self.pd3)
2011            density.ghat.integrate(self.vHt_q + vext_q, W_aL)
2012        else:
2013            density.ghat.integrate(self.vHt_q, W_aL)
2014
2015        return self.atomdist.to_work(self.atomdist.from_aux(W_aL))
2016
2017    def calculate_kinetic_energy(self, density):
2018        ekin = 0.0
2019        for vt_G, nt_G in zip(self.vt_xG, density.nt_xG):
2020            ekin -= integrate(self.gd, vt_G, nt_G)
2021        ekin += integrate(self.gd, self.vt_sG, density.nct_G).sum()
2022        return ekin
2023
2024    def restrict(self, in_xR, out_xR=None):
2025        """Restrict array."""
2026        if out_xR is None:
2027            out_xR = self.gd.empty(in_xR.shape[:-3])
2028
2029        a_xR = in_xR.reshape((-1,) + in_xR.shape[-3:])
2030        b_xR = out_xR.reshape((-1,) + out_xR.shape[-3:])
2031
2032        for in_R, out_R in zip(a_xR, b_xR):
2033            out_R[:] = self.pd3.restrict(in_R, self.pd2)[0]
2034
2035        return out_xR
2036
2037    restrict_and_collect = restrict
2038
2039    def calculate_forces2(self, dens, ghat_aLv, nct_av, vbar_av):
2040        if self.vext:
2041            vext_q = self.vext.get_potentialq(self.finegd, self.pd3)
2042            dens.ghat.derivative(self.vHt_q + vext_q, ghat_aLv)
2043        else:
2044            dens.ghat.derivative(self.vHt_q, ghat_aLv)
2045        dens.nct.derivative(self.vt_Q, nct_av)
2046        self.vbar.derivative(dens.nt_Q, vbar_av)
2047
2048    def get_electrostatic_potential(self, dens: Density) -> Array3D:
2049        self.poisson.solve(self.vHt_q, dens)
2050        return self.pd3.ifft(self.vHt_q, distribute=False)
2051