1import numpy as np
2from ase.units import Bohr
3
4from gpaw.lfc import BasisFunctions
5from gpaw.utilities import unpack
6from gpaw.utilities.tools import tri2full
7# from gpaw import debug
8# from gpaw.lcao.overlap import NewTwoCenterIntegrals as NewTCI
9from gpaw.lcao.tci import TCIExpansions
10from gpaw.utilities.blas import gemm, gemmdot
11from gpaw.wavefunctions.base import WaveFunctions
12from gpaw.lcao.atomic_correction import (DenseAtomicCorrection,
13                                         SparseAtomicCorrection)
14from gpaw.wavefunctions.mode import Mode
15
16
17class LCAO(Mode):
18    name = 'lcao'
19
20    def __init__(self, atomic_correction=None, interpolation=3,
21                 force_complex_dtype=False):
22        self.atomic_correction = atomic_correction
23        self.interpolation = interpolation
24        Mode.__init__(self, force_complex_dtype)
25
26    def __call__(self, *args, **kwargs):
27        return LCAOWaveFunctions(*args,
28                                 atomic_correction=self.atomic_correction,
29                                 **kwargs)
30
31    def __repr__(self):
32        return 'LCAO({})'.format(self.todict())
33
34    def todict(self):
35        dct = Mode.todict(self)
36        dct['interpolation'] = self.interpolation
37        return dct
38
39
40def update_phases(C_unM, q_u, ibzk_qc, spos_ac, oldspos_ac, setups, Mstart):
41    """Complex-rotate coefficients compensating discontinuous phase shift.
42
43    This changes the coefficients to counteract the phase discontinuity
44    of overlaps when atoms move across a cell boundary."""
45
46    # We don't want to apply any phase shift unless we crossed a cell
47    # boundary.  So we round the shift to either 0 or 1.
48    #
49    # Example: spos_ac goes from 0.01 to 0.99 -- this rounds to 1 and
50    # we apply the phase.  If someone moves an atom by half a cell
51    # without crossing a boundary, then we are out of luck.  But they
52    # should have reinitialized from LCAO anyway.
53    phase_qa = np.exp(2j * np.pi *
54                      np.dot(ibzk_qc, (spos_ac - oldspos_ac).T.round()))
55
56    for q, C_nM in zip(q_u, C_unM):
57        if C_nM is None:
58            continue
59        for a in range(len(spos_ac)):
60            M1 = setups.M_a[a] - Mstart
61            M2 = M1 + setups[a].nao
62            M1 = max(0, M1)
63            C_nM[:, M1:M2] *= phase_qa[q, a]  # (may truncate M2)
64
65
66# replace by class to make data structure perhaps a bit less confusing
67def get_r_and_offsets(nl, spos_ac, cell_cv):
68    r_and_offset_aao = {}
69
70    def add(a1, a2, R_c, offset):
71        if not (a1, a2) in r_and_offset_aao:
72            r_and_offset_aao[(a1, a2)] = []
73        r_and_offset_aao[(a1, a2)].append((R_c, offset))
74
75    for a1, spos1_c in enumerate(spos_ac):
76        a2_a, offsets = nl.get_neighbors(a1)
77        for a2, offset in zip(a2_a, offsets):
78            spos2_c = spos_ac[a2] + offset
79
80            R_c = np.dot(spos2_c - spos1_c, cell_cv)
81            add(a1, a2, R_c, offset)
82            if a1 != a2 or offset.any():
83                add(a2, a1, -R_c, -offset)
84
85    return r_and_offset_aao
86
87
88class LCAOWaveFunctions(WaveFunctions):
89    mode = 'lcao'
90
91    def __init__(self, ksl, gd, nvalence, setups, bd,
92                 dtype, world, kd, kptband_comm, timer,
93                 atomic_correction=None, collinear=True):
94        WaveFunctions.__init__(self, gd, nvalence, setups, bd,
95                               dtype, collinear, world, kd,
96                               kptband_comm, timer)
97        self.ksl = ksl
98        self.S_qMM = None
99        self.T_qMM = None
100        self.P_aqMi = None
101        self.debug_tci = False
102
103        if atomic_correction is None:
104            atomic_correction = 'sparse' if ksl.using_blacs else 'dense'
105
106        if atomic_correction == 'sparse':
107            self.atomic_correction_cls = SparseAtomicCorrection
108        else:
109            assert atomic_correction == 'dense'
110            self.atomic_correction_cls = DenseAtomicCorrection
111
112        # self.tci = NewTCI(gd.cell_cv, gd.pbc_c, setups, kd.ibzk_qc, kd.gamma)
113        with self.timer('TCI: Evaluate splines'):
114            self.tciexpansions = TCIExpansions.new_from_setups(setups)
115
116        self.basis_functions = BasisFunctions(gd,
117                                              [setup.phit_j
118                                               for setup in setups],
119                                              kd,
120                                              dtype=dtype,
121                                              cut=True)
122
123    def set_orthonormalized(self, o):
124        pass
125
126    def empty(self, n=(), global_array=False, realspace=False):
127        if realspace:
128            return self.gd.empty(n, self.dtype, global_array)
129        else:
130            if isinstance(n, int):
131                n = (n,)
132            nao = self.setups.nao
133            return np.empty(n + (nao,), self.dtype)
134
135    def __str__(self):
136        s = 'Wave functions: LCAO\n'
137        s += '  Diagonalizer: %s\n' % self.ksl.get_description()
138        s += ('  Atomic Correction: %s\n'
139              % self.atomic_correction_cls.description)
140        s += '  Datatype: %s\n' % self.dtype.__name__
141        return s
142
143    def set_eigensolver(self, eigensolver):
144        WaveFunctions.set_eigensolver(self, eigensolver)
145        if eigensolver:
146            eigensolver.initialize(self.gd, self.dtype, self.setups.nao,
147                                   self.ksl)
148
149    def set_positions(self, spos_ac, atom_partition=None, move_wfs=False):
150        oldspos_ac = self.spos_ac
151        with self.timer('Basic WFS set positions'):
152            WaveFunctions.set_positions(self, spos_ac, atom_partition)
153
154        with self.timer('Basis functions set positions'):
155            self.basis_functions.set_positions(spos_ac)
156
157        if self.ksl is not None:
158            self.basis_functions.set_matrix_distribution(self.ksl.Mstart,
159                                                         self.ksl.Mstop)
160
161        nq = len(self.kd.ibzk_qc)
162        nao = self.setups.nao
163        Mstop = self.ksl.Mstop
164        Mstart = self.ksl.Mstart
165        mynao = Mstop - Mstart
166
167        # if self.ksl.using_blacs:  # XXX
168        #     S and T have been distributed to a layout with blacs, so
169        #     discard them to force reallocation from scratch.
170        #
171        #     TODO: evaluate S and T when they *are* distributed, thus saving
172        #     memory and avoiding this problem
173        for kpt in self.kpt_u:
174            kpt.S_MM = None
175            kpt.T_MM = None
176
177        # Free memory in case of old matrices:
178        self.S_qMM = self.T_qMM = self.P_aqMi = None
179
180        if self.dtype == complex and oldspos_ac is not None:
181            update_phases([kpt.C_nM for kpt in self.kpt_u],
182                          [kpt.q for kpt in self.kpt_u],
183                          self.kd.ibzk_qc, spos_ac, oldspos_ac,
184                          self.setups, Mstart)
185
186        if 0:  # self.debug_tci:
187            # if self.ksl.using_blacs:
188            #     self.tci.set_matrix_distribution(Mstart, mynao)
189            oldS_qMM = np.empty((nq, mynao, nao), self.dtype)
190            oldT_qMM = np.empty((nq, mynao, nao), self.dtype)
191
192            oldP_aqMi = {}
193            for a in self.basis_functions.my_atom_indices:
194                ni = self.setups[a].ni
195                oldP_aqMi[a] = np.empty((nq, nao, ni), self.dtype)
196
197            # Calculate lower triangle of S and T matrices:
198            self.timer.start('tci calculate')
199            # self.tci.calculate(spos_ac, oldS_qMM, oldT_qMM,
200            #                   oldP_aqMi)
201            self.timer.stop('tci calculate')
202
203        self.timer.start('mktci')
204        manytci = self.tciexpansions.get_manytci_calculator(
205            self.setups, self.gd, spos_ac, self.kd.ibzk_qc, self.dtype,
206            self.timer)
207        self.timer.stop('mktci')
208        self.manytci = manytci
209        self.newtci = manytci.tci
210
211        my_atom_indices = self.basis_functions.my_atom_indices
212        self.timer.start('ST tci')
213        newS_qMM, newT_qMM = manytci.O_qMM_T_qMM(self.gd.comm,
214                                                 Mstart, Mstop,
215                                                 self.ksl.using_blacs)
216        self.timer.stop('ST tci')
217        self.timer.start('P tci')
218        P_qIM = manytci.P_qIM(my_atom_indices)
219        self.timer.stop('P tci')
220        self.P_aqMi = newP_aqMi = manytci.P_aqMi(my_atom_indices)
221        self.P_qIM = P_qIM  # XXX atomic correction
222
223        self.atomic_correction = self.atomic_correction_cls.new_from_wfs(self)
224
225        # TODO
226        #   OK complex/conj, periodic images
227        #   OK scalapack
228        #   derivatives/forces
229        #   sparse
230        #   use symmetry/conj tricks to reduce calculations
231        #   enable caching of spherical harmonics
232
233        # if self.atomic_correction.name != 'dense':
234        # from gpaw.lcao.newoverlap import newoverlap
235        # self.P_neighbors_a, self.P_aaqim = newoverlap(self, spos_ac)
236
237        # if self.atomic_correction.name == 'scipy':
238        #    Pold_qIM = self.atomic_correction.Psparse_qIM
239        #    for q in range(nq):
240        #        maxerr = abs(Pold_qIM[q] - P_qIM[q]).max()
241        #        print('sparse maxerr', maxerr)
242        #        assert maxerr == 0
243
244        self.atomic_correction.add_overlap_correction(newS_qMM)
245        if self.debug_tci:
246            self.atomic_correction.add_overlap_correction(oldS_qMM)
247
248        self.allocate_arrays_for_projections(my_atom_indices)
249
250        # S_MM = None  # allow garbage collection of old S_qMM after redist
251        if self.debug_tci:
252            oldS_qMM = self.ksl.distribute_overlap_matrix(oldS_qMM, root=-1)
253            oldT_qMM = self.ksl.distribute_overlap_matrix(oldT_qMM, root=-1)
254
255        newS_qMM = self.ksl.distribute_overlap_matrix(newS_qMM, root=-1)
256        newT_qMM = self.ksl.distribute_overlap_matrix(newT_qMM, root=-1)
257
258        # if (debug and self.bd.comm.size == 1 and self.gd.comm.rank == 0 and
259        #     nao > 0 and not self.ksl.using_blacs):
260        #     S and T are summed only on comm master, so check only there
261        #     from numpy.linalg import eigvalsh
262        #     self.timer.start('Check positive definiteness')
263        #     for S_MM in S_qMM:
264        #         tri2full(S_MM, UL='L')
265        #         smin = eigvalsh(S_MM).real.min()
266        #         if smin < 0:
267        #             raise RuntimeError('Overlap matrix has negative '
268        #                               'eigenvalue: %e' % smin)
269        #     self.timer.stop('Check positive definiteness')
270        self.positions_set = True
271
272        if self.debug_tci:
273            Serr = np.abs(newS_qMM - oldS_qMM).max()
274            Terr = np.abs(newT_qMM - oldT_qMM).max()
275            print('S maxerr', Serr)
276            print('T maxerr', Terr)
277            try:
278                assert Terr < 1e-15, Terr
279            except AssertionError:
280                np.set_printoptions(precision=6)
281                if self.world.rank == 0:
282                    print(newT_qMM)
283                    print(oldT_qMM)
284                    print(newT_qMM - oldT_qMM)
285                raise
286            assert Serr < 1e-15, Serr
287
288            assert len(oldP_aqMi) == len(newP_aqMi)
289            for a in oldP_aqMi:
290                Perr = np.abs(oldP_aqMi[a] - newP_aqMi[a]).max()
291                assert Perr < 1e-15, (a, Perr)
292
293        for kpt in self.kpt_u:
294            q = kpt.q
295            kpt.S_MM = newS_qMM[q]
296            kpt.T_MM = newT_qMM[q]
297        self.S_qMM = newS_qMM
298        self.T_qMM = newT_qMM
299
300        # Elpa wants to reuse the decomposed form of S_qMM.
301        # We need to keep track of the existence of that object here,
302        # since this is where we change S_qMM.  Hence, expect this to
303        # become arrays after the first diagonalization:
304        self.decomposed_S_qMM = [None] * len(self.S_qMM)
305
306    def initialize(self, density, hamiltonian, spos_ac):
307        # Note: The above line exists also in set_positions.
308        # This is guaranteed to be correct, but we can probably remove one.
309        # Of course no human can understand the initialization process,
310        # so this will be some other day.
311        self.timer.start('LCAO WFS Initialize')
312        if density.nt_sG is None:
313            if self.kpt_u[0].f_n is None or self.kpt_u[0].C_nM is None:
314                density.initialize_from_atomic_densities(self.basis_functions)
315            else:
316                # We have the info we need for a density matrix, so initialize
317                # from that instead of from scratch.  This will be the case
318                # after set_positions() during a relaxation
319                density.initialize_from_wavefunctions(self)
320            # Initialize GLLB-potential from basis function orbitals
321            if hamiltonian.xc.type == 'GLLB':
322                hamiltonian.xc.initialize_from_atomic_orbitals(
323                    self.basis_functions)
324
325        else:
326            # After a restart, nt_sg doesn't exist yet, so we'll have to
327            # make sure it does.  Of course, this should have been taken care
328            # of already by this time, so we should improve the code elsewhere
329            density.calculate_normalized_charges_and_mix()
330
331        hamiltonian.update(density)
332        self.timer.stop('LCAO WFS Initialize')
333
334        return 0, 0
335
336    def initialize_wave_functions_from_lcao(self):
337        """Fill the calc.wfs.kpt_[u].psit_nG arrays with useful data.
338
339        Normally psit_nG is NOT used in lcao mode, but some extensions
340        (like ase.dft.wannier) want to have it.
341        This code is adapted from fd.py / initialize_from_lcao_coefficients()
342        and fills psit_nG with data constructed from the current lcao
343        coefficients (kpt.C_nM).
344
345        (This may or may not work in band-parallel case!)
346        """
347        from gpaw.wavefunctions.arrays import UniformGridWaveFunctions
348        bfs = self.basis_functions
349        for kpt in self.kpt_u:
350            kpt.psit = UniformGridWaveFunctions(
351                self.bd.nbands, self.gd, self.dtype, kpt=kpt.q, dist=None,
352                spin=kpt.s, collinear=True)
353            kpt.psit_nG[:] = 0.0
354            bfs.lcao_to_grid(kpt.C_nM, kpt.psit_nG[:self.bd.mynbands], kpt.q)
355
356    def initialize_wave_functions_from_restart_file(self):
357        """Dummy function to ensure compatibility to fd mode"""
358        self.initialize_wave_functions_from_lcao()
359
360    def add_orbital_density(self, nt_G, kpt, n):
361        rank, q = self.kd.get_rank_and_index(kpt.k)
362        u = q * self.nspins + kpt.s
363        assert rank == self.kd.comm.rank
364        assert self.kpt_u[u] is kpt
365        psit_G = self._get_wave_function_array(u, n, realspace=True)
366        self.add_realspace_orbital_to_density(nt_G, psit_G)
367
368    def calculate_density_matrix(self, f_n, C_nM, rho_MM=None):
369        self.timer.start('Calculate density matrix')
370        rho_MM = self.ksl.calculate_density_matrix(f_n, C_nM, rho_MM)
371        self.timer.stop('Calculate density matrix')
372        return rho_MM
373
374        if 1:
375            # XXX Should not conjugate, but call gemm(..., 'c')
376            # Although that requires knowing C_Mn and not C_nM.
377            # that also conforms better to the usual conventions in literature
378            Cf_Mn = C_nM.T.conj() * f_n
379            self.timer.start('gemm')
380            gemm(1.0, C_nM, Cf_Mn, 0.0, rho_MM, 'n')
381            self.timer.stop('gemm')
382            self.timer.start('band comm sum')
383            self.bd.comm.sum(rho_MM)
384            self.timer.stop('band comm sum')
385        else:
386            # Alternative suggestion. Might be faster. Someone should test this
387            from gpaw.utilities.blas import r2k
388            C_Mn = C_nM.T.copy()
389            r2k(0.5, C_Mn, f_n * C_Mn, 0.0, rho_MM)
390            tri2full(rho_MM)
391
392    def calculate_atomic_density_matrices_with_occupation(self, D_asp, f_un):
393        # ac = self.atomic_correction
394        # if ac.implements_distributed_projections():
395        #     D2_asp = ac.redistribute(self, D_asp, type='asp', op='forth')
396        #     WaveFunctions.calculate_atomic_density_matrices_with_occupation(
397        #         self, D2_asp, f_un)
398        #     D3_asp = ac.redistribute(self, D2_asp, type='asp', op='back')
399        #     for a in D_asp:
400        #         D_asp[a][:] = D3_asp[a]
401        # else:
402        WaveFunctions.calculate_atomic_density_matrices_with_occupation(
403            self, D_asp, f_un)
404
405    def calculate_density_matrix_delta(self, d_nn, C_nM, rho_MM=None):
406        self.timer.start('Calculate density matrix')
407        rho_MM = self.ksl.calculate_density_matrix_delta(d_nn, C_nM, rho_MM)
408        self.timer.stop('Calculate density matrix')
409        return rho_MM
410
411    def add_to_density_from_k_point_with_occupation(self, nt_sG, kpt, f_n):
412        """Add contribution to pseudo electron-density. Do not use the standard
413        occupation numbers, but ones given with argument f_n."""
414        # Custom occupations are used in calculation of response potential
415        # with GLLB-potential
416        if kpt.rho_MM is None:
417            rho_MM = self.calculate_density_matrix(f_n, kpt.C_nM)
418            if hasattr(kpt, 'c_on'):
419                assert self.bd.comm.size == 1
420                d_nn = np.zeros((self.bd.mynbands, self.bd.mynbands),
421                                dtype=kpt.C_nM.dtype)
422                for ne, c_n in zip(kpt.ne_o, kpt.c_on):
423                    assert abs(c_n.imag).max() < 1e-14
424                    d_nn += ne * np.outer(c_n.conj(), c_n).real
425                rho_MM += self.calculate_density_matrix_delta(d_nn, kpt.C_nM)
426        else:
427            rho_MM = kpt.rho_MM
428        self.timer.start('Construct density')
429        self.basis_functions.construct_density(rho_MM, nt_sG[kpt.s], kpt.q)
430        self.timer.stop('Construct density')
431
432    def add_to_kinetic_density_from_k_point(self, taut_G, kpt):
433        raise NotImplementedError('Kinetic density calculation for LCAO '
434                                  'wavefunctions is not implemented.')
435
436    def calculate_forces(self, hamiltonian, F_av):
437        self.timer.start('LCAO forces')
438
439        ksl = self.ksl
440        nao = ksl.nao
441        mynao = ksl.mynao
442        dtype = self.dtype
443        # tci = self.tci
444        newtci = self.newtci
445        gd = self.gd
446        bfs = self.basis_functions
447
448        Mstart = ksl.Mstart
449        Mstop = ksl.Mstop
450
451        from gpaw.kohnsham_layouts import BlacsOrbitalLayouts
452        isblacs = isinstance(ksl, BlacsOrbitalLayouts)  # XXX
453
454        if not isblacs:
455            self.timer.start('TCI derivative')
456
457            dThetadR_qvMM, dTdR_qvMM = self.manytci.O_qMM_T_qMM(
458                gd.comm, Mstart, Mstop, False, derivative=True)
459
460            dPdR_aqvMi = self.manytci.P_aqMi(
461                self.basis_functions.my_atom_indices, derivative=True)
462
463            gd.comm.sum(dThetadR_qvMM)
464            gd.comm.sum(dTdR_qvMM)
465            self.timer.stop('TCI derivative')
466
467            my_atom_indices = bfs.my_atom_indices
468            atom_indices = bfs.atom_indices
469
470            def _slices(indices):
471                for a in indices:
472                    M1 = bfs.M_a[a] - Mstart
473                    M2 = M1 + self.setups[a].nao
474                    if M2 > 0:
475                        yield a, max(0, M1), M2
476
477            def slices():
478                return _slices(atom_indices)
479
480            def my_slices():
481                return _slices(my_atom_indices)
482
483        dH_asp = hamiltonian.dH_asp
484        vt_sG = hamiltonian.vt_sG
485
486        #
487        #         -----                    -----
488        #          \    -1                  \    *
489        # E      =  )  S     H    rho     =  )  c     eps  f  c
490        #  mu nu   /    mu x  x z    z nu   /    n mu    n  n  n nu
491        #         -----                    -----
492        #          x z                       n
493        #
494        # We use the transpose of that matrix.  The first form is used
495        # if rho is given, otherwise the coefficients are used.
496        self.timer.start('Initial')
497
498        rhoT_uMM = []
499        ET_uMM = []
500
501        if not isblacs:
502            if self.kpt_u[0].rho_MM is None:
503                self.timer.start('Get density matrix')
504                for kpt in self.kpt_u:
505                    rhoT_MM = ksl.get_transposed_density_matrix(kpt.f_n,
506                                                                kpt.C_nM)
507                    rhoT_uMM.append(rhoT_MM)
508                    ET_MM = ksl.get_transposed_density_matrix(kpt.f_n *
509                                                              kpt.eps_n,
510                                                              kpt.C_nM)
511                    ET_uMM.append(ET_MM)
512
513                    if hasattr(kpt, 'c_on'):
514                        # XXX does this work with BLACS/non-BLACS/etc.?
515                        assert self.bd.comm.size == 1
516                        d_nn = np.zeros((self.bd.mynbands, self.bd.mynbands),
517                                        dtype=kpt.C_nM.dtype)
518                        for ne, c_n in zip(kpt.ne_o, kpt.c_on):
519                            d_nn += ne * np.outer(c_n.conj(), c_n)
520                        rhoT_MM += ksl.get_transposed_density_matrix_delta(
521                            d_nn, kpt.C_nM)
522                        ET_MM += ksl.get_transposed_density_matrix_delta(
523                            d_nn * kpt.eps_n, kpt.C_nM)
524                self.timer.stop('Get density matrix')
525            else:
526                rhoT_uMM = []
527                ET_uMM = []
528                for kpt in self.kpt_u:
529                    H_MM = self.eigensolver.calculate_hamiltonian_matrix(
530                        hamiltonian, self, kpt)
531                    tri2full(H_MM)
532                    S_MM = kpt.S_MM.copy()
533                    tri2full(S_MM)
534                    ET_MM = np.linalg.solve(S_MM, gemmdot(H_MM,
535                                                          kpt.rho_MM)).T.copy()
536                    del S_MM, H_MM
537                    rhoT_MM = kpt.rho_MM.T.copy()
538                    rhoT_uMM.append(rhoT_MM)
539                    ET_uMM.append(ET_MM)
540        self.timer.stop('Initial')
541
542        if isblacs:  # XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
543            from gpaw.blacs import BlacsGrid, Redistributor
544
545            def get_density_matrix(f_n, C_nM, redistributor):
546                rho1_mm = ksl.calculate_blocked_density_matrix(f_n,
547                                                               C_nM).conj()
548                rho_mm = redistributor.redistribute(rho1_mm)
549                return rho_mm
550
551            # pcutoff_a = [max([pt.get_cutoff() for pt in setup.pt_j])
552            #              for setup in self.setups]
553            # phicutoff_a = [max([phit.get_cutoff() for phit in setup.phit_j])
554            #                for setup in self.setups]
555
556            # XXX should probably use bdsize x gdsize instead
557            # That would be consistent with some existing grids
558            grid = BlacsGrid(ksl.block_comm, self.gd.comm.size,
559                             self.bd.comm.size)
560
561            blocksize1 = -(-nao // grid.nprow)
562            blocksize2 = -(-nao // grid.npcol)
563            # XXX what are rows and columns actually?
564            desc = grid.new_descriptor(nao, nao, blocksize1, blocksize2)
565
566            rhoT_umm = []
567            ET_umm = []
568            redistributor = Redistributor(grid.comm, ksl.mmdescriptor, desc)
569            Fpot_av = np.zeros_like(F_av)
570            for u, kpt in enumerate(self.kpt_u):
571                self.timer.start('Get density matrix')
572                rhoT_mm = get_density_matrix(kpt.f_n, kpt.C_nM, redistributor)
573                rhoT_umm.append(rhoT_mm)
574                self.timer.stop('Get density matrix')
575
576                self.timer.start('Potential')
577                rhoT_mM = ksl.distribute_to_columns(rhoT_mm, desc)
578
579                vt_G = vt_sG[kpt.s]
580                Fpot_av += bfs.calculate_force_contribution(vt_G, rhoT_mM,
581                                                            kpt.q)
582                del rhoT_mM
583                self.timer.stop('Potential')
584
585            self.timer.start('Get density matrix')
586            for kpt in self.kpt_u:
587                ET_mm = get_density_matrix(kpt.f_n * kpt.eps_n, kpt.C_nM,
588                                           redistributor)
589                ET_umm.append(ET_mm)
590            self.timer.stop('Get density matrix')
591
592            M1start = blocksize1 * grid.myrow
593            M2start = blocksize2 * grid.mycol
594
595            M1stop = min(M1start + blocksize1, nao)
596            M2stop = min(M2start + blocksize2, nao)
597
598            m1max = M1stop - M1start
599            m2max = M2stop - M2start
600
601        if not isblacs:
602            # Kinetic energy contribution
603            #
604            #           ----- d T
605            #  a         \       mu nu
606            # F += 2 Re   )   -------- rho
607            #            /    d R         nu mu
608            #           -----    mu nu
609            #        mu in a; nu
610            #
611            Fkin_av = np.zeros_like(F_av)
612            for u, kpt in enumerate(self.kpt_u):
613                dEdTrhoT_vMM = (dTdR_qvMM[kpt.q] *
614                                rhoT_uMM[u][np.newaxis]).real
615                # XXX load distribution!
616                for a, M1, M2 in my_slices():
617                    Fkin_av[a, :] += \
618                        2.0 * dEdTrhoT_vMM[:, M1:M2].sum(-1).sum(-1)
619            del dEdTrhoT_vMM
620
621            # Density matrix contribution due to basis overlap
622            #
623            #            ----- d Theta
624            #  a          \           mu nu
625            # F  += -2 Re  )   ------------  E
626            #             /        d R        nu mu
627            #            -----        mu nu
628            #         mu in a; nu
629            #
630            Ftheta_av = np.zeros_like(F_av)
631            for u, kpt in enumerate(self.kpt_u):
632                dThetadRE_vMM = (dThetadR_qvMM[kpt.q] *
633                                 ET_uMM[u][np.newaxis]).real
634                for a, M1, M2 in my_slices():
635                    Ftheta_av[a, :] += \
636                        -2.0 * dThetadRE_vMM[:, M1:M2].sum(-1).sum(-1)
637            del dThetadRE_vMM
638
639        if isblacs:
640            # from gpaw.lcao.overlap import TwoCenterIntegralCalculator
641            self.timer.start('Prepare TCI loop')
642            M_a = bfs.M_a
643            Fkin2_av = np.zeros_like(F_av)
644            Ftheta2_av = np.zeros_like(F_av)
645            atompairs = self.newtci.a1a2.get_atompairs()
646
647            self.timer.start('broadcast dH')
648            alldH_asp = {}
649            for a in range(len(self.setups)):
650                gdrank = bfs.sphere_a[a].rank
651                if gdrank == gd.rank:
652                    dH_sp = dH_asp[a]
653                else:
654                    ni = self.setups[a].ni
655                    dH_sp = np.empty((self.nspins, ni * (ni + 1) // 2))
656                gd.comm.broadcast(dH_sp, gdrank)
657                # okay, now everyone gets copies of dH_sp
658                alldH_asp[a] = dH_sp
659            self.timer.stop('broadcast dH')
660
661            # This will get sort of hairy.  We need to account for some
662            # three-center overlaps, such as:
663            #
664            #         a1
665            #      Phi   ~a3    a3  ~a3     a2     a2,a1
666            #   < ----  |p  > dH   <p   |Phi  > rho
667            #      dR
668            #
669            # To this end we will loop over all pairs of atoms (a1, a3),
670            # and then a sub-loop over (a3, a2).
671
672            self.timer.stop('Prepare TCI loop')
673            self.timer.start('Not so complicated loop')
674
675            for (a1, a2) in atompairs:
676                if a1 >= a2:
677                    # Actually this leads to bad load balance.
678                    # We should take a1 > a2 or a1 < a2 equally many times.
679                    # Maybe decide which of these choices
680                    # depending on whether a2 % 1 == 0
681                    continue
682
683                m1start = M_a[a1] - M1start
684                m2start = M_a[a2] - M2start
685                if m1start >= blocksize1 or m2start >= blocksize2:
686                    continue  # (we have only one block per CPU)
687
688                nm1 = self.setups[a1].nao
689                nm2 = self.setups[a2].nao
690
691                m1stop = min(m1start + nm1, m1max)
692                m2stop = min(m2start + nm2, m2max)
693
694                if m1stop <= 0 or m2stop <= 0:
695                    continue
696
697                m1start = max(m1start, 0)
698                m2start = max(m2start, 0)
699                J1start = max(0, M1start - M_a[a1])
700                J2start = max(0, M2start - M_a[a2])
701                M1stop = J1start + m1stop - m1start
702                J2stop = J2start + m2stop - m2start
703
704                dThetadR_qvmm, dTdR_qvmm = newtci.dOdR_dTdR(a1, a2)
705
706                for u, kpt in enumerate(self.kpt_u):
707                    rhoT_mm = rhoT_umm[u][m1start:m1stop, m2start:m2stop]
708                    ET_mm = ET_umm[u][m1start:m1stop, m2start:m2stop]
709                    Fkin_v = 2.0 * (dTdR_qvmm[kpt.q][:, J1start:M1stop,
710                                                     J2start:J2stop] *
711                                    rhoT_mm[np.newaxis]).real.sum(-1).sum(-1)
712                    Ftheta_v = 2.0 * (dThetadR_qvmm[kpt.q][:, J1start:M1stop,
713                                                           J2start:J2stop] *
714                                      ET_mm[np.newaxis]).real.sum(-1).sum(-1)
715                    Fkin2_av[a1] += Fkin_v
716                    Fkin2_av[a2] -= Fkin_v
717                    Ftheta2_av[a1] -= Ftheta_v
718                    Ftheta2_av[a2] += Ftheta_v
719
720            Fkin_av = Fkin2_av
721            Ftheta_av = Ftheta2_av
722            self.timer.stop('Not so complicated loop')
723
724            dHP_and_dSP_aauim = {}
725
726            a2values = {}
727            for (a2, a3) in atompairs:
728                if a3 not in a2values:
729                    a2values[a3] = []
730                a2values[a3].append(a2)
731
732            Fatom_av = np.zeros_like(F_av)
733            Frho_av = np.zeros_like(F_av)
734            self.timer.start('Complicated loop')
735            for a1, a3 in atompairs:
736                if a1 == a3:
737                    # Functions reside on same atom, so their overlap
738                    # does not change when atom is displaced
739                    continue
740                m1start = M_a[a1] - M1start
741                if m1start >= blocksize1:
742                    continue
743
744                nm1 = self.setups[a1].nao
745                m1stop = min(m1start + nm1, m1max)
746                if m1stop <= 0:
747                    continue
748
749                dPdR_qvim = newtci.dPdR(a3, a1)
750                if dPdR_qvim is None:
751                    continue
752
753                dPdR_qvmi = -dPdR_qvim.transpose(0, 1, 3, 2).conj()
754
755                m1start = max(m1start, 0)
756                J1start = max(0, M1start - M_a[a1])
757                J1stop = J1start + m1stop - m1start
758                dPdR_qvmi = dPdR_qvmi[:, :, J1start:J1stop, :].copy()
759                for a2 in a2values[a3]:
760                    m2start = M_a[a2] - M2start
761                    if m2start >= blocksize2:
762                        continue
763
764                    nm2 = self.setups[a2].nao
765                    m2stop = min(m2start + nm2, m2max)
766                    if m2stop <= 0:
767                        continue
768
769                    m2start = max(m2start, 0)
770                    J2start = max(0, M2start - M_a[a2])
771                    J2stop = J2start + m2stop - m2start
772
773                    if (a2, a3) in dHP_and_dSP_aauim:
774                        dHP_uim, dSP_uim = dHP_and_dSP_aauim[(a2, a3)]
775                    else:
776                        P_qim = newtci.P(a3, a2)
777                        if P_qim is None:
778                            continue
779                        P_qmi = P_qim.transpose(0, 2, 1).conj()
780                        P_qmi = P_qmi[:, J2start:J2stop].copy()
781                        dH_sp = alldH_asp[a3]
782                        dS_ii = self.setups[a3].dO_ii
783
784                        dHP_uim = []
785                        dSP_uim = []
786                        for u, kpt in enumerate(self.kpt_u):
787                            dH_ii = unpack(dH_sp[kpt.s])
788                            dHP_im = np.dot(P_qmi[kpt.q], dH_ii).T.conj()
789                            # XXX only need nq of these,
790                            # but the looping is over all u
791                            dSP_im = np.dot(P_qmi[kpt.q], dS_ii).T.conj()
792                            dHP_uim.append(dHP_im)
793                            dSP_uim.append(dSP_im)
794                            dHP_and_dSP_aauim[(a2, a3)] = dHP_uim, dSP_uim
795
796                    for u, kpt in enumerate(self.kpt_u):
797                        rhoT_mm = rhoT_umm[u][m1start:m1stop, m2start:m2stop]
798                        ET_mm = ET_umm[u][m1start:m1stop, m2start:m2stop]
799                        dPdRdHP_vmm = np.dot(dPdR_qvmi[kpt.q], dHP_uim[u])
800                        dPdRdSP_vmm = np.dot(dPdR_qvmi[kpt.q], dSP_uim[u])
801
802                        Fatom_c = 2.0 * (dPdRdHP_vmm *
803                                         rhoT_mm).real.sum(-1).sum(-1)
804                        Frho_c = 2.0 * (dPdRdSP_vmm *
805                                        ET_mm).real.sum(-1).sum(-1)
806                        Fatom_av[a1] += Fatom_c
807                        Fatom_av[a3] -= Fatom_c
808
809                        Frho_av[a1] -= Frho_c
810                        Frho_av[a3] += Frho_c
811
812            self.timer.stop('Complicated loop')
813
814        if not isblacs:
815            # Potential contribution
816            #
817            #           -----      /  d Phi  (r)
818            #  a         \        |        mu    ~
819            # F += -2 Re  )       |   ---------- v (r)  Phi  (r) dr rho
820            #            /        |     d R                nu          nu mu
821            #           -----    /         a
822            #        mu in a; nu
823            #
824            self.timer.start('Potential')
825            Fpot_av = np.zeros_like(F_av)
826
827            for u, kpt in enumerate(self.kpt_u):
828                vt_G = vt_sG[kpt.s]
829                Fpot_av += bfs.calculate_force_contribution(vt_G, rhoT_uMM[u],
830                                                            kpt.q)
831            self.timer.stop('Potential')
832
833            # Density matrix contribution from PAW correction
834            #
835            #           -----                        -----
836            #  a         \      a                     \     b
837            # F +=  2 Re  )    Z      E        - 2 Re  )   Z      E
838            #            /      mu nu  nu mu          /     mu nu  nu mu
839            #           -----                        -----
840            #           mu nu                    b; mu in a; nu
841            #
842            # with
843            #                  b*
844            #         -----  dP
845            #   b      \       i mu    b   b
846            #  Z     =  )   -------- dS   P
847            #   mu nu  /     dR        ij  j nu
848            #         -----    b mu
849            #           ij
850            #
851            self.timer.start('Paw correction')
852            Frho_av = np.zeros_like(F_av)
853            for u, kpt in enumerate(self.kpt_u):
854                work_MM = np.zeros((mynao, nao), dtype)
855                ZE_MM = None
856                for b in my_atom_indices:
857                    setup = self.setups[b]
858                    dO_ii = np.asarray(setup.dO_ii, dtype)
859                    dOP_iM = np.zeros((setup.ni, nao), dtype)
860                    gemm(1.0, self.P_aqMi[b][kpt.q], dO_ii, 0.0, dOP_iM, 'c')
861                    for v in range(3):
862                        gemm(1.0, dOP_iM,
863                             dPdR_aqvMi[b][kpt.q][v][Mstart:Mstop],
864                             0.0, work_MM, 'n')
865                        ZE_MM = (work_MM * ET_uMM[u]).real
866                        for a, M1, M2 in slices():
867                            dE = 2 * ZE_MM[M1:M2].sum()
868                            Frho_av[a, v] -= dE  # the "b; mu in a; nu" term
869                            Frho_av[b, v] += dE  # the "mu nu" term
870            del work_MM, ZE_MM
871            self.timer.stop('Paw correction')
872
873            # Atomic density contribution
874            #            -----                         -----
875            #  a          \     a                       \     b
876            # F  += -2 Re  )   A      rho       + 2 Re   )   A      rho
877            #             /     mu nu    nu mu          /     mu nu    nu mu
878            #            -----                         -----
879            #            mu nu                     b; mu in a; nu
880            #
881            #                  b*
882            #         ----- d P
883            #  b       \       i mu   b   b
884            # A     =   )   ------- dH   P
885            #  mu nu   /    d R       ij  j nu
886            #         -----    b mu
887            #           ij
888            #
889            self.timer.start('Atomic Hamiltonian force')
890            Fatom_av = np.zeros_like(F_av)
891            for u, kpt in enumerate(self.kpt_u):
892                for b in my_atom_indices:
893                    H_ii = np.asarray(unpack(dH_asp[b][kpt.s]), dtype)
894                    HP_iM = gemmdot(H_ii,
895                                    np.ascontiguousarray(
896                                        self.P_aqMi[b][kpt.q].T.conj()))
897                    for v in range(3):
898                        dPdR_Mi = dPdR_aqvMi[b][kpt.q][v][Mstart:Mstop]
899                        ArhoT_MM = (gemmdot(dPdR_Mi, HP_iM) * rhoT_uMM[u]).real
900                        for a, M1, M2 in slices():
901                            dE = 2 * ArhoT_MM[M1:M2].sum()
902                            Fatom_av[a, v] += dE  # the "b; mu in a; nu" term
903                            Fatom_av[b, v] -= dE  # the "mu nu" term
904            self.timer.stop('Atomic Hamiltonian force')
905
906        F_av += Fkin_av + Fpot_av + Ftheta_av + Frho_av + Fatom_av
907        self.timer.start('Wait for sum')
908        ksl.orbital_comm.sum(F_av)
909        if self.bd.comm.rank == 0:
910            self.kd.comm.sum(F_av, 0)
911        self.timer.stop('Wait for sum')
912        self.timer.stop('LCAO forces')
913
914    def _get_wave_function_array(self, u, n, realspace=True, periodic=False):
915        # XXX Taking kpt is better than taking u
916        kpt = self.kpt_u[u]
917        C_M = kpt.C_nM[n]
918
919        if realspace:
920            psit_G = self.gd.zeros(dtype=self.dtype)
921            self.basis_functions.lcao_to_grid(C_M, psit_G, kpt.q)
922            if periodic and self.dtype == complex:
923                k_c = self.kd.ibzk_kc[kpt.k]
924                return self.gd.plane_wave(-k_c) * psit_G
925            return psit_G
926        else:
927            return C_M
928
929    def write(self, writer, write_wave_functions=False):
930        WaveFunctions.write(self, writer)
931        if write_wave_functions:
932            self.write_wave_functions(writer)
933
934    def write_wave_functions(self, writer):
935        writer.add_array(
936            'coefficients',
937            (self.nspins, self.kd.nibzkpts, self.bd.nbands, self.setups.nao),
938            dtype=self.dtype)
939        for s in range(self.nspins):
940            for k in range(self.kd.nibzkpts):
941                C_nM = self.collect_array('C_nM', k, s)
942                writer.fill(C_nM * Bohr**-1.5)
943
944    def read(self, reader):
945        WaveFunctions.read(self, reader)
946        r = reader.wave_functions
947        if 'coefficients' in r:
948            self.read_wave_functions(r)
949
950    def read_wave_functions(self, reader):
951        for kpt in self.kpt_u:
952            C_nM = reader.proxy('coefficients', kpt.s, kpt.k)
953            kpt.C_nM = self.bd.empty(self.setups.nao, dtype=self.dtype)
954            for myn, C_M in enumerate(kpt.C_nM):
955                n = self.bd.global_index(myn)
956                # XXX number of bands could have been rounded up!
957                if n >= len(C_nM):
958                    break
959                C_M[:] = C_nM[n] * Bohr**1.5
960
961    def estimate_memory(self, mem):
962        nq = len(self.kd.ibzk_qc)
963        nao = self.setups.nao
964        ni_total = sum([setup.ni for setup in self.setups])
965        itemsize = mem.itemsize[self.dtype]
966        mem.subnode('C [qnM]', nq * self.bd.mynbands * nao * itemsize)
967        nM1, nM2 = self.ksl.get_overlap_matrix_shape()
968        mem.subnode('S, T [2 x qmm]', 2 * nq * nM1 * nM2 * itemsize)
969        mem.subnode('P [aqMi]', nq * nao * ni_total // self.gd.comm.size)
970        # self.tci.estimate_memory(mem.subnode('TCI'))
971        self.basis_functions.estimate_memory(mem.subnode('BasisFunctions'))
972        self.eigensolver.estimate_memory(mem.subnode('Eigensolver'),
973                                         self.dtype)
974