1import numpy as np
2from ase.units import Bohr
3
4from gpaw.fd_operators import Laplace, Gradient
5from gpaw.kpoint import KPoint
6from gpaw.kpt_descriptor import KPointDescriptor
7from gpaw.lfc import LocalizedFunctionsCollection as LFC
8from gpaw.mpi import serial_comm
9from gpaw.preconditioner import Preconditioner
10from gpaw.projections import Projections
11from gpaw.transformers import Transformer
12from gpaw.utilities.blas import axpy
13from gpaw.wavefunctions.arrays import UniformGridWaveFunctions
14from gpaw.wavefunctions.fdpw import FDPWWaveFunctions
15from gpaw.wavefunctions.mode import Mode
16import _gpaw
17
18
19class FD(Mode):
20    name = 'fd'
21
22    def __init__(self, nn=3, interpolation=3, force_complex_dtype=False):
23        self.nn = nn
24        self.interpolation = interpolation
25        Mode.__init__(self, force_complex_dtype)
26
27    def __call__(self, *args, **kwargs):
28        return FDWaveFunctions(self.nn, *args, **kwargs)
29
30    def todict(self):
31        dct = Mode.todict(self)
32        dct['nn'] = self.nn
33        dct['interpolation'] = self.interpolation
34        return dct
35
36
37class FDWaveFunctions(FDPWWaveFunctions):
38    mode = 'fd'
39
40    def __init__(self, stencil, parallel, initksl,
41                 gd, nvalence, setups, bd,
42                 dtype, world, kd, kptband_comm, timer, reuse_wfs_method=None,
43                 collinear=True):
44        FDPWWaveFunctions.__init__(self, parallel, initksl,
45                                   reuse_wfs_method=reuse_wfs_method,
46                                   collinear=collinear,
47                                   gd=gd, nvalence=nvalence, setups=setups,
48                                   bd=bd, dtype=dtype, world=world, kd=kd,
49                                   kptband_comm=kptband_comm, timer=timer)
50
51        # Kinetic energy operator:
52        self.kin = Laplace(self.gd, -0.5, stencil, self.dtype)
53
54        self.taugrad_v = None  # initialized by MGGA functional
55
56    def empty(self, n=(), global_array=False, realspace=False, q=-1):
57        return self.gd.empty(n, self.dtype, global_array)
58
59    def integrate(self, a_xg, b_yg=None, global_integral=True):
60        return self.gd.integrate(a_xg, b_yg, global_integral)
61
62    def bytes_per_wave_function(self):
63        return self.gd.bytecount(self.dtype)
64
65    def set_setups(self, setups):
66        self.pt = LFC(self.gd, [setup.pt_j for setup in setups],
67                      self.kd, dtype=self.dtype, forces=True)
68        FDPWWaveFunctions.set_setups(self, setups)
69
70    def set_positions(self, spos_ac, atom_partition=None):
71        FDPWWaveFunctions.set_positions(self, spos_ac, atom_partition)
72
73    def __str__(self):
74        s = 'Wave functions: Uniform real-space grid\n'
75        s += '  Kinetic energy operator: %s\n' % self.kin.description
76        return s + FDPWWaveFunctions.__str__(self)
77
78    def make_preconditioner(self, block=1):
79        return Preconditioner(self.gd, self.kin, self.dtype, block)
80
81    def apply_pseudo_hamiltonian(self, kpt, ham, psit_xG, Htpsit_xG):
82        self.timer.start('Apply hamiltonian')
83        self.kin.apply(psit_xG, Htpsit_xG, kpt.phase_cd)
84        ham.apply_local_potential(psit_xG, Htpsit_xG, kpt.s)
85        ham.xc.apply_orbital_dependent_hamiltonian(
86            kpt, psit_xG, Htpsit_xG, ham.dH_asp)
87        self.timer.stop('Apply hamiltonian')
88
89    def get_pseudo_partial_waves(self):
90        phit_aj = [setup.get_partial_waves_for_atomic_orbitals()
91                   for setup in self.setups]
92        return LFC(self.gd, phit_aj, kd=self.kd, cut=True, dtype=self.dtype)
93
94    def add_to_density_from_k_point_with_occupation(self, nt_sG, kpt, f_n):
95        # Used in calculation of response part of GLLB-potential
96        nt_G = nt_sG[kpt.s]
97        for f, psit_G in zip(f_n, kpt.psit_nG):
98            # Same as nt_G += f * abs(psit_G)**2, but much faster:
99            _gpaw.add_to_density(f, psit_G, nt_G)
100
101        # Hack used in delta-scf calculations:
102        if hasattr(kpt, 'c_on'):
103            assert self.bd.comm.size == 1
104            d_nn = np.zeros((self.bd.mynbands, self.bd.mynbands),
105                            dtype=complex)
106            for ne, c_n in zip(kpt.ne_o, kpt.c_on):
107                d_nn += ne * np.outer(c_n.conj(), c_n)
108            for d_n, psi0_G in zip(d_nn, kpt.psit_nG):
109                for d, psi_G in zip(d_n, kpt.psit_nG):
110                    if abs(d) > 1.e-12:
111                        nt_G += (psi0_G.conj() * d * psi_G).real
112
113    def calculate_kinetic_energy_density(self):
114        if self.taugrad_v is None:
115            self.taugrad_v = [
116                Gradient(self.gd, v, n=3, dtype=self.dtype).apply
117                for v in range(3)]
118
119        assert not hasattr(self.kpt_u[0], 'c_on')
120        if not isinstance(self.kpt_u[0].psit_nG, np.ndarray):
121            return None
122
123        taut_sG = self.gd.zeros(self.nspins)
124        dpsit_G = self.gd.empty(dtype=self.dtype)
125        for kpt in self.kpt_u:
126            for f, psit_G in zip(kpt.f_n, kpt.psit_nG):
127                for v in range(3):
128                    self.taugrad_v[v](psit_G, dpsit_G, kpt.phase_cd)
129                    axpy(0.5 * f, abs(dpsit_G)**2, taut_sG[kpt.s])
130
131        self.kptband_comm.sum(taut_sG)
132        return taut_sG
133
134    def apply_mgga_orbital_dependent_hamiltonian(self, kpt, psit_xG,
135                                                 Htpsit_xG, dH_asp,
136                                                 dedtaut_G):
137        a_G = self.gd.empty(dtype=psit_xG.dtype)
138        for psit_G, Htpsit_G in zip(psit_xG, Htpsit_xG):
139            for v in range(3):
140                self.taugrad_v[v](psit_G, a_G, kpt.phase_cd)
141                self.taugrad_v[v](dedtaut_G * a_G, a_G, kpt.phase_cd)
142                axpy(-0.5, a_G, Htpsit_G)
143
144    def ibz2bz(self, atoms):
145        """Transform wave functions in IBZ to the full BZ."""
146
147        assert self.kd.comm.size == 1
148
149        # New k-point descriptor for full BZ:
150        kd = KPointDescriptor(self.kd.bzk_kc, nspins=self.nspins)
151        kd.set_communicator(serial_comm)
152
153        self.pt = LFC(self.gd, [setup.pt_j for setup in self.setups],
154                      kd, dtype=self.dtype)
155        self.pt.set_positions(atoms.get_scaled_positions())
156
157        self.initialize_wave_functions_from_restart_file()
158
159        weight = 2.0 / kd.nspins / kd.nbzkpts
160
161        # Build new list of k-points:
162        kpt_qs = []
163        kpt_u = []
164        for k in range(kd.nbzkpts):
165            kpt_s = []
166            for s in range(self.nspins):
167                # Index of symmetry related point in the IBZ
168                ik = self.kd.bz2ibz_k[k]
169                r, q = self.kd.get_rank_and_index(ik)
170                assert r == 0
171                kpt = self.kpt_qs[q][s]
172
173                phase_cd = np.exp(2j * np.pi * self.gd.sdisp_cd *
174                                  kd.bzk_kc[k, :, np.newaxis])
175
176                # New k-point:
177                kpt2 = KPoint(1.0 / kd.nbzkpts, weight, s, k, k, phase_cd)
178                kpt2.f_n = kpt.f_n / kpt.weight / kd.nbzkpts * 2 / self.nspins
179                kpt2.eps_n = kpt.eps_n.copy()
180
181                # Transform wave functions using symmetry operation:
182                Psit_nG = self.gd.collect(kpt.psit_nG)
183                if Psit_nG is not None:
184                    Psit_nG = Psit_nG.copy()
185                    for Psit_G in Psit_nG:
186                        Psit_G[:] = self.kd.transform_wave_function(Psit_G, k)
187                kpt2.psit = UniformGridWaveFunctions(
188                    self.bd.nbands, self.gd, self.dtype,
189                    kpt=k, dist=(self.bd.comm, self.bd.comm.size),
190                    spin=kpt.s, collinear=True)
191                self.gd.distribute(Psit_nG, kpt2.psit_nG)
192                # Calculate PAW projections:
193                nproj_a = [setup.ni for setup in self.setups]
194                kpt2.projections = Projections(
195                    self.bd.nbands, nproj_a,
196                    kpt.projections.atom_partition,
197                    self.bd.comm,
198                    collinear=True, spin=s, dtype=self.dtype)
199
200                kpt2.psit.matrix_elements(self.pt, out=kpt2.projections)
201                kpt_s.append(kpt2)
202                kpt_u.append(kpt2)
203            kpt_qs.append(kpt_s)
204
205        self.kd = kd
206        self.kpt_qs = kpt_qs
207        self.kpt_u = kpt_u
208
209    def _get_wave_function_array(self, u, n, realspace=True, periodic=False):
210        assert realspace
211        kpt = self.kpt_u[u]
212        psit_G = kpt.psit_nG[n]
213        if periodic and self.dtype == complex:
214            k_c = self.kd.ibzk_kc[kpt.k]
215            return self.gd.plane_wave(-k_c) * psit_G
216        return psit_G
217
218    def write(self, writer, write_wave_functions=False):
219        FDPWWaveFunctions.write(self, writer)
220
221        if not write_wave_functions:
222            return
223
224        writer.add_array(
225            'values',
226            (self.nspins, self.kd.nibzkpts, self.bd.nbands) +
227            tuple(self.gd.get_size_of_global_array()),
228            self.dtype)
229
230        for s in range(self.nspins):
231            for k in range(self.kd.nibzkpts):
232                for n in range(self.bd.nbands):
233                    psit_G = self.get_wave_function_array(n, k, s)
234                    writer.fill(psit_G * Bohr**-1.5)
235
236    def read(self, reader):
237        FDPWWaveFunctions.read(self, reader)
238
239        if 'values' not in reader.wave_functions:
240            return
241
242        c = reader.bohr**1.5
243        if reader.version < 0:
244            c = 1  # old gpw file
245
246        for kpt in self.kpt_u:
247            # We may not be able to keep all the wave
248            # functions in memory - so psit_nG will be a special type of
249            # array that is really just a reference to a file:
250            psit_nG = reader.wave_functions.proxy('values', kpt.s, kpt.k)
251            psit_nG.scale = c
252
253            kpt.psit = UniformGridWaveFunctions(
254                self.bd.nbands, self.gd, self.dtype, psit_nG,
255                kpt=kpt.q, dist=(self.bd.comm, self.bd.comm.size),
256                spin=kpt.s, collinear=True)
257
258        if self.world.size > 1:
259            # Read to memory:
260            for kpt in self.kpt_u:
261                kpt.psit.read_from_file()
262
263    def initialize_from_lcao_coefficients(self, basis_functions):
264        for kpt in self.kpt_u:
265            kpt.psit = UniformGridWaveFunctions(
266                self.bd.nbands, self.gd, self.dtype, kpt=kpt.q,
267                dist=(self.bd.comm, self.bd.comm.size, 1),
268                spin=kpt.s, collinear=True)
269            kpt.psit_nG[:] = 0.0
270            mynbands = len(kpt.C_nM)
271            basis_functions.lcao_to_grid(kpt.C_nM,
272                                         kpt.psit_nG[:mynbands], kpt.q)
273            kpt.C_nM = None
274
275    def random_wave_functions(self, nao):
276        """Generate random wave functions."""
277
278        gpts = self.gd.N_c[0] * self.gd.N_c[1] * self.gd.N_c[2]
279
280        if self.bd.nbands < gpts / 64:
281            gd1 = self.gd.coarsen()
282            gd2 = gd1.coarsen()
283
284            psit_G1 = gd1.empty(dtype=self.dtype)
285            psit_G2 = gd2.empty(dtype=self.dtype)
286
287            interpolate2 = Transformer(gd2, gd1, 1, self.dtype).apply
288            interpolate1 = Transformer(gd1, self.gd, 1, self.dtype).apply
289
290            shape = tuple(gd2.n_c)
291            scale = np.sqrt(12 / abs(np.linalg.det(gd2.cell_cv)))
292
293            old_state = np.random.get_state()
294
295            np.random.seed(4 + self.world.rank)
296
297            for kpt in self.kpt_u:
298                for psit_G in kpt.psit_nG[nao:]:
299                    if self.dtype == float:
300                        psit_G2[:] = (np.random.random(shape) - 0.5) * scale
301                    else:
302                        psit_G2.real = (np.random.random(shape) - 0.5) * scale
303                        psit_G2.imag = (np.random.random(shape) - 0.5) * scale
304
305                    interpolate2(psit_G2, psit_G1, kpt.phase_cd)
306                    interpolate1(psit_G1, psit_G, kpt.phase_cd)
307            np.random.set_state(old_state)
308
309        elif gpts / 64 <= self.bd.nbands < gpts / 8:
310            gd1 = self.gd.coarsen()
311
312            psit_G1 = gd1.empty(dtype=self.dtype)
313
314            interpolate1 = Transformer(gd1, self.gd, 1, self.dtype).apply
315
316            shape = tuple(gd1.n_c)
317            scale = np.sqrt(12 / abs(np.linalg.det(gd1.cell_cv)))
318
319            old_state = np.random.get_state()
320
321            np.random.seed(4 + self.world.rank)
322
323            for kpt in self.kpt_u:
324                for psit_G in kpt.psit_nG[nao:]:
325                    if self.dtype == float:
326                        psit_G1[:] = (np.random.random(shape) - 0.5) * scale
327                    else:
328                        psit_G1.real = (np.random.random(shape) - 0.5) * scale
329                        psit_G1.imag = (np.random.random(shape) - 0.5) * scale
330
331                    interpolate1(psit_G1, psit_G, kpt.phase_cd)
332            np.random.set_state(old_state)
333
334        else:
335            shape = tuple(self.gd.n_c)
336            scale = np.sqrt(12 / abs(np.linalg.det(self.gd.cell_cv)))
337
338            old_state = np.random.get_state()
339
340            np.random.seed(4 + self.world.rank)
341
342            for kpt in self.kpt_u:
343                for psit_G in kpt.psit_nG[nao:]:
344                    if self.dtype == float:
345                        psit_G[:] = (np.random.random(shape) - 0.5) * scale
346                    else:
347                        psit_G.real = (np.random.random(shape) - 0.5) * scale
348                        psit_G.imag = (np.random.random(shape) - 0.5) * scale
349
350            np.random.set_state(old_state)
351
352    def estimate_memory(self, mem):
353        FDPWWaveFunctions.estimate_memory(self, mem)
354