1from functools import wraps 2import numpy as np 3import pytest 4 5from ase.build import molecule 6 7from gpaw import GPAW 8from gpaw.mpi import world, serial_comm, broadcast_float, broadcast 9from gpaw.lcaotddft import LCAOTDDFT 10from gpaw.lcaotddft.dipolemomentwriter import DipoleMomentWriter 11from gpaw.lcaotddft.wfwriter import WaveFunctionWriter, WaveFunctionReader 12from gpaw.lcaotddft.densitymatrix import DensityMatrix 13from gpaw.lcaotddft.frequencydensitymatrix import FrequencyDensityMatrix 14from gpaw.lcaotddft.ksdecomposition import KohnShamDecomposition 15from gpaw.tddft.folding import frequencies 16from gpaw.utilities import compiled_with_sl 17 18pytestmark = pytest.mark.usefixtures('module_tmp_path') 19 20 21def only_on_master(comm, broadcast=None): 22 """Decorator for executing the function only on the rank 0. 23 24 Parameters 25 ---------- 26 comm 27 communicator 28 broadcast 29 function for broadcasting the return value or 30 `None` for no broadcasting 31 """ 32 def wrap(func): 33 @wraps(func) 34 def wrapped_func(*args, **kwargs): 35 if comm.rank == 0: 36 ret = func(*args, **kwargs) 37 else: 38 ret = None 39 comm.barrier() 40 if broadcast is not None: 41 ret = broadcast(ret, comm=comm) 42 return ret 43 return wrapped_func 44 return wrap 45 46 47def calculate_error(a, ref_a): 48 if world.rank == 0: 49 err = np.abs(a - ref_a).max() 50 print() 51 print('ERR', err) 52 else: 53 err = np.nan 54 err = broadcast_float(err, world) 55 return err 56 57 58def calculate_time_propagation(gs_fpath, kick, 59 communicator=world, parallel={}, 60 do_fdm=False): 61 td_calc = LCAOTDDFT(gs_fpath, 62 communicator=communicator, 63 parallel=parallel, 64 txt='td.out') 65 if do_fdm: 66 dmat = DensityMatrix(td_calc) 67 ffreqs = frequencies(range(0, 31, 5), 'Gauss', 0.1) 68 fdm = FrequencyDensityMatrix(td_calc, dmat, frequencies=ffreqs) 69 DipoleMomentWriter(td_calc, 'dm.dat') 70 WaveFunctionWriter(td_calc, 'wf.ulm') 71 td_calc.absorption_kick(kick) 72 td_calc.propagate(20, 3) 73 if do_fdm: 74 fdm.write('fdm.ulm') 75 76 communicator.barrier() 77 78 if do_fdm: 79 return fdm 80 81 82def check_wfs(wf_ref_fpath, wf_fpath, atol=1e-12): 83 wfr_ref = WaveFunctionReader(wf_ref_fpath) 84 wfr = WaveFunctionReader(wf_fpath) 85 assert len(wfr) == len(wfr_ref) 86 for i in range(1, len(wfr)): 87 ref = wfr_ref[i].wave_functions.coefficients 88 coeff = wfr[i].wave_functions.coefficients 89 err = calculate_error(coeff, ref) 90 assert err < atol, f'error at i={i}' 91 92 93# Generate different parallelization options 94parallel_i = [{}] 95if compiled_with_sl(): 96 if world.size == 1: 97 # Choose BLACS grid manually as the one given by sl_auto 98 # doesn't work well for the small test system and 1 process 99 parallel_i.append({'sl_default': (1, 1, 8)}) 100 else: 101 parallel_i.append({'sl_auto': True}) 102 parallel_i.append({'sl_auto': True, 'band': 2}) 103 104 105@pytest.fixture(scope='module') 106@only_on_master(world) 107def initialize_system(): 108 comm = serial_comm 109 110 # Ground-state calculation 111 atoms = molecule('NaCl') 112 atoms.center(vacuum=4.0) 113 calc = GPAW(nbands=6, 114 h=0.4, 115 setups=dict(Na='1'), 116 basis='dzp', 117 mode='lcao', 118 convergence={'density': 1e-8}, 119 communicator=comm, 120 txt='gs.out') 121 atoms.calc = calc 122 atoms.get_potential_energy() 123 calc.write('gs.gpw', mode='all') 124 125 # Time-propagation calculation 126 fdm = calculate_time_propagation('gs.gpw', 127 kick=np.ones(3) * 1e-5, 128 communicator=comm, 129 do_fdm=True) 130 131 # Calculate ground state with full unoccupied space 132 unocc_calc = calc.fixed_density(nbands='nao', 133 communicator=comm, 134 txt='unocc.out') 135 unocc_calc.write('unocc.gpw', mode='all') 136 return unocc_calc, fdm 137 138 139def test_propagated_wave_function(initialize_system, module_tmp_path): 140 wfr = WaveFunctionReader(module_tmp_path / 'wf.ulm') 141 coeff = wfr[-1].wave_functions.coefficients 142 # Pick a few coefficients corresponding to non-degenerate states; 143 # degenerate states should be normalized so that they can be compared 144 coeff = coeff[np.ix_([0], [0], [0, 1, 4], [0, 1, 2])] 145 # Normalize the wave function sign 146 coeff = np.sign(coeff.real[..., 0, np.newaxis]) * coeff 147 ref = [[[[1.6564776755628504e-02 + 1.2158943340143986e-01j, 148 4.7464497657284752e-03 + 3.4917799444496286e-02j, 149 8.2152048273399657e-07 - 1.6344333784831069e-06j], 150 [1.5177089239371724e-01 + 7.6502712023931621e-02j, 151 8.0497556154952932e-01 + 4.0573839188792121e-01j, 152 -5.1505952970811632e-06 - 1.1507918955641119e-05j], 153 [2.5116252101774323e+00 + 3.6776360873471503e-01j, 154 1.9024613198566329e-01 + 2.7843314959952882e-02j, 155 -1.3848736953929574e-05 - 2.6402210145403184e-05j]]]] 156 err = calculate_error(coeff, ref) 157 assert err < 2e-12 158 159 160@pytest.mark.parametrize('parallel', parallel_i) 161def test_propagation(initialize_system, module_tmp_path, parallel, in_tmp_dir): 162 calculate_time_propagation(module_tmp_path / 'gs.gpw', 163 kick=np.ones(3) * 1e-5, 164 parallel=parallel) 165 check_wfs(module_tmp_path / 'wf.ulm', 'wf.ulm', atol=1e-12) 166 167 168@pytest.fixture(scope='module') 169@only_on_master(world, broadcast=broadcast) 170def dipole_moment_reference(initialize_system): 171 from gpaw.tddft.spectrum import \ 172 read_dipole_moment_file, calculate_fourier_transform 173 174 unocc_calc, fdm = initialize_system 175 _, time_t, _, dm_tv = read_dipole_moment_file('dm.dat') 176 dm_tv = dm_tv - dm_tv[0] 177 dm_wv = calculate_fourier_transform(time_t, dm_tv, 178 fdm.foldedfreqs_f[0]) 179 return dm_wv 180 181 182@pytest.fixture(scope='module') 183@only_on_master(world) 184def ksd_reference(initialize_system): 185 unocc_calc, fdm = initialize_system 186 ksd = KohnShamDecomposition(unocc_calc) 187 ksd.initialize(unocc_calc) 188 return ksd, fdm 189 190 191def ksd_transform_fdm(ksd, fdm): 192 rho_iwp = np.empty((2, len(fdm.freq_w), len(ksd.w_p)), dtype=complex) 193 rho_iwp[:] = np.nan + 1j * np.nan 194 for i, rho_wuMM in enumerate([fdm.FReDrho_wuMM, fdm.FImDrho_wuMM]): 195 for w in range(len(fdm.freq_w)): 196 rho_uMM = rho_wuMM[w] 197 rho_up = ksd.transform(rho_uMM) 198 rho_iwp[i, w, :] = rho_up[0] 199 return rho_iwp 200 201 202@pytest.fixture(scope='module') 203@only_on_master(world, broadcast=broadcast) 204def ksd_transform_reference(ksd_reference): 205 ksd, fdm = ksd_reference 206 ref_rho_iwp = ksd_transform_fdm(ksd, fdm) 207 return ref_rho_iwp 208 209 210@pytest.fixture(scope='module', params=parallel_i) 211def build_ksd(initialize_system, request): 212 calc = GPAW('unocc.gpw', parallel=request.param, txt=None) 213 ksd = KohnShamDecomposition(calc) 214 ksd.initialize(calc) 215 ksd.write('ksd.ulm') 216 217 218@pytest.fixture(scope='module', params=parallel_i) 219def load_ksd(build_ksd, request): 220 calc = GPAW('unocc.gpw', parallel=request.param, txt=None) 221 # Initialize positions in order to calculate density 222 calc.initialize_positions() 223 ksd = KohnShamDecomposition(calc, 'ksd.ulm') 224 dmat = DensityMatrix(calc) 225 fdm = FrequencyDensityMatrix(calc, dmat, 'fdm.ulm') 226 return ksd, fdm 227 228 229@pytest.fixture(scope='module') 230def ksd_transform(load_ksd): 231 ksd, fdm = load_ksd 232 rho_iwp = ksd_transform_fdm(ksd, fdm) 233 return rho_iwp 234 235 236def test_ksd_transform(ksd_transform, ksd_transform_reference): 237 ref_iwp = ksd_transform_reference 238 rho_iwp = ksd_transform 239 err = calculate_error(rho_iwp, ref_iwp) 240 atol = 1e-18 241 assert err < atol 242 243 244def test_ksd_transform_real_only(load_ksd, ksd_transform_reference): 245 ksd, fdm = load_ksd 246 ref_iwp = ksd_transform_reference 247 rho_iwp = np.empty((2, len(fdm.freq_w), len(ksd.w_p)), dtype=complex) 248 rho_iwp[:] = np.nan + 1j * np.nan 249 for i, rho_wuMM in enumerate([fdm.FReDrho_wuMM, fdm.FImDrho_wuMM]): 250 for w in range(len(fdm.freq_w)): 251 rho_uMM = rho_wuMM[w] 252 rho_p = ksd.transform([rho_uMM[0].real], broadcast=True)[0] \ 253 + 1j * ksd.transform([rho_uMM[0].imag], broadcast=True)[0] 254 rho_iwp[i, w, :] = rho_p 255 err = calculate_error(rho_iwp, ref_iwp) 256 atol = 1e-18 257 assert err < atol 258 259 260def test_dipole_moment_from_ksd(ksd_transform, load_ksd, 261 dipole_moment_reference): 262 ksd, fdm = load_ksd 263 dm_wv = np.empty((len(fdm.freq_w), 3), dtype=complex) 264 dm_wv[:] = np.nan + 1j * np.nan 265 rho_wp = ksd_transform[0] 266 for w in range(len(fdm.freq_w)): 267 dm_v = ksd.get_dipole_moment([rho_wp[w]]) 268 dm_wv[w, :] = dm_v 269 270 ref_wv = dipole_moment_reference 271 err = calculate_error(dm_wv, ref_wv) 272 atol = 1e-7 273 assert err < atol 274 275 276def get_density_fdm(ksd, fdm, kind): 277 assert kind in ['dmat', 'ksd'] 278 rho_wg = fdm.dmat.density.finegd.empty(len(fdm.freq_w), dtype=complex) 279 rho_wg[:] = np.nan + 1j * np.nan 280 for w in range(len(fdm.freq_w)): 281 rho_uMM = fdm.FReDrho_wuMM[w] 282 if kind == 'dmat': 283 rho_g = fdm.dmat.get_density([rho_uMM[0].real]) \ 284 + 1j * fdm.dmat.get_density([rho_uMM[0].imag]) 285 elif kind == 'ksd': 286 rho_up = ksd.transform(rho_uMM, broadcast=True) 287 rho_g = ksd.get_density(fdm.dmat.wfs, [rho_up[0].real]) \ 288 + 1j * ksd.get_density(fdm.dmat.wfs, [rho_up[0].imag]) 289 rho_wg[w, :] = rho_g 290 return rho_wg 291 292 293@pytest.fixture(scope='module') 294@only_on_master(world, broadcast=broadcast) 295def density_reference(ksd_reference): 296 ksd, fdm = ksd_reference 297 dmat_rho_wg = get_density_fdm(ksd, fdm, 'dmat') 298 ksd_rho_wg = get_density_fdm(ksd, fdm, 'ksd') 299 return dict(dmat=dmat_rho_wg, ksd=ksd_rho_wg) 300 301 302def test_ksd_vs_dmat_density(density_reference): 303 ref_wg = density_reference['dmat'] 304 rho_wg = density_reference['ksd'] 305 err = calculate_error(rho_wg, ref_wg) 306 atol = 2e-10 307 assert err < atol 308 309 310@pytest.fixture(scope='module') 311def density(load_ksd): 312 ksd, fdm = load_ksd 313 if ksd.ksl.using_blacs: 314 pytest.xfail('Scalapack is not supported') 315 dmat_rho_wg = get_density_fdm(ksd, fdm, 'dmat') 316 ksd_rho_wg = get_density_fdm(ksd, fdm, 'ksd') 317 return dict(dmat=dmat_rho_wg, ksd=ksd_rho_wg) 318 319 320@pytest.mark.parametrize('kind', ['ksd', 'dmat']) 321def test_density(kind, density, load_ksd, density_reference): 322 ksd, fdm = load_ksd 323 ref_wg = density_reference[kind] 324 rho_wg = fdm.dmat.density.finegd.collect(density[kind]) 325 err = calculate_error(rho_wg, ref_wg) 326 atol = 3e-19 327 assert err < atol 328 329 330@pytest.mark.parametrize('kind', ['ksd', 'dmat']) 331def test_dipole_moment_from_density(kind, density, load_ksd, 332 dipole_moment_reference): 333 ksd, fdm = load_ksd 334 rho_wg = density[kind] 335 dm_wv = np.empty((len(fdm.freq_w), 3), dtype=complex) 336 dm_wv[:] = np.nan + 1j * np.nan 337 for w in range(len(fdm.freq_w)): 338 dm_v = ksd.density.finegd.calculate_dipole_moment(rho_wg[w]) 339 dm_wv[w, :] = dm_v 340 341 ref_wv = dipole_moment_reference 342 err = calculate_error(dm_wv, ref_wv) 343 atol = 5e-7 344 assert err < atol 345