1from typing import Tuple, Dict
2
3import numpy as np
4
5from gpaw.xc import XC
6from .coulomb import coulomb_interaction
7from .forces import calculate_forces
8from .paw import calculate_paw_stuff
9from .scf import apply1, apply2
10from .symmetry import Symmetry
11
12
13class HybridXC:
14    orbital_dependent = True
15    type = 'HYB'
16
17    def __init__(self,
18                 xcname: str,
19                 fraction: float = None,
20                 omega: float = None):
21        from . import parse_name
22        if xcname in ['EXX', 'PBE0', 'HSE03', 'HSE06', 'B3LYP']:
23            if fraction is not None or omega is not None:
24                raise ValueError
25            self.name = xcname
26            xcname, fraction, omega = parse_name(xcname)
27        else:
28            if fraction is None or omega is None:
29                raise ValueError
30            self.name = f'{xcname}-{fraction:.3f}-{omega:.3f}'
31
32        self.xc = XC(xcname)
33        self.exx_fraction = fraction
34        self.omega = omega
35
36        if xcname == 'null':
37            self.description = ''
38        else:
39            self.description = f'{xcname} + '
40        self.description += f'{fraction} * EXX(omega = {omega} bohr^-1)'
41
42        self.vlda_sR = None
43        self.v_sknG: Dict[Tuple[int, int], np.ndarray] = {}
44
45        self.ecc = np.nan
46        self.evc = np.nan
47        self.evv = np.nan
48        self.ekin = np.nan
49
50        self.sym = None
51        self.coulomb = None
52
53    def get_setup_name(self):
54        return 'PBE'
55
56    def initialize(self, dens, ham, wfs):
57        self.dens = dens
58        self.wfs = wfs
59        self.ecc = sum(setup.ExxC for setup in wfs.setups) * self.exx_fraction
60        assert wfs.world.size == wfs.gd.comm.size
61
62    def get_description(self):
63        return self.description
64
65    def set_positions(self, spos_ac):
66        self.spos_ac = spos_ac
67
68    def calculate(self, gd, nt_sr, vt_sr):
69        energy = self.ecc + self.evv + self.evc
70        energy += self.xc.calculate(gd, nt_sr, vt_sr)
71        return energy
72
73    def calculate_paw_correction(self, setup, D_sp, dH_sp=None, a=None):
74        return self.xc.calculate_paw_correction(setup, D_sp, dH_sp, a=a)
75
76    def get_kinetic_energy_correction(self):
77        return self.ekin
78
79    def apply_orbital_dependent_hamiltonian(self, kpt, psit_xG,
80                                            Htpsit_xG=None, dH_asp=None):
81        wfs = self.wfs
82        if self.coulomb is None:
83            self.coulomb = coulomb_interaction(self.omega, wfs.gd, wfs.kd)
84            self.description += f'\n{self.coulomb.description}'
85            self.sym = Symmetry(wfs.kd)
86
87        paw_s = calculate_paw_stuff(wfs, self.dens)  # ???????
88
89        if kpt.f_n is None:
90            # Just use LDA_X for first step:
91            if self.vlda_sR is None:
92                # First time:
93                self.vlda_sR = self.calculate_lda_potential()
94            pd = kpt.psit.pd
95            for psit_G, Htpsit_G in zip(psit_xG, Htpsit_xG):
96                Htpsit_G += pd.fft(self.vlda_sR[kpt.s] *
97                                   pd.ifft(psit_G, kpt.k), kpt.q)
98        else:
99            self.vlda_sR = None
100            if kpt.psit.array.base is psit_xG.base:
101                if (kpt.s, kpt.k) not in self.v_sknG:
102                    assert not any(s == kpt.s for s, k in self.v_sknG)
103                    evc, evv, ekin, v_knG = apply1(
104                        kpt, Htpsit_xG,
105                        wfs,
106                        self.coulomb, self.sym,
107                        paw_s[kpt.s])
108                    if kpt.s == 0:
109                        self.evc = 0.0
110                        self.evv = 0.0
111                        self.ekin = 0.0
112                    scale = 2 / wfs.nspins * self.exx_fraction
113                    self.evc += evc * scale
114                    self.evv += evv * scale
115                    self.ekin += ekin * scale
116                    self.v_sknG = {(kpt.s, k): v_nG
117                                   for k, v_nG in v_knG.items()}
118                v_nG = self.v_sknG.pop((kpt.s, kpt.k))
119            else:
120                v_nG = apply2(kpt, psit_xG, Htpsit_xG, wfs,
121                              self.coulomb, self.sym,
122                              paw_s[kpt.s])
123            Htpsit_xG += v_nG * self.exx_fraction
124
125    def calculate_lda_potential(self):
126        from gpaw.xc import XC
127        lda = XC('LDA_X')
128        nt_sr = self.dens.nt_sg
129        vt_sr = np.zeros_like(nt_sr)
130        vlda_sR = self.dens.gd.zeros(self.wfs.nspins)
131        lda.calculate(self.dens.finegd, nt_sr, vt_sr)
132        for vt_R, vt_r in zip(vlda_sR, vt_sr):
133            vt_R[:], _ = self.dens.pd3.restrict(vt_r, self.dens.pd2)
134        return vlda_sR * self.exx_fraction
135
136    def summary(self, log):
137        log(self.description)
138
139    def add_forces(self, F_av):
140        paw_s = calculate_paw_stuff(self.wfs, self.dens)
141        F_av += calculate_forces(self.wfs,
142                                 self.coulomb,
143                                 self.sym,
144                                 paw_s) * self.exx_fraction
145
146    def correct_hamiltonian_matrix(self, kpt, H_nn):
147        return
148
149    def rotate(self, kpt, U_nn):
150        pass  # 1 / 0
151
152    def add_correction(self, kpt, psit_xG, Htpsit_xG, P_axi, c_axi, n_x,
153                       calculate_change=False):
154        pass  # 1 / 0
155
156    def read(self, reader):
157        pass
158
159    def write(self, writer):
160        pass
161
162    def set_grid_descriptor(self, gd):
163        pass
164