1from math import pi, sqrt 2from warnings import warn 3from typing import Optional, List, Dict 4 5import numpy as np 6from ase.units import Bohr, Ha 7 8from gpaw import GPAW 9from gpaw.atom.shapefunc import shape_functions 10from gpaw.fftw import get_efficient_fft_size 11from gpaw.grid_descriptor import GridDescriptor 12from gpaw.lfc import LocalizedFunctionsCollection as LFC 13from gpaw.utilities import h2gpts 14from gpaw.wavefunctions.pw import PWDescriptor 15from gpaw.mpi import serial_comm 16from gpaw.setup import Setup 17from gpaw.spline import Spline 18from gpaw.typing import Array3D 19 20 21class Interpolator: 22 def __init__(self, gd1, gd2, dtype=float): 23 self.pd1 = PWDescriptor(0.0, gd1, dtype) 24 self.pd2 = PWDescriptor(0.0, gd2, dtype) 25 26 def interpolate(self, a_r): 27 return self.pd1.interpolate(a_r, self.pd2)[0] 28 29 30POINTS = 200 31 32 33class PS2AE: 34 """Transform PS to AE wave functions. 35 36 Interpolates PS wave functions to a fine grid and adds PAW 37 corrections in order to obtain true AE wave functions. 38 """ 39 def __init__(self, 40 calc: GPAW, 41 grid_spacing: float = 0.05, 42 n: int = 2, 43 h=None # deprecated 44 ): 45 """Create transformation object. 46 47 calc: GPAW calculator object 48 The calcalator that has the wave functions. 49 grid_spacing: float 50 Desired grid-spacing in Angstrom. 51 n: int 52 Force number of points to be a mulitiple of n. 53 """ 54 if h is not None: 55 warn('Please use grid_spacing=... instead of h=...') 56 grid_spacing = h 57 58 self.calc = calc 59 gd = calc.wfs.gd 60 61 gd1 = GridDescriptor(gd.N_c, gd.cell_cv, comm=serial_comm) 62 63 # Descriptor for the final grid: 64 N_c = h2gpts(grid_spacing / Bohr, gd.cell_cv) 65 N_c = np.array([get_efficient_fft_size(N, n) for N in N_c]) 66 gd2 = self.gd = GridDescriptor(N_c, gd.cell_cv, comm=serial_comm) 67 self.interpolator = Interpolator(gd1, gd2, self.calc.wfs.dtype) 68 69 self._dphi: Optional[LFC] = None # PAW correction 70 71 self.dv = self.gd.dv * Bohr**3 72 73 @property 74 def dphi(self) -> LFC: 75 if self._dphi is not None: 76 return self._dphi 77 78 splines: Dict[Setup, List[Spline]] = {} 79 dphi_aj = [] 80 for setup in self.calc.wfs.setups: 81 dphi_j = splines.get(setup) 82 if dphi_j is None: 83 rcut = max(setup.rcut_j) * 1.1 84 gcut = setup.rgd.ceil(rcut) 85 dphi_j = [] 86 for l, phi_g, phit_g in zip(setup.l_j, 87 setup.data.phi_jg, 88 setup.data.phit_jg): 89 dphi_g = (phi_g - phit_g)[:gcut] 90 dphi_j.append(setup.rgd.spline(dphi_g, rcut, l, 91 points=200)) 92 splines[setup] = dphi_j 93 dphi_aj.append(dphi_j) 94 95 self._dphi = LFC(self.gd, dphi_aj, kd=self.calc.wfs.kd.copy(), 96 dtype=self.calc.wfs.dtype) 97 self._dphi.set_positions(self.calc.spos_ac) 98 99 return self._dphi 100 101 def get_wave_function(self, 102 n: int, 103 k: int = 0, 104 s: int = 0, 105 ae: bool = True, 106 periodic: bool = False) -> Array3D: 107 """Interpolate wave function. 108 109 Returns 3-d array in units of Ang**-1.5. 110 111 n: int 112 Band index. 113 k: int 114 K-point index. 115 s: int 116 Spin index. 117 ae: bool 118 Add PAW correction to get an all-electron wave function. 119 periodic: 120 Return periodic part of wave-function, u(r), instead of 121 psi(r)=exp(ikr)u(r). 122 """ 123 u_r = self.calc.get_pseudo_wave_function(n, k, s, 124 pad=True, periodic=True) 125 u_R = self.interpolator.interpolate(u_r * Bohr**1.5) 126 127 k_c = self.calc.wfs.kd.ibzk_kc[k] 128 gamma = np.isclose(k_c, 0.0).all() 129 130 if gamma: 131 eikr_R = 1.0 132 else: 133 eikr_R = self.gd.plane_wave(k_c) 134 135 if ae: 136 dphi = self.dphi 137 wfs = self.calc.wfs 138 P_nI = wfs.collect_projections(k, s) 139 140 if wfs.world.rank == 0: 141 psi_R = u_R * eikr_R 142 P_ai = {} 143 I1 = 0 144 for a, setup in enumerate(wfs.setups): 145 I2 = I1 + setup.ni 146 P_ai[a] = P_nI[n, I1:I2] 147 I1 = I2 148 dphi.add(psi_R, P_ai, k) 149 u_R = psi_R / eikr_R 150 151 wfs.world.broadcast(u_R, 0) 152 153 if periodic: 154 return u_R * Bohr**-1.5 155 else: 156 return u_R * eikr_R * Bohr**-1.5 157 158 def get_pseudo_density(self, 159 add_compensation_charges: bool = True) -> Array3D: 160 """Interpolate pseudo density.""" 161 dens = self.calc.density 162 gd1 = dens.gd 163 assert gd1.comm.size == 1 164 interpolator = Interpolator(gd1, self.gd) 165 dens_r = dens.nt_sG[:dens.nspins].sum(axis=0) 166 dens_R = interpolator.interpolate(dens_r) 167 168 if add_compensation_charges: 169 dens.calculate_multipole_moments() 170 ghat = LFC(self.gd, [setup.ghat_l for setup in dens.setups], 171 integral=sqrt(4 * pi)) 172 ghat.set_positions(self.calc.spos_ac) 173 Q_aL = {} 174 for a, Q_L in dens.Q_aL.items(): 175 Q_aL[a] = Q_L.copy() 176 Q_aL[a][0] += dens.setups[a].Nv / (4 * pi)**0.5 177 ghat.add(dens_R, Q_aL) 178 179 return dens_R / Bohr**3 180 181 def get_electrostatic_potential(self, 182 ae: bool = True, 183 rcgauss: float = 0.02) -> Array3D: 184 """Interpolate electrostatic potential. 185 186 Return value in eV. 187 188 ae: bool 189 Add PAW correction to get the all-electron potential. 190 rcgauss: float 191 Width of gaussian (in Angstrom) used to represent the nuclear 192 charge. 193 """ 194 gd = self.calc.hamiltonian.finegd 195 v_r = self.calc.get_electrostatic_potential() / Ha 196 gd1 = GridDescriptor(gd.N_c, gd.cell_cv, comm=serial_comm) 197 interpolator = Interpolator(gd1, self.gd) 198 v_R = interpolator.interpolate(v_r) 199 200 if ae: 201 self.add_potential_correction(v_R, rcgauss / Bohr) 202 203 return v_R * Ha 204 205 def add_potential_correction(self, 206 v_R: Array3D, 207 rcgauss: float) -> None: 208 dens = self.calc.density 209 dens.D_asp.redistribute(dens.atom_partition.as_serial()) 210 dens.Q_aL.redistribute(dens.atom_partition.as_serial()) 211 212 dv_a1 = [] 213 for a, D_sp in dens.D_asp.items(): 214 setup = dens.setups[a] 215 c = setup.xc_correction 216 rgd = c.rgd 217 params = setup.data.shape_function.copy() 218 params['lmax'] = 0 219 ghat_g = shape_functions(rgd, **params)[0] 220 Z_g = shape_functions(rgd, 'gauss', rcgauss, lmax=0)[0] * setup.Z 221 D_q = np.dot(D_sp.sum(0), c.B_pqL[:, :, 0]) 222 dn_g = np.dot(D_q, (c.n_qg - c.nt_qg)) * sqrt(4 * pi) 223 dn_g += 4 * pi * (c.nc_g - c.nct_g) 224 dn_g -= Z_g 225 dn_g -= dens.Q_aL[a][0] * ghat_g * sqrt(4 * pi) 226 dv_g = rgd.poisson(dn_g) / sqrt(4 * pi) 227 dv_g[1:] /= rgd.r_g[1:] 228 dv_g[0] = dv_g[1] 229 dv_g[-1] = 0.0 230 dv_a1.append([rgd.spline(dv_g, points=POINTS)]) 231 232 dens.D_asp.redistribute(dens.atom_partition) 233 dens.Q_aL.redistribute(dens.atom_partition) 234 235 if dv_a1: 236 dv = LFC(self.gd, dv_a1) 237 dv.set_positions(self.calc.spos_ac) 238 dv.add(v_R) 239 dens.gd.comm.broadcast(v_R, 0) 240 241 242def interpolate_weight(calc, weight, h=0.05, n=2): 243 """interpolates cdft weight function, gd is the fine grid.""" 244 gd = calc.density.finegd 245 246 weight = gd.collect(weight, broadcast=True) 247 weight = gd.zero_pad(weight) 248 249 w = np.zeros_like(weight) 250 gd1 = GridDescriptor(gd.N_c, gd.cell_cv, comm=serial_comm) 251 gd1.distribute(weight, w) 252 253 N_c = h2gpts(h / Bohr, gd.cell_cv) 254 N_c = np.array([get_efficient_fft_size(N, n) for N in N_c]) 255 gd2 = GridDescriptor(N_c, gd.cell_cv, comm=serial_comm) 256 257 interpolator = Interpolator(gd1, gd2) 258 W = interpolator.interpolate(w) 259 260 return W 261