1import pytest
2import numpy as np
3from ase import Atoms
4
5from gpaw import GPAW
6from gpaw.tddft import TDDFT, DipoleMomentWriter
7from gpaw.inducedfield.inducedfield_base import BaseInducedField
8from gpaw.inducedfield.inducedfield_tddft import TDDFTInducedField
9from gpaw.poisson import PoissonSolver
10from gpaw.test import equal
11
12
13do_print_values = False  # Use this for printing the reference values
14
15if do_print_values:
16    i = 1
17
18    def equal(x, y, tol):  # noqa
19        global i
20        print("equal(val%d, %20.12f, tol)" % (i, x))
21        i += 1
22
23
24@pytest.mark.ci
25def test_inducedfield_td(in_tmp_dir):
26    poisson_eps = 1e-12
27    density_eps = 1e-6
28
29    # PoissonSolver
30    poissonsolver = PoissonSolver('fd', eps=poisson_eps)
31
32    # Na2 cluster
33    atoms = Atoms(symbols='Na2',
34                  positions=[(0, 0, 0), (3.0, 0, 0)],
35                  pbc=False)
36    atoms.center(vacuum=3.0)
37
38    # Standard ground state calculation
39    calc = GPAW(nbands=2,
40                h=0.6,
41                setups={'Na': '1'},
42                poissonsolver=poissonsolver,
43                convergence={'density': density_eps})
44    atoms.calc = calc
45    _ = atoms.get_potential_energy()
46    calc.write('na2_gs.gpw', mode='all')
47
48    # Standard time-propagation initialization
49    time_step = 10.0
50    iterations = 20
51    kick_strength = [1.0e-3, 1.0e-3, 0.0]
52    td_calc = TDDFT('na2_gs.gpw')
53    DipoleMomentWriter(td_calc, 'na2_td_dm.dat')
54
55    # Create and attach InducedField object
56    frequencies = [1.0, 2.08]     # Frequencies of interest in eV
57    folding = 'Gauss'             # Folding function
58    width = 0.1                   # Line width for folding in eV
59    ind = TDDFTInducedField(paw=td_calc,
60                            frequencies=frequencies,
61                            folding=folding,
62                            width=width,
63                            restart_file='na2_td.ind')
64
65    # Propagate as usual
66    td_calc.absorption_kick(kick_strength=kick_strength)
67    td_calc.propagate(time_step, iterations // 2)
68
69    # Save TDDFT and InducedField objects
70    td_calc.write('na2_td.gpw', mode='all')
71    ind.write('na2_td.ind')
72    ind.paw = None
73
74    # Restart and continue
75    td_calc = TDDFT('na2_td.gpw')
76    DipoleMomentWriter(td_calc, 'na2_td_dm.dat')
77
78    # Load and attach InducedField object
79    ind = TDDFTInducedField(filename='na2_td.ind',
80                            paw=td_calc,
81                            restart_file='na2_td.ind')
82
83    # Continue propagation as usual
84    td_calc.propagate(time_step, iterations // 2)
85
86    # Calculate induced electric field
87    ind.calculate_induced_field(gridrefinement=2,
88                                from_density='comp',
89                                poissonsolver=poissonsolver,
90                                extend_N_cd=3 * np.ones((3, 2), int),
91                                deextend=True)
92
93    # Save
94    ind.write('na2_td_field.ind', 'all')
95    ind.paw = None
96    td_calc = None
97    ind = None
98
99    # Read data (test also field data I/O)
100    ind = BaseInducedField(filename='na2_td_field.ind',
101                           readmode='field')
102
103    # Estimate tolerance (worst case error accumulation)
104    tol = (iterations * ind.fieldgd.integrate(ind.fieldgd.zeros() + 1.0) *
105           max(density_eps, np.sqrt(poisson_eps)))
106    # tol = 0.038905993684
107    if do_print_values:
108        print('tol = %.12f' % tol)
109
110    # Test
111    val1 = ind.fieldgd.integrate(ind.Ffe_wg[0])
112    val2 = ind.fieldgd.integrate(np.abs(ind.Fef_wvg[0][0]))
113    val3 = ind.fieldgd.integrate(np.abs(ind.Fef_wvg[0][1]))
114    val4 = ind.fieldgd.integrate(np.abs(ind.Fef_wvg[0][2]))
115    val5 = ind.fieldgd.integrate(ind.Ffe_wg[1])
116    val6 = ind.fieldgd.integrate(np.abs(ind.Fef_wvg[1][0]))
117    val7 = ind.fieldgd.integrate(np.abs(ind.Fef_wvg[1][1]))
118    val8 = ind.fieldgd.integrate(np.abs(ind.Fef_wvg[1][2]))
119
120    equal(val1, 1926.232999117403, tol)
121    equal(val2, 0.427606450419, tol)
122    equal(val3, 0.565823985683, tol)
123    equal(val4, 0.372493489423, tol)
124    equal(val5, 1945.618902611449, tol)
125    equal(val6, 0.423899965987, tol)
126    equal(val7, 0.560882533828, tol)
127    equal(val8, 0.369203021329, tol)
128