1"""Module defining an eigensolver base-class."""
2from functools import partial
3
4import numpy as np
5from ase.dft.bandgap import _bandgap
6from ase.units import Ha
7from ase.utils.timing import timer
8
9from gpaw.matrix import matrix_matrix_multiply as mmm
10from gpaw.utilities.blas import axpy
11from gpaw.xc.hybrid import HybridXC
12
13
14def reshape(a_x, shape):
15    """Get an ndarray of size shape from a_x buffer."""
16    return a_x.ravel()[:np.prod(shape)].reshape(shape)
17
18
19class Eigensolver:
20    def __init__(self, keep_htpsit=True, blocksize=1):
21        self.keep_htpsit = keep_htpsit
22        self.initialized = False
23        self.Htpsit_nG = None
24        self.error = np.inf
25        self.blocksize = blocksize
26        self.orthonormalization_required = True
27
28    def initialize(self, wfs):
29        self.timer = wfs.timer
30        self.world = wfs.world
31        self.kpt_comm = wfs.kd.comm
32        self.band_comm = wfs.bd.comm
33        self.dtype = wfs.dtype
34        self.bd = wfs.bd
35        self.nbands = wfs.bd.nbands
36        self.mynbands = wfs.bd.mynbands
37
38        if wfs.bd.comm.size > 1:
39            self.keep_htpsit = False
40
41        if self.keep_htpsit:
42            self.Htpsit_nG = np.empty_like(wfs.work_array)
43
44        # Preconditioner for the electronic gradients:
45        self.preconditioner = wfs.make_preconditioner(self.blocksize)
46
47        for kpt in wfs.kpt_u:
48            if kpt.eps_n is None:
49                kpt.eps_n = np.empty(self.mynbands)
50
51        self.initialized = True
52
53    def reset(self):
54        self.initialized = False
55
56    def weights(self, wfs):
57        """Calculate convergence weights for all eigenstates."""
58        weight_un = np.zeros((len(wfs.kpt_u), self.bd.mynbands))
59
60        if isinstance(self.nbands_converge, int):
61            # Converge fixed number of bands:
62            n = self.nbands_converge - self.bd.beg
63            if n > 0:
64                for weight_n, kpt in zip(weight_un, wfs.kpt_u):
65                    weight_n[:n] = kpt.weight
66        elif self.nbands_converge == 'occupied':
67            # Conveged occupied bands:
68            for weight_n, kpt in zip(weight_un, wfs.kpt_u):
69                if kpt.f_n is None:  # no eigenvalues yet
70                    weight_n[:] = np.inf
71                else:
72                    # Methfessel-Paxton distribution can give negative
73                    # occupation numbers - so we take the absolute value:
74                    weight_n[:] = np.abs(kpt.f_n)
75        else:
76            # Converge state with energy up to CBM + delta:
77            assert self.nbands_converge.startswith('CBM+')
78            delta = float(self.nbands_converge[4:]) / Ha
79
80            if wfs.kpt_u[0].f_n is None:
81                weight_un[:] = np.inf  # no eigenvalues yet
82            else:
83                # Collect all eigenvalues and calculate band gap:
84                efermi = np.mean(wfs.fermi_levels)
85                eps_skn = np.array(
86                    [[wfs.collect_eigenvalues(k, spin) - efermi
87                      for k in range(wfs.kd.nibzkpts)]
88                     for spin in range(wfs.nspins)])
89                if wfs.world.rank > 0:
90                    eps_skn = np.empty((wfs.nspins,
91                                        wfs.kd.nibzkpts,
92                                        wfs.bd.nbands))
93                wfs.world.broadcast(eps_skn, 0)
94                try:
95                    # Find bandgap + positions of CBM:
96                    gap, _, (s, k, n) = _bandgap(eps_skn,
97                                                 spin=None, direct=False)
98                except ValueError:
99                    gap = 0.0
100
101                if gap == 0.0:
102                    cbm = efermi
103                else:
104                    cbm = efermi + eps_skn[s, k, n]
105
106                ecut = cbm + delta
107
108                for weight_n, kpt in zip(weight_un, wfs.kpt_u):
109                    weight_n[kpt.eps_n < ecut] = kpt.weight
110
111                if (eps_skn[:, :, -1] < ecut - efermi).any():
112                    # We don't have enough bands!
113                    weight_un[:] = np.inf
114
115        return weight_un
116
117    def iterate(self, ham, wfs):
118        """Solves eigenvalue problem iteratively
119
120        This method is inherited by the actual eigensolver which should
121        implement *iterate_one_k_point* method for a single iteration of
122        a single kpoint.
123        """
124
125        if not self.initialized:
126            if isinstance(ham.xc, HybridXC):
127                self.blocksize = wfs.bd.mynbands
128            self.initialize(wfs)
129
130        weight_un = self.weights(wfs)
131
132        error = 0.0
133        for kpt, weights in zip(wfs.kpt_u, weight_un):
134            if not wfs.orthonormalized:
135                wfs.orthonormalize(kpt)
136            e = self.iterate_one_k_point(ham, wfs, kpt, weights)
137            error += e
138            if self.orthonormalization_required:
139                wfs.orthonormalize(kpt)
140
141        wfs.orthonormalized = True
142        self.error = self.band_comm.sum(self.kpt_comm.sum(error))
143
144    def iterate_one_k_point(self, ham, kpt):
145        """Implemented in subclasses."""
146        raise NotImplementedError
147
148    def calculate_residuals(self, kpt, wfs, ham, psit, P, eps_n,
149                            R, C, n_x=None, calculate_change=False):
150        """Calculate residual.
151
152        From R=Ht*psit calculate R=H*psit-eps*S*psit."""
153
154        for R_G, eps, psit_G in zip(R.array, eps_n, psit.array):
155            axpy(-eps, psit_G, R_G)
156
157        ham.dH(P, out=C)
158        for a, I1, I2 in P.indices:
159            dS_ii = ham.setups[a].dO_ii
160            C.array[..., I1:I2] -= np.dot((P.array[..., I1:I2].T * eps_n).T,
161                                          dS_ii)
162
163        ham.xc.add_correction(kpt, psit.array, R.array,
164                              {a: P_ni for a, P_ni in P.items()},
165                              {a: C_ni for a, C_ni in C.items()},
166                              n_x,
167                              calculate_change)
168        wfs.pt.add(R.array, {a: C_ni for a, C_ni in C.items()}, kpt.q)
169
170    @timer('Subspace diag')
171    def subspace_diagonalize(self, ham, wfs, kpt):
172        """Diagonalize the Hamiltonian in the subspace of kpt.psit_nG
173
174        *Htpsit_nG* is a work array of same size as psit_nG which contains
175        the local part of the Hamiltonian times psit on exit
176
177        First, the Hamiltonian (defined by *kin*, *vt_sG*, and
178        *dH_asp*) is applied to the wave functions, then the *H_nn*
179        matrix is calculated and diagonalized, and finally, the wave
180        functions (and also Htpsit_nG are rotated.  Also the
181        projections *P_ani* are rotated.
182
183        It is assumed that the wave functions *psit_nG* are orthonormal
184        and that the integrals of projector functions and wave functions
185        *P_ani* are already calculated.
186
187        Return rotated wave functions and H applied to the rotated
188        wave functions if self.keep_htpsit is True.
189        """
190
191        if self.band_comm.size > 1 and wfs.bd.strided:
192            raise NotImplementedError
193
194        psit = kpt.psit
195        tmp = psit.new(buf=wfs.work_array)
196        H = wfs.work_matrix_nn
197        P2 = kpt.projections.new()
198
199        Ht = partial(wfs.apply_pseudo_hamiltonian, kpt, ham)
200
201        with self.timer('calc_h_matrix'):
202            # We calculate the complex conjugate of H, because
203            # that is what is most efficient with BLAS given the layout of
204            # our matrices.
205            psit.matrix_elements(operator=Ht, result=tmp, out=H,
206                                 symmetric=True, cc=True)
207            ham.dH(kpt.projections, out=P2)
208            mmm(1.0, kpt.projections, 'N', P2, 'C', 1.0, H, symmetric=True)
209            ham.xc.correct_hamiltonian_matrix(kpt, H.array)
210
211        with wfs.timer('diagonalize'):
212            slcomm, r, c, b = wfs.scalapack_parameters
213            if r == c == 1:
214                slcomm = None
215            # Complex conjugate before diagonalizing:
216            eps_n = H.eigh(cc=True, scalapack=(slcomm, r, c, b))
217            # H.array[n, :] now contains the n'th eigenvector and eps_n[n]
218            # the n'th eigenvalue
219            kpt.eps_n = eps_n[wfs.bd.get_slice()]
220
221        with self.timer('rotate_psi'):
222            if self.keep_htpsit:
223                Htpsit = psit.new(buf=self.Htpsit_nG)
224                mmm(1.0, H, 'N', tmp, 'N', 0.0, Htpsit)
225            mmm(1.0, H, 'N', psit, 'N', 0.0, tmp)
226            psit[:] = tmp
227            mmm(1.0, H, 'N', kpt.projections, 'N', 0.0, P2)
228            kpt.projections.matrix = P2.matrix
229            # Rotate orbital dependent XC stuff:
230            ham.xc.rotate(kpt, H.array.T)
231
232    def estimate_memory(self, mem, wfs):
233        gridmem = wfs.bytes_per_wave_function()
234
235        keep_htpsit = self.keep_htpsit and (wfs.bd.mynbands == wfs.bd.nbands)
236
237        if keep_htpsit:
238            mem.subnode('Htpsit', wfs.bd.nbands * gridmem)
239        else:
240            mem.subnode('No Htpsit', 0)
241
242        mem.subnode('eps_n', wfs.bd.mynbands * mem.floatsize)
243        mem.subnode('eps_N', wfs.bd.nbands * mem.floatsize)
244        mem.subnode('Preconditioner', 4 * gridmem)
245        mem.subnode('Work', gridmem)
246