1"""This module implements a phonon perturbation."""
2
3__all__ = ["PhononPerturbation"]
4
5from math import pi
6import numpy as np
7import numpy.linalg as la
8
9from gpaw.utilities import unpack
10from gpaw.transformers import Transformer
11from gpaw.lfc import LocalizedFunctionsCollection as LFC
12from gpaw.dfpt.perturbation import Perturbation
13
14
15class PhononPerturbation(Perturbation):
16    """Implementation of a phonon perturbation.
17
18    This class implements the change in the effective potential due to a
19    displacement of an atom ``a`` in direction ``v`` with wave-vector ``q``.
20    The action of the perturbing potential on a state vector is implemented in
21    the ``apply`` member function.
22
23    """
24    def __init__(self, calc, kd, poisson_solver, dtype=float, **kwargs):
25        """Store useful objects, e.g. lfc's for the various atomic functions.
26
27        Depending on whether the system is periodic or finite, Poisson's
28        equation is solved with FFT or multigrid techniques, respectively.
29
30        Parameters
31        ----------
32        calc: Calculator
33            Ground-state calculation.
34        kd: KPointDescriptor
35            Descriptor for the q-vectors of the dynamical matrix.
36
37        """
38
39        self.kd = kd
40        self.dtype = dtype
41        self.poisson = poisson_solver
42
43        # Gamma wrt q-vector
44        if self.kd.gamma:
45            self.phase_cd = None
46        else:
47            assert self.kd.mynks == len(self.kd.ibzk_qc)
48
49            self.phase_qcd = []
50            sdisp_cd = calc.wfs.gd.sdisp_cd
51
52            for q in range(self.kd.mynks):
53                phase_cd = np.exp(2j * np.pi *
54                                  sdisp_cd * self.kd.ibzk_qc[q, :, np.newaxis])
55                self.phase_qcd.append(phase_cd)
56
57        # Store grid-descriptors
58        self.gd = calc.density.gd
59        self.finegd = calc.density.finegd
60
61        # Steal setups for the lfc's
62        setups = calc.wfs.setups
63
64        # Store projector coefficients
65        self.dH_asp = calc.hamiltonian.dH_asp.copy()
66
67        # Localized functions:
68        # core corections
69        self.nct = LFC(self.gd, [[setup.nct] for setup in setups],
70                       integral=[setup.Nct for setup in setups],
71                       dtype=self.dtype)
72        # compensation charges
73        # XXX what is the consequence of numerical errors in the integral ??
74        self.ghat = LFC(self.finegd, [setup.ghat_l for setup in setups],
75                        kd,
76                        dtype=self.dtype)
77        # self.ghat = LFC(self.finegd, [setup.ghat_l for setup in setups],
78        #                 integral=sqrt(4 * pi), dtype=self.dtype)
79        # vbar potential
80        self.vbar = LFC(self.finegd, [[setup.vbar] for setup in setups],
81                        kd,
82                        dtype=self.dtype)
83
84        # Expansion coefficients for the compensation charges
85        self.Q_aL = calc.density.Q_aL.copy()
86
87        # Grid transformer -- convert array from fine to coarse grid
88        self.restrictor = Transformer(self.finegd,
89                                      self.gd,
90                                      nn=3,
91                                      dtype=self.dtype)
92
93        # Atom, cartesian coordinate and q-vector of the perturbation
94        self.a = None
95        self.v = None
96
97        # Local q-vector index of the perturbation
98        if self.kd.gamma:
99            self.q = -1
100        else:
101            self.q = None
102
103    def initialize(self, spos_ac):
104        """Prepare the various attributes for a calculation."""
105
106        # Set positions on LFC's
107        self.nct.set_positions(spos_ac)
108        self.ghat.set_positions(spos_ac)
109        self.vbar.set_positions(spos_ac)
110
111        if not self.kd.gamma:
112            # Phase factor exp(iq.r) needed to obtian the periodic part of lfcs
113            coor_vg = self.finegd.get_grid_point_coordinates()
114            cell_cv = self.finegd.cell_cv
115            # Convert to scaled coordinates
116            scoor_cg = np.dot(la.inv(cell_cv), coor_vg.swapaxes(0, -2))
117            scoor_cg = scoor_cg.swapaxes(1, -2)
118            # Phase factor
119            phase_qg = np.exp(
120                2j * pi * np.dot(self.kd.ibzk_qc, scoor_cg.swapaxes(0, -2)))
121            self.phase_qg = phase_qg.swapaxes(1, -2)
122
123        # XXX To be removed from this class !!
124        # Setup the Poisson solver -- to be used on the fine grid
125        self.poisson.set_grid_descriptor(self.finegd)
126        self.poisson.initialize()
127
128    def set_q(self, q):
129        """Set the index of the q-vector of the perturbation."""
130
131        assert not self.kd.gamma, "Gamma-point calculation"
132
133        self.q = q
134
135        # Update phases and Poisson solver
136        self.phase_cd = self.phase_qcd[q]
137        self.poisson.set_q(self.kd.ibzk_qc[q])
138
139        # Invalidate calculated quantities
140        # - local part of perturbing potential
141        self.v1_G = None
142
143    def set_av(self, a, v):
144        """Set atom and cartesian component of the perturbation.
145
146        Parameters
147        ----------
148        a: int
149            Index of the atom.
150        v: int
151            Cartesian component (0, 1 or 2) of the atomic displacement.
152
153        """
154
155        assert self.q is not None
156
157        self.a = a
158        self.v = v
159
160        # Update derivative of local potential
161        self.calculate_local_potential()
162
163    def get_phase_cd(self):
164        """Overwrite base class member function."""
165
166        return self.phase_cd
167
168    def has_q(self):
169        """Overwrite base class member function."""
170
171        return (not self.kd.gamma)
172
173    def get_q(self):
174        """Return q-vector."""
175
176        assert not self.kd.gamma, "Gamma-point calculation."
177
178        return self.kd.ibzk_qc[self.q]
179
180    def solve_poisson(self, phi_g, rho_g):
181        """Solve Poisson's equation for a Bloch-type charge distribution.
182
183        More to come here ...
184
185        Parameters
186        ----------
187        phi_g: GridDescriptor
188            Grid for the solution of Poissons's equation.
189        rho_g: GridDescriptor
190            Grid with the charge distribution.
191
192        """
193
194        # assert phi_g.shape == rho_g.shape == self.phase_qg.shape[-3:], \
195        #       ("Arrays have incompatible shapes.")
196        assert self.q is not None, ("q-vector not set")
197
198        # Gamma point calculation wrt the q-vector -> rho_g periodic
199        if self.kd.gamma:
200            # XXX NOTICE: solve_neutral
201            self.poisson.solve_neutral(phi_g, rho_g)
202        else:
203            # Divide out the phase factor to get the periodic part
204            rhot_g = rho_g / self.phase_qg[self.q]
205
206            # Solve Poisson's equation for the periodic part of the potential
207            # XXX NOTICE: solve_neutral
208            self.poisson.solve_neutral(phi_g, rhot_g)
209
210            # Return to Bloch form
211            phi_g *= self.phase_qg[self.q]
212
213    def calculate_local_potential(self):
214        """Derivate of the local potential wrt an atomic displacements.
215
216        The local part of the PAW potential has contributions from the
217        compensation charges (``ghat``) and a spherical symmetric atomic
218        potential (``vbar``).
219
220        """
221
222        assert self.a is not None
223        assert self.v is not None
224        assert self.q is not None
225
226        a = self.a
227        v = self.v
228
229        # Expansion coefficients for the ghat functions
230        Q_aL = self.ghat.dict(zero=True)
231        # Remember sign convention for add_derivative method
232        # And be sure not to change the dtype of the arrays by assigning values
233        # to array elements.
234        Q_aL[a][:] = -1 * self.Q_aL[a]
235
236        # Grid for derivative of compensation charges
237        ghat1_g = self.finegd.zeros(dtype=self.dtype)
238        self.ghat.add_derivative(a, v, ghat1_g, c_axi=Q_aL, q=self.q)
239
240        # Solve Poisson's eq. for the potential from the periodic part of the
241        # compensation charge derivative
242        v1_g = self.finegd.zeros(dtype=self.dtype)
243        self.solve_poisson(v1_g, ghat1_g)
244
245        # Store potential from the compensation charge
246        self.vghat1_g = v1_g.copy()
247
248        # Add derivative of vbar - sign convention in add_derivative method
249        c_ai = self.vbar.dict(zero=True)
250        c_ai[a][0] = -1.
251        self.vbar.add_derivative(a, v, v1_g, c_axi=c_ai, q=self.q)
252
253        # Store potential for the evaluation of the energy derivative
254        self.v1_g = v1_g.copy()
255
256        # Transfer to coarse grid
257        v1_G = self.gd.zeros(dtype=self.dtype)
258        self.restrictor.apply(v1_g, v1_G, phases=self.phase_cd)
259
260        self.v1_G = v1_G
261
262    def apply(self, psi_nG, y_nG, wfs, k, kplusq):
263        """Apply perturbation to unperturbed wave-functions.
264
265        Parameters
266        ----------
267        psi_nG: ndarray
268            Set of grid vectors to which the perturbation is applied.
269        y_nG: ndarray
270            Output vectors.
271        wfs: WaveFunctions
272            Instance of class ``WaveFunctions``.
273        k: int
274            Index of the k-point for the vectors.
275        kplusq: int
276            Index of the k+q vector.
277
278        """
279
280        assert self.a is not None
281        assert self.v is not None
282        assert self.q is not None
283        assert psi_nG.ndim in (3, 4)
284        assert tuple(self.gd.n_c) == psi_nG.shape[-3:]
285
286        if psi_nG.ndim == 3:
287            y_nG += self.v1_G * psi_nG
288        else:
289            y_nG += self.v1_G[np.newaxis, :] * psi_nG
290
291        self.apply_nonlocal_potential(psi_nG, y_nG, wfs, k, kplusq)
292
293    def apply_nonlocal_potential(self, psi_nG, y_nG, wfs, k, kplusq):
294        """Derivate of the non-local PAW potential wrt an atomic displacement.
295
296        Parameters
297        ----------
298        k: int
299            Index of the k-point being operated on.
300        kplusq: int
301            Index of the k+q vector.
302
303        """
304
305        assert self.a is not None
306        assert self.v is not None
307        assert psi_nG.ndim in (3, 4)
308        assert tuple(self.gd.n_c) == psi_nG.shape[-3:]
309
310        if psi_nG.ndim == 3:
311            n = 1
312        else:
313            n = psi_nG.shape[0]
314
315        a = self.a
316        v = self.v
317
318        P_ani = wfs.kpt_u[k].P_ani
319        dP_aniv = wfs.kpt_u[k].dP_aniv
320        pt = wfs.pt
321
322        # < p_a^i | Psi_nk >
323        P_ni = P_ani[a]
324        # < dp_av^i | Psi_nk > - remember the sign convention of the derivative
325        dP_ni = -1 * dP_aniv[a][..., v]
326
327        # Expansion coefficients for the projectors on atom a
328        dH_ii = unpack(self.dH_asp[a][0])
329
330        # The derivative of the non-local PAW potential has two contributions
331        # 1) Sum over projectors
332        c_ni = np.dot(dP_ni, dH_ii)
333        c_ani = pt.dict(shape=n, zero=True)
334        c_ani[a] = c_ni
335        # k+q !!
336        pt.add(y_nG, c_ani, q=kplusq)
337
338        # 2) Sum over derivatives of the projectors
339        dc_ni = np.dot(P_ni, dH_ii)
340        dc_ani = pt.dict(shape=n, zero=True)
341        # Take care of sign of derivative in the coefficients
342        dc_ani[a] = -1 * dc_ni
343        # k+q !!
344        pt.add_derivative(a, v, y_nG, dc_ani, q=kplusq)
345