1import pytest
2
3from ase.build import molecule
4
5from gpaw import GPAW
6from gpaw.tddft import TDDFT, DipoleMomentWriter
7from gpaw.mpi import world, serial_comm
8from gpaw.utilities import compiled_with_sl
9
10from ..lcaotddft.test_molecule import only_on_master
11
12
13pytestmark = pytest.mark.usefixtures('module_tmp_path')
14
15
16def calculate_time_propagation(gpw_fpath, *,
17                               iterations=3,
18                               kick=[1e-5, 1e-5, 1e-5],
19                               propagator='SICN',
20                               communicator=world,
21                               write_and_continue=False,
22                               force_new_dm_file=False,
23                               parallel={}):
24    td_calc = TDDFT(gpw_fpath,
25                    propagator=propagator,
26                    communicator=communicator,
27                    parallel=parallel,
28                    txt='td.out')
29    DipoleMomentWriter(td_calc, 'dm.dat',
30                       force_new_file=force_new_dm_file)
31    if kick is not None:
32        td_calc.absorption_kick(kick)
33    td_calc.propagate(20, iterations)
34    if write_and_continue:
35        td_calc.write('td.gpw', mode='all')
36        # Switch dipole moment writer and output
37        td_calc.observers.pop()
38        dm = DipoleMomentWriter(td_calc, 'dm2.dat', force_new_file=True)
39        dm._update(td_calc)
40        td_calc.propagate(20, iterations)
41    communicator.barrier()
42
43
44def check_dm(ref_fpath, fpath, rtol=1e-8, atol=1e-12):
45    from gpaw.tddft.spectrum import read_dipole_moment_file
46
47    world.barrier()
48    _, time_ref_t, _, dm_ref_tv = read_dipole_moment_file(ref_fpath)
49    _, time_t, _, dm_tv = read_dipole_moment_file(fpath)
50    assert time_t == pytest.approx(time_ref_t, abs=0)
51    assert dm_tv == pytest.approx(dm_ref_tv, rel=rtol, abs=atol)
52
53
54# Generate different parallelization options
55parallel_i = [{}]
56if world.size > 1:
57    parallel_i.append({'band': 2})
58if compiled_with_sl():
59    parallel_i.append({'sl_auto': True})
60    if world.size > 1:
61        parallel_i.append({'sl_auto': True, 'band': 2})
62
63
64@pytest.fixture(scope='module')
65@only_on_master(world)
66def ground_state():
67    atoms = molecule('SiH4')
68    atoms.center(vacuum=4.0)
69
70    calc = GPAW(nbands=6, h=0.4,
71                convergence={'density': 1e-8},
72                communicator=serial_comm,
73                xc='LDA',
74                txt='gs.out')
75    atoms.calc = calc
76    atoms.get_potential_energy()
77    calc.write('gs.gpw', mode='all')
78
79
80@pytest.fixture(scope='module')
81@only_on_master(world)
82def time_propagation_reference(ground_state):
83    calculate_time_propagation('gs.gpw',
84                               communicator=serial_comm,
85                               write_and_continue=True)
86
87
88def test_dipole_moment_values(time_propagation_reference,
89                              module_tmp_path, in_tmp_dir):
90    with open('dm.dat', 'w') as fd:
91        fd.write('''
92# DipoleMomentWriter[version=1](center=False, density='comp')
93#            time            norm                    dmx                    dmy                    dmz
94          0.00000000       6.92701356e-16    -3.798602757097e-08    -3.850923113536e-10    -2.506988148420e-10
95# Kick = [    1.000000000000e-05,     1.000000000000e-05,     1.000000000000e-05]; Time = 0.00000000
96          0.00000000       6.78612525e-16    -3.806745480191e-08    -5.880500044945e-10    -4.683685533214e-10
97          0.82682747      -3.11611967e-16     6.011432043389e-05     6.015251317290e-05     6.015179500177e-05
98          1.65365493       1.71405522e-15     1.075009677567e-04     1.075385921602e-04     1.075337414463e-04
99          2.48048240       1.55070479e-15     1.388363880650e-04     1.388662733804e-04     1.388701173331e-04
100'''.strip())  # noqa: E501
101
102    with open('dm2.dat', 'w') as fd:
103        fd.write('''
104# DipoleMomentWriter[version=1](center=False, density='comp')
105#            time            norm                    dmx                    dmy                    dmz
106          2.48048240       1.55070479e-15     1.388363880650e-04     1.388662733804e-04     1.388701173331e-04
107          3.30730987      -1.85697397e-16     1.528174640313e-04     1.528352677280e-04     1.528428543052e-04
108          4.13413733      -6.23799730e-17     1.497979692345e-04     1.498097215567e-04     1.498150038226e-04
109          4.96096480       1.44537040e-15     1.324352983945e-04     1.324326482531e-04     1.324473352117e-04
110'''.strip())  # noqa: E501
111
112    rtol = 5e-4
113    atol = 1e-8
114    check_dm('dm.dat', module_tmp_path / 'dm.dat', rtol=rtol, atol=atol)
115    check_dm('dm2.dat', module_tmp_path / 'dm2.dat', rtol=rtol, atol=atol)
116
117
118@pytest.mark.parametrize('parallel', parallel_i)
119@pytest.mark.parametrize('propagator', [
120    'SICN', 'ECN', 'ETRSCN', 'SIKE'])
121def test_propagation(time_propagation_reference,
122                     parallel, propagator,
123                     module_tmp_path, in_tmp_dir):
124    calculate_time_propagation(module_tmp_path / 'gs.gpw',
125                               propagator=propagator,
126                               parallel=parallel)
127    atol = 1e-12
128    if propagator == 'SICN':
129        # This is the same propagator as the reference;
130        # error comes only from parallelization
131        rtol = 1e-8
132        if 'band' in parallel:
133            # XXX band parallelization is inaccurate!
134            rtol = 7e-4
135            atol = 3e-8
136    else:
137        # Other propagators match qualitatively
138        rtol = 5e-2
139        if 'band' in parallel:
140            # XXX band parallelization is inaccurate!
141            atol = 3e-8
142    check_dm(module_tmp_path / 'dm.dat', 'dm.dat', rtol=rtol, atol=atol)
143
144
145@pytest.mark.parametrize('parallel', parallel_i)
146def test_restart(time_propagation_reference,
147                 parallel,
148                 module_tmp_path, in_tmp_dir):
149    calculate_time_propagation(module_tmp_path / 'td.gpw',
150                               kick=None,
151                               force_new_dm_file=True,
152                               parallel=parallel)
153    rtol = 1e-8
154    if 'band' in parallel:
155        rtol = 5e-4
156    check_dm(module_tmp_path / 'dm2.dat', 'dm.dat', rtol=rtol)
157