1from functools import wraps
2import numpy as np
3import pytest
4
5from ase.build import molecule
6
7from gpaw import GPAW
8from gpaw.mpi import world, serial_comm, broadcast_float, broadcast
9from gpaw.lcaotddft import LCAOTDDFT
10from gpaw.lcaotddft.dipolemomentwriter import DipoleMomentWriter
11from gpaw.lcaotddft.wfwriter import WaveFunctionWriter, WaveFunctionReader
12from gpaw.lcaotddft.densitymatrix import DensityMatrix
13from gpaw.lcaotddft.frequencydensitymatrix import FrequencyDensityMatrix
14from gpaw.lcaotddft.ksdecomposition import KohnShamDecomposition
15from gpaw.tddft.folding import frequencies
16from gpaw.utilities import compiled_with_sl
17
18pytestmark = pytest.mark.usefixtures('module_tmp_path')
19
20
21def only_on_master(comm, broadcast=None):
22    """Decorator for executing the function only on the rank 0.
23
24    Parameters
25    ----------
26    comm
27        communicator
28    broadcast
29        function for broadcasting the return value or
30        `None` for no broadcasting
31    """
32    def wrap(func):
33        @wraps(func)
34        def wrapped_func(*args, **kwargs):
35            if comm.rank == 0:
36                ret = func(*args, **kwargs)
37            else:
38                ret = None
39            comm.barrier()
40            if broadcast is not None:
41                ret = broadcast(ret, comm=comm)
42            return ret
43        return wrapped_func
44    return wrap
45
46
47def calculate_error(a, ref_a):
48    if world.rank == 0:
49        err = np.abs(a - ref_a).max()
50        print()
51        print('ERR', err)
52    else:
53        err = np.nan
54    err = broadcast_float(err, world)
55    return err
56
57
58def calculate_time_propagation(gs_fpath, kick,
59                               communicator=world, parallel={},
60                               do_fdm=False):
61    td_calc = LCAOTDDFT(gs_fpath,
62                        communicator=communicator,
63                        parallel=parallel,
64                        txt='td.out')
65    if do_fdm:
66        dmat = DensityMatrix(td_calc)
67        ffreqs = frequencies(range(0, 31, 5), 'Gauss', 0.1)
68        fdm = FrequencyDensityMatrix(td_calc, dmat, frequencies=ffreqs)
69    DipoleMomentWriter(td_calc, 'dm.dat')
70    WaveFunctionWriter(td_calc, 'wf.ulm')
71    td_calc.absorption_kick(kick)
72    td_calc.propagate(20, 3)
73    if do_fdm:
74        fdm.write('fdm.ulm')
75
76    communicator.barrier()
77
78    if do_fdm:
79        return fdm
80
81
82def check_wfs(wf_ref_fpath, wf_fpath, atol=1e-12):
83    wfr_ref = WaveFunctionReader(wf_ref_fpath)
84    wfr = WaveFunctionReader(wf_fpath)
85    assert len(wfr) == len(wfr_ref)
86    for i in range(1, len(wfr)):
87        ref = wfr_ref[i].wave_functions.coefficients
88        coeff = wfr[i].wave_functions.coefficients
89        err = calculate_error(coeff, ref)
90        assert err < atol, f'error at i={i}'
91
92
93# Generate different parallelization options
94parallel_i = [{}]
95if compiled_with_sl():
96    if world.size == 1:
97        # Choose BLACS grid manually as the one given by sl_auto
98        # doesn't work well for the small test system and 1 process
99        parallel_i.append({'sl_default': (1, 1, 8)})
100    else:
101        parallel_i.append({'sl_auto': True})
102        parallel_i.append({'sl_auto': True, 'band': 2})
103
104
105@pytest.fixture(scope='module')
106@only_on_master(world)
107def initialize_system():
108    comm = serial_comm
109
110    # Ground-state calculation
111    atoms = molecule('NaCl')
112    atoms.center(vacuum=4.0)
113    calc = GPAW(nbands=6,
114                h=0.4,
115                setups=dict(Na='1'),
116                basis='dzp',
117                mode='lcao',
118                convergence={'density': 1e-8},
119                communicator=comm,
120                txt='gs.out')
121    atoms.calc = calc
122    atoms.get_potential_energy()
123    calc.write('gs.gpw', mode='all')
124
125    # Time-propagation calculation
126    fdm = calculate_time_propagation('gs.gpw',
127                                     kick=np.ones(3) * 1e-5,
128                                     communicator=comm,
129                                     do_fdm=True)
130
131    # Calculate ground state with full unoccupied space
132    unocc_calc = calc.fixed_density(nbands='nao',
133                                    communicator=comm,
134                                    txt='unocc.out')
135    unocc_calc.write('unocc.gpw', mode='all')
136    return unocc_calc, fdm
137
138
139def test_propagated_wave_function(initialize_system, module_tmp_path):
140    wfr = WaveFunctionReader(module_tmp_path / 'wf.ulm')
141    coeff = wfr[-1].wave_functions.coefficients
142    # Pick a few coefficients corresponding to non-degenerate states;
143    # degenerate states should be normalized so that they can be compared
144    coeff = coeff[np.ix_([0], [0], [0, 1, 4], [0, 1, 2])]
145    # Normalize the wave function sign
146    coeff = np.sign(coeff.real[..., 0, np.newaxis]) * coeff
147    ref = [[[[1.6564776755628504e-02 + 1.2158943340143986e-01j,
148              4.7464497657284752e-03 + 3.4917799444496286e-02j,
149              8.2152048273399657e-07 - 1.6344333784831069e-06j],
150             [1.5177089239371724e-01 + 7.6502712023931621e-02j,
151              8.0497556154952932e-01 + 4.0573839188792121e-01j,
152              -5.1505952970811632e-06 - 1.1507918955641119e-05j],
153             [2.5116252101774323e+00 + 3.6776360873471503e-01j,
154              1.9024613198566329e-01 + 2.7843314959952882e-02j,
155              -1.3848736953929574e-05 - 2.6402210145403184e-05j]]]]
156    err = calculate_error(coeff, ref)
157    assert err < 2e-12
158
159
160@pytest.mark.parametrize('parallel', parallel_i)
161def test_propagation(initialize_system, module_tmp_path, parallel, in_tmp_dir):
162    calculate_time_propagation(module_tmp_path / 'gs.gpw',
163                               kick=np.ones(3) * 1e-5,
164                               parallel=parallel)
165    check_wfs(module_tmp_path / 'wf.ulm', 'wf.ulm', atol=1e-12)
166
167
168@pytest.fixture(scope='module')
169@only_on_master(world, broadcast=broadcast)
170def dipole_moment_reference(initialize_system):
171    from gpaw.tddft.spectrum import \
172        read_dipole_moment_file, calculate_fourier_transform
173
174    unocc_calc, fdm = initialize_system
175    _, time_t, _, dm_tv = read_dipole_moment_file('dm.dat')
176    dm_tv = dm_tv - dm_tv[0]
177    dm_wv = calculate_fourier_transform(time_t, dm_tv,
178                                        fdm.foldedfreqs_f[0])
179    return dm_wv
180
181
182@pytest.fixture(scope='module')
183@only_on_master(world)
184def ksd_reference(initialize_system):
185    unocc_calc, fdm = initialize_system
186    ksd = KohnShamDecomposition(unocc_calc)
187    ksd.initialize(unocc_calc)
188    return ksd, fdm
189
190
191def ksd_transform_fdm(ksd, fdm):
192    rho_iwp = np.empty((2, len(fdm.freq_w), len(ksd.w_p)), dtype=complex)
193    rho_iwp[:] = np.nan + 1j * np.nan
194    for i, rho_wuMM in enumerate([fdm.FReDrho_wuMM, fdm.FImDrho_wuMM]):
195        for w in range(len(fdm.freq_w)):
196            rho_uMM = rho_wuMM[w]
197            rho_up = ksd.transform(rho_uMM)
198            rho_iwp[i, w, :] = rho_up[0]
199    return rho_iwp
200
201
202@pytest.fixture(scope='module')
203@only_on_master(world, broadcast=broadcast)
204def ksd_transform_reference(ksd_reference):
205    ksd, fdm = ksd_reference
206    ref_rho_iwp = ksd_transform_fdm(ksd, fdm)
207    return ref_rho_iwp
208
209
210@pytest.fixture(scope='module', params=parallel_i)
211def build_ksd(initialize_system, request):
212    calc = GPAW('unocc.gpw', parallel=request.param, txt=None)
213    ksd = KohnShamDecomposition(calc)
214    ksd.initialize(calc)
215    ksd.write('ksd.ulm')
216
217
218@pytest.fixture(scope='module', params=parallel_i)
219def load_ksd(build_ksd, request):
220    calc = GPAW('unocc.gpw', parallel=request.param, txt=None)
221    # Initialize positions in order to calculate density
222    calc.initialize_positions()
223    ksd = KohnShamDecomposition(calc, 'ksd.ulm')
224    dmat = DensityMatrix(calc)
225    fdm = FrequencyDensityMatrix(calc, dmat, 'fdm.ulm')
226    return ksd, fdm
227
228
229@pytest.fixture(scope='module')
230def ksd_transform(load_ksd):
231    ksd, fdm = load_ksd
232    rho_iwp = ksd_transform_fdm(ksd, fdm)
233    return rho_iwp
234
235
236def test_ksd_transform(ksd_transform, ksd_transform_reference):
237    ref_iwp = ksd_transform_reference
238    rho_iwp = ksd_transform
239    err = calculate_error(rho_iwp, ref_iwp)
240    atol = 1e-18
241    assert err < atol
242
243
244def test_ksd_transform_real_only(load_ksd, ksd_transform_reference):
245    ksd, fdm = load_ksd
246    ref_iwp = ksd_transform_reference
247    rho_iwp = np.empty((2, len(fdm.freq_w), len(ksd.w_p)), dtype=complex)
248    rho_iwp[:] = np.nan + 1j * np.nan
249    for i, rho_wuMM in enumerate([fdm.FReDrho_wuMM, fdm.FImDrho_wuMM]):
250        for w in range(len(fdm.freq_w)):
251            rho_uMM = rho_wuMM[w]
252            rho_p = ksd.transform([rho_uMM[0].real], broadcast=True)[0] \
253                + 1j * ksd.transform([rho_uMM[0].imag], broadcast=True)[0]
254            rho_iwp[i, w, :] = rho_p
255    err = calculate_error(rho_iwp, ref_iwp)
256    atol = 1e-18
257    assert err < atol
258
259
260def test_dipole_moment_from_ksd(ksd_transform, load_ksd,
261                                dipole_moment_reference):
262    ksd, fdm = load_ksd
263    dm_wv = np.empty((len(fdm.freq_w), 3), dtype=complex)
264    dm_wv[:] = np.nan + 1j * np.nan
265    rho_wp = ksd_transform[0]
266    for w in range(len(fdm.freq_w)):
267        dm_v = ksd.get_dipole_moment([rho_wp[w]])
268        dm_wv[w, :] = dm_v
269
270    ref_wv = dipole_moment_reference
271    err = calculate_error(dm_wv, ref_wv)
272    atol = 1e-7
273    assert err < atol
274
275
276def get_density_fdm(ksd, fdm, kind):
277    assert kind in ['dmat', 'ksd']
278    rho_wg = fdm.dmat.density.finegd.empty(len(fdm.freq_w), dtype=complex)
279    rho_wg[:] = np.nan + 1j * np.nan
280    for w in range(len(fdm.freq_w)):
281        rho_uMM = fdm.FReDrho_wuMM[w]
282        if kind == 'dmat':
283            rho_g = fdm.dmat.get_density([rho_uMM[0].real]) \
284                + 1j * fdm.dmat.get_density([rho_uMM[0].imag])
285        elif kind == 'ksd':
286            rho_up = ksd.transform(rho_uMM, broadcast=True)
287            rho_g = ksd.get_density(fdm.dmat.wfs, [rho_up[0].real]) \
288                + 1j * ksd.get_density(fdm.dmat.wfs, [rho_up[0].imag])
289        rho_wg[w, :] = rho_g
290    return rho_wg
291
292
293@pytest.fixture(scope='module')
294@only_on_master(world, broadcast=broadcast)
295def density_reference(ksd_reference):
296    ksd, fdm = ksd_reference
297    dmat_rho_wg = get_density_fdm(ksd, fdm, 'dmat')
298    ksd_rho_wg = get_density_fdm(ksd, fdm, 'ksd')
299    return dict(dmat=dmat_rho_wg, ksd=ksd_rho_wg)
300
301
302def test_ksd_vs_dmat_density(density_reference):
303    ref_wg = density_reference['dmat']
304    rho_wg = density_reference['ksd']
305    err = calculate_error(rho_wg, ref_wg)
306    atol = 2e-10
307    assert err < atol
308
309
310@pytest.fixture(scope='module')
311def density(load_ksd):
312    ksd, fdm = load_ksd
313    if ksd.ksl.using_blacs:
314        pytest.xfail('Scalapack is not supported')
315    dmat_rho_wg = get_density_fdm(ksd, fdm, 'dmat')
316    ksd_rho_wg = get_density_fdm(ksd, fdm, 'ksd')
317    return dict(dmat=dmat_rho_wg, ksd=ksd_rho_wg)
318
319
320@pytest.mark.parametrize('kind', ['ksd', 'dmat'])
321def test_density(kind, density, load_ksd, density_reference):
322    ksd, fdm = load_ksd
323    ref_wg = density_reference[kind]
324    rho_wg = fdm.dmat.density.finegd.collect(density[kind])
325    err = calculate_error(rho_wg, ref_wg)
326    atol = 3e-19
327    assert err < atol
328
329
330@pytest.mark.parametrize('kind', ['ksd', 'dmat'])
331def test_dipole_moment_from_density(kind, density, load_ksd,
332                                    dipole_moment_reference):
333    ksd, fdm = load_ksd
334    rho_wg = density[kind]
335    dm_wv = np.empty((len(fdm.freq_w), 3), dtype=complex)
336    dm_wv[:] = np.nan + 1j * np.nan
337    for w in range(len(fdm.freq_w)):
338        dm_v = ksd.density.finegd.calculate_dipole_moment(rho_wg[w])
339        dm_wv[w, :] = dm_v
340
341    ref_wv = dipole_moment_reference
342    err = calculate_error(dm_wv, ref_wv)
343    atol = 5e-7
344    assert err < atol
345