1import numpy as np
2
3from gpaw import debug
4from gpaw.transformers import Transformer
5from gpaw.lfc import BasisFunctions
6from gpaw.lcaotddft.observer import TDDFTObserver
7from gpaw.utilities import unpack2, is_contiguous
8
9from gpaw.inducedfield.inducedfield_base import BaseInducedField, \
10    sendreceive_dict
11
12
13class TDDFTInducedField(BaseInducedField, TDDFTObserver):
14    """Induced field class for time propagation TDDFT.
15
16    Attributes (see also ``BaseInducedField``):
17    -------------------------------------------
18    time: float
19        Current time
20    Fnt_wsG: ndarray (complex)
21        Fourier transform of induced pseudo density
22    n0t_sG: ndarray (float)
23        Ground state pseudo density
24    FD_awsp: dict of ndarray (complex)
25        Fourier transform of induced D_asp
26    D0_asp: dict of ndarray (float)
27        Ground state D_asp
28    """
29
30    def __init__(self, filename=None, paw=None,
31                 frequencies=None, folding='Gauss', width=0.08,
32                 interval=1, restart_file=None
33                 ):
34        """
35        Parameters (see also ``BaseInducedField``):
36        -------------------------------------------
37        paw: TDDFT object
38            TDDFT object for time propagation
39        width: float
40            Width in eV for the Gaussian (sigma) or Lorentzian (eta) folding
41            Gaussian   = exp(- (1/2) * sigma^2 * t^2)
42            Lorentzian = exp(- eta * t)
43        interval: int
44            Number of timesteps between calls (used when attaching)
45        restart_file: string
46            Name of the restart file
47        """
48
49        TDDFTObserver.__init__(self, paw, interval)
50        # From observer:
51        # self.niter
52        # self.interval
53        # self.timer
54        # Observer does also paw.attach(self, ...)
55
56        # Restart file
57        self.restart_file = restart_file
58
59        # These are allocated in allocate()
60        self.Fnt_wsG = None
61        self.n0t_sG = None
62        self.FD_awsp = None
63        self.D0_asp = None
64
65        self.readwritemode_str_to_list = \
66            {'': ['Fnt', 'n0t', 'FD', 'D0', 'atoms'],
67             'all': ['Fnt', 'n0t', 'FD', 'D0',
68                     'Frho', 'Fphi', 'Fef', 'Ffe', 'atoms'],
69             'field': ['Frho', 'Fphi', 'Fef', 'Ffe', 'atoms']}
70
71        BaseInducedField.__init__(self, filename, paw,
72                                  frequencies, folding, width)
73
74    def initialize(self, paw, allocate=True):
75        BaseInducedField.initialize(self, paw, allocate)
76
77        if self.has_paw:
78            assert hasattr(paw, 'time') and hasattr(paw, 'niter'), 'Use TDDFT!'
79            self.time = paw.time                # !
80            self.niter = paw.niter
81
82    def set_folding(self, folding, width):
83        BaseInducedField.set_folding(self, folding, width)
84
85        if self.folding is None:
86            self.envelope = lambda t: 1.0
87        else:
88            if self.folding == 'Gauss':
89                self.envelope = lambda t: np.exp(- 0.5 * self.width**2 * t**2)
90            elif self.folding == 'Lorentz':
91                self.envelope = lambda t: np.exp(- self.width * t)
92            else:
93                raise RuntimeError('unknown folding "' + self.folding + '"')
94
95    def allocate(self):
96        if not self.allocated:
97
98            # Ground state pseudo density
99            self.n0t_sG = self.gd.empty((self.nspins, )) + np.nan
100
101            # Fourier transformed pseudo density
102            self.Fnt_wsG = self.gd.zeros((self.nw, self.nspins),
103                                         dtype=self.dtype)
104
105            # Ground state D_asp
106            self.D0_asp = {}
107            for a, D_sp in self.density.D_asp.items():
108                self.D0_asp[a] = D_sp.copy()
109
110            # Size of D_p for each atom
111            self.np_a = {}
112            for a, D_sp in self.D0_asp.items():
113                self.np_a[a] = np.array([len(D_sp[0])])
114
115            # Fourier transformed D_asp
116            self.FD_awsp = {}
117            for a, np_ in self.np_a.items():
118                self.FD_awsp[a] = np.zeros((self.nw, self.nspins, np_[0]),
119                                           dtype=self.dtype)
120
121            self.allocated = True
122
123        if debug:
124            assert is_contiguous(self.Fnt_wsG, self.dtype)
125
126    def deallocate(self):
127        BaseInducedField.deallocate(self)
128        self.n0t_sG = None
129        self.Fnt_wsG = None
130        self.D0_asp = None
131        self.FD_awsp = None
132
133    def _update(self, paw):
134        if paw.action == 'init':
135            if paw.niter == 0:
136                self.n0t_sG[:] = paw.density.nt_sG
137            return
138        elif paw.action == 'kick':
139            # Background electric field
140            self.Fbgef_v = paw.kick_strength
141            return
142        elif paw.action != 'propagate':
143            return
144
145        assert (self.Fbgef_v is not None
146                and not np.any(np.isnan(self.n0t_sG))), \
147            f'Attach {self.__class__.__name__} before absorption kick'
148
149        # Update time
150        self.time = paw.time
151        time_step = paw.time_step
152
153        # Complex exponential with envelope
154        f_w = np.exp(1.0j * self.omega_w * self.time) * \
155            self.envelope(self.time) * time_step
156
157        # Time-dependent quantities
158        nt_sG = self.density.nt_sG
159        D_asp = self.density.D_asp
160
161        # Update Fourier transforms
162        for w in range(self.nw):
163            self.Fnt_wsG[w] += (nt_sG - self.n0t_sG) * f_w[w]
164            for a, D_sp in D_asp.items():
165                self.FD_awsp[a][w] += (D_sp - self.D0_asp[a]) * f_w[w]
166
167        # Restart file
168        # XXX remove this once deprecated dump_interval is removed,
169        # but keep write_restart() as it'll be still used
170        # (see TDDFTObserver class)
171        if (paw.restart_file is not None
172                and self.niter % paw.dump_interval == 0):
173            self.write_restart()
174
175    def write_restart(self):
176        if self.restart_file is not None:
177            self.write(self.restart_file)
178            self.log(f'{self.__class__.__name__}: Wrote restart file')
179
180    def interpolate_pseudo_density(self, gridrefinement=2):
181
182        gd = self.gd
183        Fnt_wsg = self.Fnt_wsG.copy()
184
185        # Find m for
186        # gridrefinement = 2**m
187        m1 = np.log(gridrefinement) / np.log(2.)
188        m = int(np.round(m1))
189
190        # Check if m is really integer
191        if np.absolute(m - m1) < 1e-8:
192            for i in range(m):
193                gd2 = gd.refine()
194
195                # Interpolate
196                interpolator = Transformer(gd, gd2, self.stencil,
197                                           dtype=self.dtype)
198                Fnt2_wsg = gd2.empty((self.nw, self.nspins), dtype=self.dtype)
199                for w in range(self.nw):
200                    for s in range(self.nspins):
201                        interpolator.apply(Fnt_wsg[w][s], Fnt2_wsg[w][s],
202                                           np.ones((3, 2), dtype=complex))
203
204                gd = gd2
205                Fnt_wsg = Fnt2_wsg
206        else:
207            raise NotImplementedError
208
209        return Fnt_wsg, gd
210
211    def comp_charge_correction(self, gridrefinement=2):
212
213        # TODO: implement for gr==1 also
214        assert gridrefinement == 2
215
216        # Density
217        Fnt_wsg, gd = self.interpolate_pseudo_density(gridrefinement)
218        Frhot_wg = Fnt_wsg.sum(axis=1)
219
220        tmp_g = gd.empty(dtype=float)
221        for w in range(self.nw):
222            # Determine compensation charge coefficients:
223            FQ_aL = {}
224            for a, FD_wsp in self.FD_awsp.items():
225                FQ_aL[a] = np.dot(FD_wsp[w].sum(axis=0),
226                                  self.setups[a].Delta_pL)
227
228            # Add real part of compensation charges
229            tmp_g[:] = 0
230            FQ2_aL = {}
231            for a, FQ_L in FQ_aL.items():
232                # Take copy to make array contiguous
233                FQ2_aL[a] = FQ_L.real.copy()
234#                print is_contiguous(FQ2_aL[a])
235#                print is_contiguous(FQ_L.real)
236            self.density.ghat.add(tmp_g, FQ2_aL)
237            Frhot_wg[w] += tmp_g
238
239            # Add imag part of compensation charges
240            tmp_g[:] = 0
241            FQ2_aL = {}
242            for a, FQ_L in FQ_aL.items():
243                FQ2_aL[a] = FQ_L.imag.copy()
244            self.density.ghat.add(tmp_g, FQ2_aL)
245            Frhot_wg[w] += 1.0j * tmp_g
246
247        return Frhot_wg, gd
248
249    def paw_corrections(self, gridrefinement=2):
250
251        Fn_wsg, gd = self.interpolate_pseudo_density(gridrefinement)
252
253        # Splines
254        splines = {}
255        phi_aj = []
256        phit_aj = []
257        for a, id in enumerate(self.setups.id_a):
258            if id in splines:
259                phi_j, phit_j = splines[id]
260            else:
261                # Load splines:
262                phi_j, phit_j = self.setups[a].get_partial_waves()[:2]
263                splines[id] = (phi_j, phit_j)
264            phi_aj.append(phi_j)
265            phit_aj.append(phit_j)
266
267        # Create localized functions from splines
268        phi = BasisFunctions(gd, phi_aj, dtype=float)
269        phit = BasisFunctions(gd, phit_aj, dtype=float)
270#        phi = BasisFunctions(gd, phi_aj, dtype=complex)
271#        phit = BasisFunctions(gd, phit_aj, dtype=complex)
272        spos_ac = self.atoms.get_scaled_positions()
273        phi.set_positions(spos_ac)
274        phit.set_positions(spos_ac)
275
276        tmp_g = gd.empty(dtype=float)
277        rho_MM = np.zeros((phi.Mmax, phi.Mmax), dtype=self.dtype)
278        rho2_MM = np.zeros_like(rho_MM)
279        for w in range(self.nw):
280            for s in range(self.nspins):
281                rho_MM[:] = 0
282                M1 = 0
283                for a, setup in enumerate(self.setups):
284                    ni = setup.ni
285                    FD_wsp = self.FD_awsp.get(a)
286                    if FD_wsp is None:
287                        FD_p = np.empty((ni * (ni + 1) // 2), dtype=self.dtype)
288                    else:
289                        FD_p = FD_wsp[w][s]
290                    if gd.comm.size > 1:
291                        gd.comm.broadcast(FD_p, self.rank_a[a])
292                    D_ij = unpack2(FD_p)
293                    # unpack does complex conjugation that we don't want so
294                    # remove conjugation
295                    D_ij = np.triu(D_ij, 1) + np.conj(np.tril(D_ij))
296
297#                    if FD_wsp is None:
298#                        FD_wsp = np.empty((self.nw, self.nspins,
299#                                           ni * (ni + 1) // 2),
300#                                          dtype=self.dtype)
301#                    if gd.comm.size > 1:
302#                        gd.comm.broadcast(FD_wsp, self.rank_a[a])
303#                    D_ij = unpack2(FD_wsp[w][s])
304#                    D_ij = np.triu(D_ij, 1) + np.conj(np.tril(D_ij))
305
306                    M2 = M1 + ni
307                    rho_MM[M1:M2, M1:M2] = D_ij
308                    M1 = M2
309
310                # Add real part of AE corrections
311                tmp_g[:] = 0
312                rho2_MM[:] = rho_MM.real
313                # TODO: use ae_valence_density_correction
314                phi.construct_density(rho2_MM, tmp_g, q=-1)
315                phit.construct_density(-rho2_MM, tmp_g, q=-1)
316#                phi.lfc.ae_valence_density_correction(rho2_MM, tmp_g,
317#                                                      np.zeros(len(phi.M_W),
318#                                                               np.intc),
319#                                                      np.zeros(self.na))
320#                phit.lfc.ae_valence_density_correction(-rho2_MM, tmp_g,
321#                                                      np.zeros(len(phi.M_W),
322#                                                               np.intc),
323#                                                      np.zeros(self.na))
324                Fn_wsg[w][s] += tmp_g
325
326                # Add imag part of AE corrections
327                tmp_g[:] = 0
328                rho2_MM[:] = rho_MM.imag
329                # TODO: use ae_valence_density_correction
330                phi.construct_density(rho2_MM, tmp_g, q=-1)
331                phit.construct_density(-rho2_MM, tmp_g, q=-1)
332#                phi.lfc.ae_valence_density_correction(rho2_MM, tmp_g,
333#                                                      np.zeros(len(phi.M_W),
334#                                                               np.intc),
335#                                                      np.zeros(self.na))
336#                phit.lfc.ae_valence_density_correction(-rho2_MM, tmp_g,
337#                                                      np.zeros(len(phi.M_W),
338#                                                               np.intc),
339#                                                      np.zeros(self.na))
340                Fn_wsg[w][s] += 1.0j * tmp_g
341
342        return Fn_wsg, gd
343
344    def get_induced_density(self, from_density, gridrefinement):
345        # Return charge density (electrons = negative charge)
346        if from_density == 'pseudo':
347            Fn_wsg, gd = self.interpolate_pseudo_density(gridrefinement)
348            Frho_wg = - Fn_wsg.sum(axis=1)
349            return Frho_wg, gd
350        elif from_density == 'comp':
351            Frho_wg, gd = self.comp_charge_correction(gridrefinement)
352            Frho_wg = - Frho_wg
353            return Frho_wg, gd
354        elif from_density == 'ae':
355            Fn_wsg, gd = self.paw_corrections(gridrefinement)
356            Frho_wg = - Fn_wsg.sum(axis=1)
357            return Frho_wg, gd
358        else:
359            raise RuntimeError('unknown from_density "' + from_density + '"')
360
361    def _read(self, reader, reads):
362        BaseInducedField._read(self, reader, reads)
363
364        r = reader
365        time = r.time
366        if self.has_paw:
367            # Test time
368            if abs(time - self.time) >= 1e-9:
369                raise IOError('Timestamp is incompatible with calculator.')
370        else:
371            self.time = time
372
373        # Allocate
374        self.allocate()
375
376        # Dimensions for D_p for all atoms
377        self.np_a = r.np_a
378
379        def readarray(name):
380            if name.split('_')[0] in reads:
381                self.gd.distribute(r.get(name), getattr(self, name))
382
383        # Read arrays
384        readarray('n0t_sG')
385        readarray('Fnt_wsG')
386
387        if 'D0' in reads:
388            D0_asp = r.D0_asp
389            self.D0_asp = {}
390            for a in range(self.na):
391                if self.domain_comm.rank == self.rank_a[a]:
392                    self.D0_asp[a] = D0_asp[a]
393
394        if 'FD' in reads:
395            FD_awsp = r.FD_awsp
396            self.FD_awsp = {}
397            for a in range(self.na):
398                if self.domain_comm.rank == self.rank_a[a]:
399                    self.FD_awsp[a] = FD_awsp[a]
400
401    def _write(self, writer, writes):
402        BaseInducedField._write(self, writer, writes)
403
404        # Collect np_a to master
405        if self.kpt_comm.rank == 0 and self.band_comm.rank == 0:
406
407            # Create empty dict on domain master
408            if self.domain_comm.rank == 0:
409                np_a = {}
410                for a in range(self.na):
411                    np_a[a] = np.empty(1, dtype=int)
412            else:
413                np_a = {}
414            # Collect dict to master
415            sendreceive_dict(self.domain_comm, np_a, 0,
416                             self.np_a, self.rank_a, range(self.na))
417
418        # Write time propagation status
419        writer.write(time=self.time, np_a=np_a)
420
421        def writearray(name, shape, dtype):
422            if name.split('_')[0] in writes:
423                writer.add_array(name, shape, dtype)
424            a_wxg = getattr(self, name)
425            for w in range(self.nw):
426                writer.fill(self.gd.collect(a_wxg[w]))
427
428        ng = tuple(self.gd.get_size_of_global_array())
429
430        # Write time propagation arrays
431        if 'n0t' in writes:
432            writer.write(n0t_sG=self.gd.collect(self.n0t_sG))
433        writearray('Fnt_wsG', (self.nw, self.nspins) + ng, self.dtype)
434
435        if 'D0' in writes:
436            # Collect D0_asp to world master
437            if self.kpt_comm.rank == 0 and self.band_comm.rank == 0:
438                # Create empty dict on domain master
439                if self.domain_comm.rank == 0:
440                    D0_asp = {}
441                    for a in range(self.na):
442                        npa = np_a[a]
443                        D0_asp[a] = np.empty((self.nspins, npa[0]),
444                                             dtype=float)
445                else:
446                    D0_asp = {}
447                # Collect dict to master
448                sendreceive_dict(self.domain_comm, D0_asp, 0,
449                                 self.D0_asp, self.rank_a, range(self.na))
450            # Write
451            writer.write(D0_asp=D0_asp)
452
453        if 'FD' in writes:
454            # Collect FD_awsp to world master
455            if self.kpt_comm.rank == 0 and self.band_comm.rank == 0:
456                # Create empty dict on domain master
457                if self.domain_comm.rank == 0:
458                    FD_awsp = {}
459                    for a in range(self.na):
460                        npa = np_a[a]
461                        FD_awsp[a] = np.empty((self.nw, self.nspins, npa[0]),
462                                              dtype=complex)
463                else:
464                    FD_awsp = {}
465                # Collect dict to master
466                sendreceive_dict(self.domain_comm, FD_awsp, 0,
467                                 self.FD_awsp, self.rank_a, range(self.na))
468            # Write
469            writer.write(FD_awsp=FD_awsp)
470