1"""Module defining  ``Eigensolver`` classes."""
2from functools import partial
3
4import numpy as np
5
6from gpaw.utilities.blas import axpy
7from gpaw.eigensolvers.eigensolver import Eigensolver
8
9
10class RMMDIIS(Eigensolver):
11    """RMM-DIIS eigensolver
12
13    It is expected that the trial wave functions are orthonormal
14    and the integrals of projector functions and wave functions
15    ``nucleus.P_uni`` are already calculated
16
17    Solution steps are:
18
19    * Subspace diagonalization
20    * Calculation of residuals
21    * Improvement of wave functions:  psi' = psi + lambda PR + lambda PR'
22    * Orthonormalization"""
23
24    def __init__(self, keep_htpsit=True, blocksize=None, niter=3, rtol=1e-16,
25                 limit_lambda=False, use_rayleigh=False, trial_step=0.1):
26        """Initialize RMM-DIIS eigensolver.
27
28        Parameters:
29
30        limit_lambda: dictionary
31            determines if step length should be limited
32            supported keys:
33            'absolute':True/False limit the absolute value
34            'upper':float upper limit for lambda
35            'lower':float lower limit for lambda
36
37        """
38
39        Eigensolver.__init__(self, keep_htpsit, blocksize)
40        self.niter = niter
41        self.rtol = rtol
42        self.limit_lambda = limit_lambda
43        self.use_rayleigh = use_rayleigh
44        if use_rayleigh:
45            1 / 0
46            self.blocksize = 1
47        self.trial_step = trial_step
48        self.first = True
49
50    def todict(self):
51        return {'name': 'rmm-diis', 'niter': self.niter}
52
53    def initialize(self, wfs):
54        if self.blocksize is None:
55            if wfs.mode == 'pw':
56                S = wfs.pd.comm.size
57                # Use a multiple of S for maximum efficiency
58                self.blocksize = int(np.ceil(10 / S)) * S
59            else:
60                self.blocksize = 10
61        Eigensolver.initialize(self, wfs)
62
63    def iterate_one_k_point(self, ham, wfs, kpt, weights):
64        """Do a single RMM-DIIS iteration for the kpoint"""
65
66        self.subspace_diagonalize(ham, wfs, kpt)
67
68        psit = kpt.psit
69        # psit2 = psit.new(buf=wfs.work_array)
70        P = kpt.projections
71        P2 = P.new()
72        # dMP = P.new()
73        # M_nn = wfs.work_matrix_nn
74        # dS = wfs.setups.dS
75        R = psit.new(buf=self.Htpsit_nG)
76
77        self.timer.start('RMM-DIIS')
78        if self.keep_htpsit:
79            with self.timer('Calculate residuals'):
80                self.calculate_residuals(kpt, wfs, ham, psit, P, kpt.eps_n,
81                                         R, P2)
82
83        def integrate(a_G, b_G):
84            return np.real(wfs.integrate(a_G, b_G, global_integral=False))
85
86        comm = wfs.gd.comm
87
88        B = self.blocksize
89        dR = R.new(dist=None, nbands=B)
90        dpsit = dR.new()
91        P = P.new(bcomm=None, nbands=B)
92        P2 = P.new()
93        errors_x = np.zeros(B)
94
95        # Arrays needed for DIIS step
96        if self.niter > 1:
97            psit_diis_nxG = wfs.empty(B * self.niter, q=kpt.q)
98            R_diis_nxG = wfs.empty(B * self.niter, q=kpt.q)
99
100        Ht = partial(wfs.apply_pseudo_hamiltonian, kpt, ham)
101
102        error = 0.0
103        for n1 in range(0, wfs.bd.mynbands, B):
104            n2 = n1 + B
105            if n2 > wfs.bd.mynbands:
106                n2 = wfs.bd.mynbands
107                B = n2 - n1
108                P = P.new(nbands=B)
109                P2 = P.new()
110                dR = dR.new(nbands=B, dist=None)
111                dpsit = dR.new()
112
113            n_x = np.arange(n1, n2)
114            psitb = psit.view(n1, n2)
115
116            with self.timer('Calculate residuals'):
117                Rb = R.view(n1, n2)
118                if not self.keep_htpsit:
119                    psitb.apply(Ht, out=Rb)
120                    psitb.matrix_elements(wfs.pt, out=P)
121                    self.calculate_residuals(kpt, wfs, ham, psitb,
122                                             P, kpt.eps_n[n_x], Rb, P2, n_x)
123
124            errors_x[:] = 0.0
125            for n in range(n1, n2):
126                weight = weights[n]
127                errors_x[n - n1] = weight * integrate(Rb.array[n - n1],
128                                                      Rb.array[n - n1])
129            comm.sum(errors_x)
130            error += np.sum(errors_x)
131
132            # Insert first vectors and residuals for DIIS step
133            if self.niter > 1:
134                # Save the previous vectors contiguously for each band
135                # in the block
136                psit_diis_nxG[:B * self.niter:self.niter] = psitb.array
137                R_diis_nxG[:B * self.niter:self.niter] = Rb.array
138
139            # Precondition the residual:
140            with self.timer('precondition'):
141                ekin_x = self.preconditioner.calculate_kinetic_energy(
142                    psitb.array, kpt)
143                self.preconditioner(Rb.array, kpt, ekin_x, out=dpsit.array)
144
145            # Calculate the residual of dpsit_G, dR_G = (H - e S) dpsit_G:
146            # self.timer.start('Apply Hamiltonian')
147            dpsit.apply(Ht, out=dR)
148            # self.timer.stop('Apply Hamiltonian')
149            with self.timer('projections'):
150                dpsit.matrix_elements(wfs.pt, out=P)
151
152            with self.timer('Calculate residuals'):
153                self.calculate_residuals(kpt, wfs, ham, dpsit,
154                                         P, kpt.eps_n[n_x], dR, P2, n_x,
155                                         calculate_change=True)
156
157            # Find lam that minimizes the norm of R'_G = R_G + lam dR_G
158            with self.timer('Find lambda'):
159                RdR_x = np.array([integrate(dR_G, R_G)
160                                  for R_G, dR_G in zip(Rb.array, dR.array)])
161                dRdR_x = np.array([integrate(dR_G, dR_G) for dR_G in dR.array])
162                comm.sum(RdR_x)
163                comm.sum(dRdR_x)
164                lam_x = -RdR_x / dRdR_x
165
166            # Limit abs(lam) to [0.15, 1.0]
167            if self.limit_lambda:
168                upper = self.limit_lambda['upper']
169                lower = self.limit_lambda['lower']
170                if self.limit_lambda.get('absolute', False):
171                    lam_x = np.where(np.abs(lam_x) < lower,
172                                     lower * np.sign(lam_x), lam_x)
173                    lam_x = np.where(np.abs(lam_x) > upper,
174                                     upper * np.sign(lam_x), lam_x)
175                else:
176                    lam_x = np.where(lam_x < lower, lower, lam_x)
177                    lam_x = np.where(lam_x > upper, upper, lam_x)
178
179            # lam_x[:] = 0.1
180
181            # New trial wavefunction and residual
182            with self.timer('Update psi'):
183                for lam, psit_G, dpsit_G, R_G, dR_G in zip(
184                        lam_x, psitb.array,
185                        dpsit.array, Rb.array,
186                        dR.array):
187                    axpy(lam, dpsit_G, psit_G)  # psit_G += lam * dpsit_G
188                    axpy(lam, dR_G, R_G)  # R_G += lam** dR_G
189
190            self.timer.start('DIIS step')
191            # DIIS step
192            for nit in range(1, self.niter):
193                # Do not perform DIIS if error is small
194                # if abs(error_block / B) < self.rtol:
195                #     break
196
197                # Update the subspace
198                psit_diis_nxG[nit:B * self.niter:self.niter] = psitb.array
199                R_diis_nxG[nit:B * self.niter:self.niter] = Rb.array
200
201                # XXX Only integrals of nit old psits would be needed
202                # self.timer.start('projections')
203                # wfs.pt.integrate(psit_diis_nxG, P_diis_anxi, kpt.q)
204                # self.timer.stop('projections')
205                if nit > 1 or self.limit_lambda:
206                    for ib in range(B):
207                        istart = ib * self.niter
208                        iend = istart + nit + 1
209
210                        # Residual matrix
211                        self.timer.start('Construct matrix')
212                        R_nn = wfs.integrate(R_diis_nxG[istart:iend],
213                                             R_diis_nxG[istart:iend],
214                                             global_integral=True)
215
216                        # Full matrix
217                        A_nn = -np.ones((nit + 2, nit + 2), wfs.dtype)
218                        A_nn[:nit + 1, :nit + 1] = R_nn[:]
219                        A_nn[-1, -1] = 0.0
220                        x_n = np.zeros(nit + 2, wfs.dtype)
221                        x_n[-1] = -1.0
222                        self.timer.stop('Construct matrix')
223                        with self.timer('Linear solve'):
224                            alpha_i = np.linalg.solve(A_nn, x_n)[:-1]
225
226                        self.timer.start('Update trial vectors')
227                        psitb.array[ib] = alpha_i[nit] * psit_diis_nxG[istart +
228                                                                       nit]
229                        Rb.array[ib] = alpha_i[nit] * R_diis_nxG[istart + nit]
230                        for i in range(nit):
231                            # axpy(alpha_i[i], psit_diis_nxG[istart + i],
232                            #      psit_diis_nxG[istart + nit])
233                            # axpy(alpha_i[i], R_diis_nxG[istart + i],
234                            #      R_diis_nxG[istart + nit])
235                            axpy(alpha_i[i], psit_diis_nxG[istart + i],
236                                 psitb.array[ib])
237                            axpy(alpha_i[i], R_diis_nxG[istart + i],
238                                 Rb.array[ib])
239                        self.timer.stop('Update trial vectors')
240
241                if nit < self.niter - 1:
242                    with self.timer('precondition'):
243                        self.preconditioner(Rb.array, kpt,
244                                            ekin_x, out=dpsit.array)
245
246                    for psit_G, lam, dpsit_G in zip(psitb.array, lam_x,
247                                                    dpsit.array):
248                        axpy(lam, dpsit_G, psit_G)
249
250                    # Calculate the new residuals
251                    self.timer.start('Calculate residuals')
252                    psitb.apply(Ht, out=Rb)
253                    psitb.matrix_elements(wfs.pt, out=P)
254                    self.calculate_residuals(kpt, wfs, ham, psitb,
255                                             P, kpt.eps_n[n_x], Rb, P2, n_x,
256                                             calculate_change=True)
257                    self.timer.stop('Calculate residuals')
258
259            self.timer.stop('DIIS step')
260            # Final trial step
261            with self.timer('precondition'):
262                self.preconditioner(Rb.array, kpt, ekin_x, out=dpsit.array)
263
264            self.timer.start('Update psi')
265            if self.trial_step is not None:
266                lam_x[:] = self.trial_step
267            for lam, psit_G, dpsit_G in zip(lam_x, psitb.array, dpsit.array):
268                axpy(lam, dpsit_G, psit_G)  # psit_G += lam * dpsit_G
269            self.timer.stop('Update psi')
270
271        self.timer.stop('RMM-DIIS')
272        return error
273
274    def __repr__(self):
275        repr_string = 'RMM-DIIS eigensolver\n'
276        repr_string += '       keep_htpsit: %s\n' % self.keep_htpsit
277        repr_string += '       DIIS iterations: %d\n' % self.niter
278        repr_string += '       Threshold for DIIS: %5.1e\n' % self.rtol
279        repr_string += '       Limit lambda: %s\n' % self.limit_lambda
280        repr_string += '       use_rayleigh: %s\n' % self.use_rayleigh
281        repr_string += '       trial_step: %s' % self.trial_step
282        return repr_string
283