1import pytest 2 3from ase.build import molecule 4 5from gpaw import GPAW 6from gpaw.tddft import TDDFT, DipoleMomentWriter 7from gpaw.mpi import world, serial_comm 8from gpaw.utilities import compiled_with_sl 9 10from ..lcaotddft.test_molecule import only_on_master 11 12 13pytestmark = pytest.mark.usefixtures('module_tmp_path') 14 15 16def calculate_time_propagation(gpw_fpath, *, 17 iterations=3, 18 kick=[1e-5, 1e-5, 1e-5], 19 propagator='SICN', 20 communicator=world, 21 write_and_continue=False, 22 force_new_dm_file=False, 23 parallel={}): 24 td_calc = TDDFT(gpw_fpath, 25 propagator=propagator, 26 communicator=communicator, 27 parallel=parallel, 28 txt='td.out') 29 DipoleMomentWriter(td_calc, 'dm.dat', 30 force_new_file=force_new_dm_file) 31 if kick is not None: 32 td_calc.absorption_kick(kick) 33 td_calc.propagate(20, iterations) 34 if write_and_continue: 35 td_calc.write('td.gpw', mode='all') 36 # Switch dipole moment writer and output 37 td_calc.observers.pop() 38 dm = DipoleMomentWriter(td_calc, 'dm2.dat', force_new_file=True) 39 dm._update(td_calc) 40 td_calc.propagate(20, iterations) 41 communicator.barrier() 42 43 44def check_dm(ref_fpath, fpath, rtol=1e-8, atol=1e-12): 45 from gpaw.tddft.spectrum import read_dipole_moment_file 46 47 world.barrier() 48 _, time_ref_t, _, dm_ref_tv = read_dipole_moment_file(ref_fpath) 49 _, time_t, _, dm_tv = read_dipole_moment_file(fpath) 50 assert time_t == pytest.approx(time_ref_t, abs=0) 51 assert dm_tv == pytest.approx(dm_ref_tv, rel=rtol, abs=atol) 52 53 54# Generate different parallelization options 55parallel_i = [{}] 56if world.size > 1: 57 parallel_i.append({'band': 2}) 58if compiled_with_sl(): 59 parallel_i.append({'sl_auto': True}) 60 if world.size > 1: 61 parallel_i.append({'sl_auto': True, 'band': 2}) 62 63 64@pytest.fixture(scope='module') 65@only_on_master(world) 66def ground_state(): 67 atoms = molecule('SiH4') 68 atoms.center(vacuum=4.0) 69 70 calc = GPAW(nbands=6, h=0.4, 71 convergence={'density': 1e-8}, 72 communicator=serial_comm, 73 xc='LDA', 74 txt='gs.out') 75 atoms.calc = calc 76 atoms.get_potential_energy() 77 calc.write('gs.gpw', mode='all') 78 79 80@pytest.fixture(scope='module') 81@only_on_master(world) 82def time_propagation_reference(ground_state): 83 calculate_time_propagation('gs.gpw', 84 communicator=serial_comm, 85 write_and_continue=True) 86 87 88def test_dipole_moment_values(time_propagation_reference, 89 module_tmp_path, in_tmp_dir): 90 with open('dm.dat', 'w') as fd: 91 fd.write(''' 92# DipoleMomentWriter[version=1](center=False, density='comp') 93# time norm dmx dmy dmz 94 0.00000000 6.92701356e-16 -3.798602757097e-08 -3.850923113536e-10 -2.506988148420e-10 95# Kick = [ 1.000000000000e-05, 1.000000000000e-05, 1.000000000000e-05]; Time = 0.00000000 96 0.00000000 6.78612525e-16 -3.806745480191e-08 -5.880500044945e-10 -4.683685533214e-10 97 0.82682747 -3.11611967e-16 6.011432043389e-05 6.015251317290e-05 6.015179500177e-05 98 1.65365493 1.71405522e-15 1.075009677567e-04 1.075385921602e-04 1.075337414463e-04 99 2.48048240 1.55070479e-15 1.388363880650e-04 1.388662733804e-04 1.388701173331e-04 100'''.strip()) # noqa: E501 101 102 with open('dm2.dat', 'w') as fd: 103 fd.write(''' 104# DipoleMomentWriter[version=1](center=False, density='comp') 105# time norm dmx dmy dmz 106 2.48048240 1.55070479e-15 1.388363880650e-04 1.388662733804e-04 1.388701173331e-04 107 3.30730987 -1.85697397e-16 1.528174640313e-04 1.528352677280e-04 1.528428543052e-04 108 4.13413733 -6.23799730e-17 1.497979692345e-04 1.498097215567e-04 1.498150038226e-04 109 4.96096480 1.44537040e-15 1.324352983945e-04 1.324326482531e-04 1.324473352117e-04 110'''.strip()) # noqa: E501 111 112 rtol = 5e-4 113 atol = 1e-8 114 check_dm('dm.dat', module_tmp_path / 'dm.dat', rtol=rtol, atol=atol) 115 check_dm('dm2.dat', module_tmp_path / 'dm2.dat', rtol=rtol, atol=atol) 116 117 118@pytest.mark.parametrize('parallel', parallel_i) 119@pytest.mark.parametrize('propagator', [ 120 'SICN', 'ECN', 'ETRSCN', 'SIKE']) 121def test_propagation(time_propagation_reference, 122 parallel, propagator, 123 module_tmp_path, in_tmp_dir): 124 calculate_time_propagation(module_tmp_path / 'gs.gpw', 125 propagator=propagator, 126 parallel=parallel) 127 atol = 1e-12 128 if propagator == 'SICN': 129 # This is the same propagator as the reference; 130 # error comes only from parallelization 131 rtol = 1e-8 132 if 'band' in parallel: 133 # XXX band parallelization is inaccurate! 134 rtol = 7e-4 135 atol = 3e-8 136 else: 137 # Other propagators match qualitatively 138 rtol = 5e-2 139 if 'band' in parallel: 140 # XXX band parallelization is inaccurate! 141 atol = 3e-8 142 check_dm(module_tmp_path / 'dm.dat', 'dm.dat', rtol=rtol, atol=atol) 143 144 145@pytest.mark.parametrize('parallel', parallel_i) 146def test_restart(time_propagation_reference, 147 parallel, 148 module_tmp_path, in_tmp_dir): 149 calculate_time_propagation(module_tmp_path / 'td.gpw', 150 kick=None, 151 force_new_dm_file=True, 152 parallel=parallel) 153 rtol = 1e-8 154 if 'band' in parallel: 155 rtol = 5e-4 156 check_dm(module_tmp_path / 'dm2.dat', 'dm.dat', rtol=rtol) 157