1from math import pi, sqrt
2from warnings import warn
3from typing import Optional, List, Dict
4
5import numpy as np
6from ase.units import Bohr, Ha
7
8from gpaw import GPAW
9from gpaw.atom.shapefunc import shape_functions
10from gpaw.fftw import get_efficient_fft_size
11from gpaw.grid_descriptor import GridDescriptor
12from gpaw.lfc import LocalizedFunctionsCollection as LFC
13from gpaw.utilities import h2gpts
14from gpaw.wavefunctions.pw import PWDescriptor
15from gpaw.mpi import serial_comm
16from gpaw.setup import Setup
17from gpaw.spline import Spline
18from gpaw.typing import Array3D
19
20
21class Interpolator:
22    def __init__(self, gd1, gd2, dtype=float):
23        self.pd1 = PWDescriptor(0.0, gd1, dtype)
24        self.pd2 = PWDescriptor(0.0, gd2, dtype)
25
26    def interpolate(self, a_r):
27        return self.pd1.interpolate(a_r, self.pd2)[0]
28
29
30POINTS = 200
31
32
33class PS2AE:
34    """Transform PS to AE wave functions.
35
36    Interpolates PS wave functions to a fine grid and adds PAW
37    corrections in order to obtain true AE wave functions.
38    """
39    def __init__(self,
40                 calc: GPAW,
41                 grid_spacing: float = 0.05,
42                 n: int = 2,
43                 h=None  # deprecated
44                 ):
45        """Create transformation object.
46
47        calc: GPAW calculator object
48            The calcalator that has the wave functions.
49        grid_spacing: float
50            Desired grid-spacing in Angstrom.
51        n: int
52            Force number of points to be a mulitiple of n.
53        """
54        if h is not None:
55            warn('Please use grid_spacing=... instead of h=...')
56            grid_spacing = h
57
58        self.calc = calc
59        gd = calc.wfs.gd
60
61        gd1 = GridDescriptor(gd.N_c, gd.cell_cv, comm=serial_comm)
62
63        # Descriptor for the final grid:
64        N_c = h2gpts(grid_spacing / Bohr, gd.cell_cv)
65        N_c = np.array([get_efficient_fft_size(N, n) for N in N_c])
66        gd2 = self.gd = GridDescriptor(N_c, gd.cell_cv, comm=serial_comm)
67        self.interpolator = Interpolator(gd1, gd2, self.calc.wfs.dtype)
68
69        self._dphi: Optional[LFC] = None  # PAW correction
70
71        self.dv = self.gd.dv * Bohr**3
72
73    @property
74    def dphi(self) -> LFC:
75        if self._dphi is not None:
76            return self._dphi
77
78        splines: Dict[Setup, List[Spline]] = {}
79        dphi_aj = []
80        for setup in self.calc.wfs.setups:
81            dphi_j = splines.get(setup)
82            if dphi_j is None:
83                rcut = max(setup.rcut_j) * 1.1
84                gcut = setup.rgd.ceil(rcut)
85                dphi_j = []
86                for l, phi_g, phit_g in zip(setup.l_j,
87                                            setup.data.phi_jg,
88                                            setup.data.phit_jg):
89                    dphi_g = (phi_g - phit_g)[:gcut]
90                    dphi_j.append(setup.rgd.spline(dphi_g, rcut, l,
91                                                   points=200))
92                splines[setup] = dphi_j
93            dphi_aj.append(dphi_j)
94
95        self._dphi = LFC(self.gd, dphi_aj, kd=self.calc.wfs.kd.copy(),
96                         dtype=self.calc.wfs.dtype)
97        self._dphi.set_positions(self.calc.spos_ac)
98
99        return self._dphi
100
101    def get_wave_function(self,
102                          n: int,
103                          k: int = 0,
104                          s: int = 0,
105                          ae: bool = True,
106                          periodic: bool = False) -> Array3D:
107        """Interpolate wave function.
108
109        Returns 3-d array in units of Ang**-1.5.
110
111        n: int
112            Band index.
113        k: int
114            K-point index.
115        s: int
116            Spin index.
117        ae: bool
118            Add PAW correction to get an all-electron wave function.
119        periodic:
120            Return periodic part of wave-function, u(r), instead of
121            psi(r)=exp(ikr)u(r).
122        """
123        u_r = self.calc.get_pseudo_wave_function(n, k, s,
124                                                 pad=True, periodic=True)
125        u_R = self.interpolator.interpolate(u_r * Bohr**1.5)
126
127        k_c = self.calc.wfs.kd.ibzk_kc[k]
128        gamma = np.isclose(k_c, 0.0).all()
129
130        if gamma:
131            eikr_R = 1.0
132        else:
133            eikr_R = self.gd.plane_wave(k_c)
134
135        if ae:
136            dphi = self.dphi
137            wfs = self.calc.wfs
138            P_nI = wfs.collect_projections(k, s)
139
140            if wfs.world.rank == 0:
141                psi_R = u_R * eikr_R
142                P_ai = {}
143                I1 = 0
144                for a, setup in enumerate(wfs.setups):
145                    I2 = I1 + setup.ni
146                    P_ai[a] = P_nI[n, I1:I2]
147                    I1 = I2
148                dphi.add(psi_R, P_ai, k)
149                u_R = psi_R / eikr_R
150
151            wfs.world.broadcast(u_R, 0)
152
153        if periodic:
154            return u_R * Bohr**-1.5
155        else:
156            return u_R * eikr_R * Bohr**-1.5
157
158    def get_pseudo_density(self,
159                           add_compensation_charges: bool = True) -> Array3D:
160        """Interpolate pseudo density."""
161        dens = self.calc.density
162        gd1 = dens.gd
163        assert gd1.comm.size == 1
164        interpolator = Interpolator(gd1, self.gd)
165        dens_r = dens.nt_sG[:dens.nspins].sum(axis=0)
166        dens_R = interpolator.interpolate(dens_r)
167
168        if add_compensation_charges:
169            dens.calculate_multipole_moments()
170            ghat = LFC(self.gd, [setup.ghat_l for setup in dens.setups],
171                       integral=sqrt(4 * pi))
172            ghat.set_positions(self.calc.spos_ac)
173            Q_aL = {}
174            for a, Q_L in dens.Q_aL.items():
175                Q_aL[a] = Q_L.copy()
176                Q_aL[a][0] += dens.setups[a].Nv / (4 * pi)**0.5
177            ghat.add(dens_R, Q_aL)
178
179        return dens_R / Bohr**3
180
181    def get_electrostatic_potential(self,
182                                    ae: bool = True,
183                                    rcgauss: float = 0.02) -> Array3D:
184        """Interpolate electrostatic potential.
185
186        Return value in eV.
187
188        ae: bool
189            Add PAW correction to get the all-electron potential.
190        rcgauss: float
191            Width of gaussian (in Angstrom) used to represent the nuclear
192            charge.
193        """
194        gd = self.calc.hamiltonian.finegd
195        v_r = self.calc.get_electrostatic_potential() / Ha
196        gd1 = GridDescriptor(gd.N_c, gd.cell_cv, comm=serial_comm)
197        interpolator = Interpolator(gd1, self.gd)
198        v_R = interpolator.interpolate(v_r)
199
200        if ae:
201            self.add_potential_correction(v_R, rcgauss / Bohr)
202
203        return v_R * Ha
204
205    def add_potential_correction(self,
206                                 v_R: Array3D,
207                                 rcgauss: float) -> None:
208        dens = self.calc.density
209        dens.D_asp.redistribute(dens.atom_partition.as_serial())
210        dens.Q_aL.redistribute(dens.atom_partition.as_serial())
211
212        dv_a1 = []
213        for a, D_sp in dens.D_asp.items():
214            setup = dens.setups[a]
215            c = setup.xc_correction
216            rgd = c.rgd
217            params = setup.data.shape_function.copy()
218            params['lmax'] = 0
219            ghat_g = shape_functions(rgd, **params)[0]
220            Z_g = shape_functions(rgd, 'gauss', rcgauss, lmax=0)[0] * setup.Z
221            D_q = np.dot(D_sp.sum(0), c.B_pqL[:, :, 0])
222            dn_g = np.dot(D_q, (c.n_qg - c.nt_qg)) * sqrt(4 * pi)
223            dn_g += 4 * pi * (c.nc_g - c.nct_g)
224            dn_g -= Z_g
225            dn_g -= dens.Q_aL[a][0] * ghat_g * sqrt(4 * pi)
226            dv_g = rgd.poisson(dn_g) / sqrt(4 * pi)
227            dv_g[1:] /= rgd.r_g[1:]
228            dv_g[0] = dv_g[1]
229            dv_g[-1] = 0.0
230            dv_a1.append([rgd.spline(dv_g, points=POINTS)])
231
232        dens.D_asp.redistribute(dens.atom_partition)
233        dens.Q_aL.redistribute(dens.atom_partition)
234
235        if dv_a1:
236            dv = LFC(self.gd, dv_a1)
237            dv.set_positions(self.calc.spos_ac)
238            dv.add(v_R)
239        dens.gd.comm.broadcast(v_R, 0)
240
241
242def interpolate_weight(calc, weight, h=0.05, n=2):
243    """interpolates cdft weight function, gd is the fine grid."""
244    gd = calc.density.finegd
245
246    weight = gd.collect(weight, broadcast=True)
247    weight = gd.zero_pad(weight)
248
249    w = np.zeros_like(weight)
250    gd1 = GridDescriptor(gd.N_c, gd.cell_cv, comm=serial_comm)
251    gd1.distribute(weight, w)
252
253    N_c = h2gpts(h / Bohr, gd.cell_cv)
254    N_c = np.array([get_efficient_fft_size(N, n) for N in N_c])
255    gd2 = GridDescriptor(N_c, gd.cell_cv, comm=serial_comm)
256
257    interpolator = Interpolator(gd1, gd2)
258    W = interpolator.interpolate(w)
259
260    return W
261