1"""This module provides an interface class for phonon calculations."""
2
3__all__ = ["PhononCalculator"]
4
5import time
6import os
7from math import pi, sqrt, sin
8
9import numpy as np
10import numpy.linalg as la
11
12import ase.units as units
13from ase.io import Trajectory
14
15from gpaw import GPAW
16from gpaw.mpi import serial_comm, rank
17from gpaw.kpt_descriptor import KPointDescriptor
18from gpaw.poisson import PoissonSolver, FFTPoissonSolver
19from gpaw.dfpt.responsecalculator import ResponseCalculator
20from gpaw.dfpt.phononperturbation import PhononPerturbation
21from gpaw.dfpt.wavefunctions import WaveFunctions
22from gpaw.dfpt.dynamicalmatrix import DynamicalMatrix
23from gpaw.dfpt.electronphononcoupling import ElectronPhononCoupling
24
25from gpaw.symmetry import Symmetry
26
27
28class PhononCalculator:
29    """This class defines the interface for phonon calculations."""
30
31    def __init__(self, calc, gamma=True, symmetry=False, e_ph=False,
32                 communicator=serial_comm):
33        """Inititialize class with a list of atoms.
34
35        The atoms object must contain a converged ground-state calculation.
36
37        The set of q-vectors in which the dynamical matrix will be calculated
38        is determined from the ``symmetry`` kwarg. For now, only time-reversal
39        symmetry is used to generate the irrecducible BZ.
40
41        Add a little note on parallelization strategy here.
42
43        Parameters
44        ----------
45        calc: str or Calculator
46            Calculator containing a ground-state calculation.
47        gamma: bool
48            Gamma-point calculation with respect to the q-vector of the
49            dynamical matrix. When ``False``, the Monkhorst-Pack grid from the
50            ground-state calculation is used.
51        symmetry: bool
52            Use symmetries to reduce the q-vectors of the dynamcial matrix
53            (None, False or True). The different options are equivalent to the
54            old style options in a ground-state calculation (see usesymm).
55        e_ph: bool
56            Save the derivative of the effective potential.
57        communicator: Communicator
58            Communicator for parallelization over k-points and real-space
59            domain.
60        """
61
62        # XXX
63        assert symmetry in [None, False], "Spatial symmetries not allowed yet"
64
65        if isinstance(calc, str):
66            self.calc = GPAW(calc, communicator=serial_comm, txt=None)
67        else:
68            self.calc = calc
69
70        cell_cv = self.calc.atoms.get_cell()
71        setups = self.calc.wfs.setups
72        # XXX - no clue how to get magmom - ignore it for the moment
73        # m_av = magmom_av.round(decimals=3)  # round off
74        # id_a = zip(setups.id_a, *m_av.T)
75        id_a = setups.id_a
76
77        if symmetry is None:
78            self.symmetry = Symmetry(id_a, cell_cv, point_group=False,
79                                     time_reversal=False)
80        else:
81            self.symmetry = Symmetry(id_a, cell_cv, point_group=False,
82                                     time_reversal=True)
83
84        # Make sure localized functions are initialized
85        self.calc.set_positions()
86        # Note that this under some circumstances (e.g. when called twice)
87        # allocates a new array for the P_ani coefficients !!
88
89        # Store useful objects
90        self.atoms = self.calc.get_atoms()
91        # Get rid of ``calc`` attribute
92        self.atoms.calc = None
93
94        # Boundary conditions
95        pbc_c = self.calc.atoms.get_pbc()
96
97        if not pbc_c.any():
98            self.gamma = True
99            self.dtype = float
100            kpts = None
101            # Multigrid Poisson solver
102            poisson_solver = PoissonSolver('fd')
103        else:
104            if gamma:
105                self.gamma = True
106                self.dtype = float
107                kpts = None
108            else:
109                self.gamma = False
110                self.dtype = complex
111                # Get k-points from ground-state calculation
112                kpts = self.calc.input_parameters.kpts
113
114            # FFT Poisson solver
115            poisson_solver = FFTPoissonSolver(dtype=self.dtype)
116
117        # K-point descriptor for the q-vectors of the dynamical matrix
118        # Note, no explicit parallelization here.
119        self.kd = KPointDescriptor(kpts, 1)
120        self.kd.set_symmetry(self.atoms, self.symmetry)
121        self.kd.set_communicator(serial_comm)
122
123        # Number of occupied bands
124        nvalence = self.calc.wfs.nvalence
125        nbands = nvalence // 2 + nvalence % 2
126        assert nbands <= self.calc.wfs.bd.nbands
127
128        # Extract other useful objects
129        # Ground-state k-point descriptor - used for the k-points in the
130        # ResponseCalculator
131        # XXX replace communicators when ready to parallelize
132        kd_gs = self.calc.wfs.kd
133        gd = self.calc.density.gd
134        kpt_u = self.calc.wfs.kpt_u
135        setups = self.calc.wfs.setups
136        dtype_gs = self.calc.wfs.dtype
137
138        # WaveFunctions
139        wfs = WaveFunctions(nbands, kpt_u, setups, kd_gs, gd, dtype=dtype_gs)
140
141        # Linear response calculator
142        self.response_calc = ResponseCalculator(self.calc, wfs,
143                                                dtype=self.dtype)
144
145        # Phonon perturbation
146        self.perturbation = PhononPerturbation(self.calc, self.kd,
147                                               poisson_solver,
148                                               dtype=self.dtype)
149
150        # Dynamical matrix
151        self.dyn = DynamicalMatrix(self.atoms, self.kd, dtype=self.dtype)
152
153        # Electron-phonon couplings
154        if e_ph:
155            self.e_ph = ElectronPhononCoupling(self.atoms, gd, self.kd,
156                                               dtype=self.dtype)
157        else:
158            self.e_ph = None
159
160        # Initialization flag
161        self.initialized = False
162
163        # Parallel communicator for parallelization over kpts and domain
164        self.comm = communicator
165
166    def initialize(self):
167        """Initialize response calculator and perturbation."""
168        self.perturbation.initialize(self.calc.spos_ac)
169        self.response_calc.initialize(self.calc.spos_ac)
170        self.initialized = True
171
172    def __getstate__(self):
173        """Method used when pickling.
174
175        Bound method attributes cannot be pickled and must therefore be deleted
176        before an instance is dumped to file.
177
178        """
179
180        # Get state of object and take care of troublesome attributes
181        state = dict(self.__dict__)
182        state['kd'].__dict__['comm'] = serial_comm
183        state.pop('calc')
184        state.pop('perturbation')
185        state.pop('response_calc')
186
187        return state
188
189    def run(self, qpts_q=None, clean=False, name=None, path=None):
190        """Run calculation for atomic displacements and update matrix.
191
192        Parameters
193        ----------
194        qpts: List
195            List of q-points indices for which the dynamical matrix will be
196            calculated (only temporary).
197
198        """
199
200        if not self.initialized:
201            self.initialize()
202
203        if self.gamma:
204            qpts_q = [0]
205        elif qpts_q is None:
206            qpts_q = range(self.kd.nibzkpts)
207        else:
208            assert isinstance(qpts_q, list)
209
210        # Update name and path attributes
211        self.set_name_and_path(name=name, path=path)
212        # Get string template for filenames
213        filename_str = self.get_filename_string()
214
215        # Delay the ranks belonging to the same k-point/domain decomposition
216        # equally
217        time.sleep(rank // self.comm.size)
218
219        # XXX Make a single ground_state_contributions member function
220        # Ground-state contributions to the force constants
221        self.dyn.density_ground_state(self.calc)
222        # self.dyn.wfs_ground_state(self.calc, self.response_calc)
223
224        # Calculate linear response wrt q-vectors and displacements of atoms
225        for q in qpts_q:
226
227            if not self.gamma:
228                self.perturbation.set_q(q)
229
230            # First-order contributions to the force constants
231            for a in self.dyn.indices:
232                for v in [0, 1, 2]:
233
234                    # Check if the calculation has already been done
235                    filename = filename_str % (q, a, v)
236                    # Wait for all sub-ranks to enter
237                    self.comm.barrier()
238
239                    if os.path.isfile(os.path.join(self.path, filename)):
240                        continue
241
242                    if self.comm.rank == 0:
243                        fd = open(os.path.join(self.path, filename), 'w')
244
245                    # Wait for all sub-ranks here
246                    self.comm.barrier()
247
248                    components = ['x', 'y', 'z']
249                    symbols = self.atoms.get_chemical_symbols()
250                    print("q-vector index: %i" % q)
251                    print("Atom index: %i" % a)
252                    print("Atomic symbol: %s" % symbols[a])
253                    print("Component: %s" % components[v])
254
255                    # Set atom and cartesian component of perturbation
256                    self.perturbation.set_av(a, v)
257                    # Calculate linear response
258                    self.response_calc(self.perturbation)
259
260                    # Calculate row of the matrix of force constants
261                    self.dyn.calculate_row(self.perturbation,
262                                           self.response_calc)
263
264                    # Write force constants to file
265                    if self.comm.rank == 0:
266                        self.dyn.write(fd, q, a, v)
267                        fd.close()
268
269                    # Store effective potential derivative
270                    if self.e_ph is not None:
271                        v1_eff_G = self.perturbation.v1_G + \
272                            self.response_calc.vHXC1_G
273                        self.e_ph.v1_eff_qavG.append(v1_eff_G)
274
275                    # Wait for the file-writing rank here
276                    self.comm.barrier()
277
278        # XXX
279        # Check that all files are valid and collect in a single file
280        # Remove the files
281        if clean:
282            self.clean()
283
284    def get_atoms(self):
285        """Return atoms."""
286
287        return self.atoms
288
289    def get_dynamical_matrix(self):
290        """Return reference to ``dyn`` attribute."""
291
292        return self.dyn
293
294    def get_filename_string(self):
295        """Return string template for force constant filenames."""
296
297        name_str = (self.name + '.' +
298                    'q_%%0%ii_' % len(str(self.kd.nibzkpts)) +
299                    'a_%%0%ii_' % len(str(len(self.atoms))) +
300                    'v_%i' + '.pckl')
301
302        return name_str
303
304    def set_atoms(self, atoms):
305        """Set atoms to be included in the calculation.
306
307        Parameters
308        ----------
309        atoms: list
310            Can be either a list of strings, ints or ...
311        """
312
313        assert isinstance(atoms, list)
314
315        if isinstance(atoms[0], str):
316            assert np.all([isinstance(atom, str) for atom in atoms])
317            sym_a = self.atoms.get_chemical_symbols()
318            # List for atomic indices
319            indices = []
320            for type in atoms:
321                indices.extend([a for a, atom in enumerate(sym_a)
322                                if atom == type])
323        else:
324            assert np.all([isinstance(atom, int) for atom in atoms])
325            indices = atoms
326
327        self.dyn.set_indices(indices)
328
329    def set_name_and_path(self, name=None, path=None):
330        """Set name and path of the force constant files.
331
332        name: str
333            Base name for the files which the elements of the matrix of force
334            constants will be written to.
335        path: str
336            Path specifying the directory where the files will be dumped.
337        """
338
339        if name is None:
340            self.name = 'phonon.' + self.atoms.get_chemical_formula()
341        else:
342            self.name = name
343        # self.name += '.nibzkpts_%i' % self.kd.nibzkpts
344
345        if path is None:
346            self.path = '.'
347        else:
348            self.path = path
349
350        # Set corresponding attributes in the ``dyn`` attribute
351        filename_str = self.get_filename_string()
352        self.dyn.set_name_and_path(filename_str, self.path)
353
354    def clean(self):
355        """Delete generated files."""
356
357        filename_str = self.get_filename_string()
358
359        for q in range(self.kd.nibzkpts):
360            for a in range(len(self.atoms)):
361                for v in [0, 1, 2]:
362                    filename = filename_str % (q, a, v)
363                    if os.path.isfile(os.path.join(self.path, filename)):
364                        os.remove(filename)
365
366    def band_structure(self, path_kc, modes=False, acoustic=True):
367        """Calculate phonon dispersion along a path in the Brillouin zone.
368
369        The dynamical matrix at arbitrary q-vectors is obtained by Fourier
370        transforming the real-space matrix. In case of negative eigenvalues
371        (squared frequency), the corresponding negative frequency is returned.
372
373        Parameters
374        ----------
375        path_kc: ndarray
376            List of k-point coordinates (in units of the reciprocal lattice
377            vectors) specifying the path in the Brillouin zone for which the
378            dynamical matrix will be calculated.
379        modes: bool
380            Returns both frequencies and modes (mass scaled) when True.
381        acoustic: bool
382            Restore the acoustic sum-rule in the calculated force constants.
383        """
384
385        for k_c in path_kc:
386            assert np.all(np.asarray(k_c) <= 1.0), \
387                "Scaled coordinates must be given"
388
389        # Assemble the dynanical matrix from calculated force constants
390        self.dyn.assemble(acoustic=acoustic)
391        # Get the dynamical matrix in real-space
392        DR_lmn, R_clmn = self.dyn.real_space()
393
394        # Reshape for the evaluation of the fourier sums
395        shape = DR_lmn.shape
396        DR_m = DR_lmn.reshape((-1,) + shape[-2:])
397        R_cm = R_clmn.reshape((3, -1))
398
399        # Lists for frequencies and modes along path
400        omega_kn = []
401        u_kn = []
402        # Number of atoms included
403        N = len(self.dyn.get_indices())
404
405        # Mass prefactor for the normal modes
406        m_inv_av = self.dyn.get_mass_array()
407
408        for q_c in path_kc:
409
410            # Evaluate fourier transform
411            phase_m = np.exp(-2.j * pi * np.dot(q_c, R_cm))
412            # Dynamical matrix in unit of Ha / Bohr**2 / amu
413            D_q = np.sum(phase_m[:, np.newaxis, np.newaxis] * DR_m, axis=0)
414
415            if modes:
416                omega2_n, u_avn = la.eigh(D_q, UPLO='L')
417                # Sort eigenmodes according to eigenvalues (see below) and
418                # multiply with mass prefactor
419                u_nav = u_avn[:, omega2_n.argsort()].T.copy() * m_inv_av
420                # Multiply with mass prefactor
421                u_kn.append(u_nav.reshape((3 * N, -1, 3)))
422            else:
423                omega2_n = la.eigvalsh(D_q, UPLO='L')
424
425            # Sort eigenvalues in increasing order
426            omega2_n.sort()
427            # Use dtype=complex to handle negative eigenvalues
428            omega_n = np.sqrt(omega2_n.astype(complex))
429
430            # Take care of imaginary frequencies
431            if not np.all(omega2_n >= 0.):
432                indices = np.where(omega2_n < 0)[0]
433                print(("WARNING, %i imaginary frequencies at "
434                       "q = (% 5.2f, % 5.2f, % 5.2f) ; (omega_q =% 5.3e*i)"
435                       % (len(indices), q_c[0], q_c[1], q_c[2],
436                          omega_n[indices][0].imag)))
437
438                omega_n[indices] = -1 * np.sqrt(np.abs(omega2_n[indices].real))
439
440            omega_kn.append(omega_n.real)
441
442        # Conversion factor from sqrt(Ha / Bohr**2 / amu) -> eV
443        s = units.Hartree**0.5 * units._hbar * 1.e10 / \
444            (units._e * units._amu)**(0.5) / units.Bohr
445        # Convert to eV and Ang
446        omega_kn = s * np.asarray(omega_kn)
447        if modes:
448            u_kn = np.asarray(u_kn) * units.Bohr
449            return omega_kn, u_kn
450
451        return omega_kn
452
453    def write_modes(self, q_c, branches=0, kT=units.kB * 300, repeat=(1, 1, 1),
454                    nimages=30, acoustic=True):
455        """Write mode to trajectory file.
456
457        The classical equipartioning theorem states that each normal mode has
458        an average energy::
459
460            <E> = 1/2 * k_B * T = 1/2 * omega^2 * Q^2
461
462                =>
463
464              Q = sqrt(k_B*T) / omega
465
466        at temperature T. Here, Q denotes the normal coordinate of the mode.
467
468        Parameters
469        ----------
470        q_c: ndarray
471            q-vector of the modes.
472        branches: int or list
473            Branch index of calculated modes.
474        kT: float
475            Temperature in units of eV. Determines the amplitude of the atomic
476            displacements in the modes.
477        repeat: tuple
478            Repeat atoms (l, m, n) times in the directions of the lattice
479            vectors. Displacements of atoms in repeated cells carry a Bloch
480            phase factor given by the q-vector and the cell lattice vector R_m.
481        nimages: int
482            Number of images in an oscillation.
483
484        """
485
486        if isinstance(branches, int):
487            branch_n = [branches]
488        else:
489            branch_n = list(branches)
490
491        # Calculate modes
492        omega_n, u_n = self.band_structure([q_c], modes=True,
493                                           acoustic=acoustic)
494
495        # Repeat atoms
496        atoms = self.atoms * repeat
497        pos_mav = atoms.positions.copy()
498        # Total number of unit cells
499        M = np.prod(repeat)
500
501        # Corresponding lattice vectors R_m
502        R_cm = np.indices(repeat[::-1]).reshape(3, -1)[::-1]
503        # Bloch phase
504        phase_m = np.exp(2.j * pi * np.dot(q_c, R_cm))
505        phase_ma = phase_m.repeat(len(self.atoms))
506
507        for n in branch_n:
508            omega = omega_n[0, n]
509            u_av = u_n[0, n]  # .reshape((-1, 3))
510            # Mean displacement at high T ?
511            u_av *= sqrt(kT / abs(omega))
512
513            mode_av = np.zeros((len(self.atoms), 3), dtype=self.dtype)
514            indices = self.dyn.get_indices()
515            mode_av[indices] = u_av
516            mode_mav = (np.vstack([mode_av] * M) *
517                        phase_ma[:, np.newaxis]).real
518
519            traj = Trajectory('%s.mode.%d.traj' % (self.name, n), 'w')
520
521            for x in np.linspace(0, 2 * pi, nimages, endpoint=False):
522                # XXX Is it correct to take out the sine component here ?
523                atoms.set_positions(pos_mav + sin(x) * mode_mav)
524                traj.write(atoms)
525
526            traj.close()
527