1"""Tests for phonons"""
2import os
3import numpy as np
4import abipy.data as abidata
5import abipy.core.abinit_units as abu
6
7from abipy import abilab
8from abipy.core.testing import AbipyTest
9from abipy.dfpt.ddb import DdbFile, DielectricTensorGenerator
10from abipy.dfpt.anaddbnc import AnaddbNcFile
11from abipy.dfpt.phonons import PhononBands
12
13
14test_dir = os.path.join(os.path.dirname(__file__), "..", "..", 'test_files')
15
16
17class DdbTest(AbipyTest):
18
19    def test_alas_ddb_1qpt_phonons(self):
20        """Testing DDB with one q-point"""
21        with DdbFile(os.path.join(test_dir, "AlAs_1qpt_DDB")) as ddb:
22            repr(ddb); str(ddb)
23            # Test qpoints.
24            assert len(ddb.qpoints) == 1
25            assert np.all(ddb.qpoints[0] == [0.25, 0, 0])
26            assert ddb.natom == len(ddb.structure)
27            s = ddb.get_string()
28            with DdbFile.from_string(s) as same_ddb:
29                assert same_ddb.qpoints[0] == ddb.qpoints[0]
30                assert same_ddb.structure == ddb.structure
31
32            # Test header
33            h = ddb.header
34            assert h.version == 100401 and h.ecut == 3
35            assert "ecut" in h and h["ecut"] == h.ecut
36            assert "ixc" in ddb.params
37            assert ddb.params["ixc"] == 7
38            assert h.occ == 4 * [2]
39            assert h.xred.shape == (h.natom, 3) and h.kpt.shape == (h.nkpt, 3)
40            self.assert_equal(h.znucl, [13, 33])
41            assert ddb.version == 100401
42            assert ddb.total_energy is None
43            assert ddb.cart_forces is None
44            assert ddb.cart_stress_tensor is None
45
46            assert np.all(h.symrel[1].T.ravel() == [0, -1, 1, 0, -1, 0, 1, -1, 0])
47            assert np.all(h.symrel[2].T.ravel() == [-1, 0, 0, -1, 0, 1, -1, 1, 0])
48
49            # Test structure
50            struct = ddb.structure
51            assert struct.formula == "Al1 As1"
52
53            # Test interface with Anaddb.
54            str(ddb.qpoints[0])
55            assert ddb.qindex(ddb.qpoints[0]) == 0
56
57            phbands = ddb.anaget_phmodes_at_qpoint(qpoint=ddb.qpoints[0], verbose=1)
58            assert phbands is not None and hasattr(phbands, "phfreqs")
59            phbands = ddb.anaget_phmodes_at_qpoint(qpoint=ddb.qpoints[0], lo_to_splitting=False, verbose=1)
60
61            # Wrong qpoint
62            with self.assertRaises(ValueError):
63                ddb.anaget_phmodes_at_qpoint(qpoint=(0, 0, 0), verbose=1)
64
65            with self.assertRaises(ValueError):
66                ddb.anaget_phmodes_at_qpoints(qpoints=[[0.1, 0.2, 0.3]], ifcflag=0)
67
68            # Wrong ngqpt
69            with self.assertRaises(ddb.AnaddbError):
70                try:
71                    ddb.anaget_phbst_and_phdos_files(ngqpt=(4, 4, 4), verbose=1)
72                except Exception as exc:
73                    # This to test AnaddbError.__str__
74                    str(exc)
75                    raise
76
77            # Cannot compute DOS since we need a mesh.
78            with self.assertRaises(ddb.AnaddbError):
79                ddb.anaget_phbst_and_phdos_files(verbose=1)
80
81            # Test notebook
82            if self.has_nbformat():
83                assert ddb.write_notebook(nbpath=self.get_tmpname(text=True))
84
85            # Test block parsing.
86            blocks = ddb._read_blocks()
87            assert len(blocks) == 1
88            assert blocks[0]["qpt"] == [0.25, 0, 0]
89            assert blocks[0]["dord"] == 2
90            assert blocks[0]["qpt3"] == None
91
92            lines = blocks[0]["data"]
93            assert lines[0].rstrip() == " 2nd derivatives (non-stat.)  - # elements :      36"
94            assert lines[2].rstrip() == "   1   1   1   1  0.80977066582497D+01 -0.46347282336361D-16"
95            assert lines[-1].rstrip() == "   3   2   3   2  0.49482344898401D+01 -0.44885664256253D-17"
96
97            for qpt in ddb.qpoints:
98                assert ddb.get_block_for_qpoint(qpt)
99                assert ddb.get_block_for_qpoint(qpt.frac_coords)
100
101            assert ddb.replace_block_for_qpoint(ddb.qpoints[0], blocks[0]["data"])
102            d_2ord = ddb.get_2nd_ord_dict()
103            assert ddb.qpoints[0] in d_2ord
104            new_qpt = [0.11, 0.22, 3.4]
105            new_block = {"data":
106                             [' 2nd derivatives (non-stat.)  - # elements :      1',
107                              ' qpt  1.10000000E-01  2.20000000E-01  3.40000000E+00   1.0',
108                              '   1   1   1   1  0.38964081001769D+01  0.51387831420710D-24'],
109                         "dord": 2, "qpt": new_qpt, "qpt3": None}
110            assert ddb.insert_block(new_block)
111            assert ddb.insert_block(new_block, replace=True)
112            assert not ddb.insert_block(new_block, replace=False)
113            # Write new DDB file.
114            tmp_file = self.get_tmpname(text=True)
115            ddb.write(tmp_file)
116            with DdbFile(tmp_file) as new_ddb:
117                # check that the new qpoint has been written to the file
118                assert new_qpt in new_ddb.qpoints
119
120            # remove the added block and write again to check that it is equivalent to the original
121            assert ddb.remove_block(dord=2, qpt=new_qpt)
122
123            ddb.write(tmp_file)
124            with DdbFile(tmp_file) as new_ddb:
125                assert ddb.qpoints == new_ddb.qpoints
126                assert DdbFile.as_ddb(new_ddb) is new_ddb
127                # Call anaddb to check if we can read new DDB
128                phbands = new_ddb.anaget_phmodes_at_qpoint(qpoint=new_ddb.qpoints[0], verbose=1)
129                assert phbands is not None and hasattr(phbands, "phfreqs")
130
131    def test_alas_ddb_444_nobecs(self):
132        """Testing DDB for AlAs on a 4x4x4x q-mesh without Born effective charges."""
133        ddb = DdbFile(os.path.join(test_dir, "AlAs_444_nobecs_DDB"))
134        repr(ddb); str(ddb)
135        assert str(ddb.header)
136        assert ddb.to_string(verbose=2)
137        assert ddb.header["nkpt"] == 256
138        assert ddb.header.nsym == 24 and ddb.header.ntypat == 2
139        self.assert_equal(ddb.header.znucl, [13, 33])
140        self.assert_equal(ddb.header.acell, [1, 1, 1])
141        self.assert_equal(ddb.header.ngfft, [10, 10, 10])
142        self.assert_equal(ddb.header.spinat, 0.0)
143        #assert ddb.header.occ.shape = (ddb.header.nsppol, ddb.header.nkpt, ddb.header.nsppol)
144
145        assert not ddb.has_qpoint([0.345, 0.456, 0.567])
146        assert ddb.has_qpoint([0, 0, 0])
147        for qpoint in ddb.qpoints:
148            assert ddb.has_qpoint(qpoint)
149            assert ddb.has_qpoint(qpoint.frac_coords)
150            assert qpoint in ddb.computed_dynmat
151            assert len(ddb.computed_dynmat[qpoint].index[0]) == 4
152
153        assert not ddb.has_bec_terms(select="at_least_one")
154        assert not ddb.has_bec_terms(select="all")
155        assert not ddb.has_epsinf_terms()
156        assert not ddb.has_lo_to_data()
157        assert not ddb.has_internalstrain_terms()
158        assert not ddb.has_piezoelectric_terms()
159        assert not ddb.has_strain_terms()
160        assert ddb.has_at_least_one_atomic_perturbation()
161
162        ref_qpoints = np.reshape([
163                 0.00000000E+00,  0.00000000E+00,  0.00000000E+00,
164                 2.50000000E-01,  0.00000000E+00,  0.00000000E+00,
165                 5.00000000E-01,  0.00000000E+00,  0.00000000E+00,
166                 2.50000000E-01,  2.50000000E-01,  0.00000000E+00,
167                 5.00000000E-01,  2.50000000E-01,  0.00000000E+00,
168                -2.50000000E-01,  2.50000000E-01,  0.00000000E+00,
169                 5.00000000E-01,  5.00000000E-01,  0.00000000E+00,
170                -2.50000000E-01,  5.00000000E-01,  2.50000000E-01,
171        ], (-1, 3))
172
173        assert len(ddb.qpoints) == 8
174        for qpt, ref_qpt in zip(ddb.qpoints, ref_qpoints):
175            assert qpt == ref_qpt
176
177        for qpoint in ddb.qpoints:
178            phbands = ddb.anaget_phmodes_at_qpoint(qpoint=qpoint, verbose=1)
179            assert phbands is not None and hasattr(phbands, "phfreqs")
180
181        phbands = ddb.anaget_phmodes_at_qpoints(verbose=1)
182        assert phbands is not None and hasattr(phbands, "phfreqs")
183
184        phbands = ddb.anaget_phmodes_at_qpoints(qpoints=[[0.1, 0.2, 0.3]], verbose=1, ifcflag=1)
185        assert phbands is not None and hasattr(phbands, "phfreqs")
186
187        assert ddb.anaget_interpolated_ddb(qpt_list=[[0.1, 0.2, 0.3]])
188
189        assert np.all(ddb.guessed_ngqpt == [4, 4, 4])
190
191        # Get bands and Dos
192        phbands_file, phdos_file = ddb.anaget_phbst_and_phdos_files(verbose=1)
193        phbands, phdos = phbands_file.phbands, phdos_file.phdos
194
195        assert ddb.view_phononwebsite(verbose=1, dryrun=True) == 0
196
197        if self.has_matplotlib():
198            assert phbands.plot_with_phdos(phdos, show=False,
199                title="Phonon bands and DOS of %s" % phbands.structure.formula)
200            assert phbands_file.plot_phbands(show=False)
201
202        if self.has_panel():
203            assert hasattr(ddb.get_panel(), "show")
204
205        # Get epsinf and becs
206        r = ddb.anaget_epsinf_and_becs(chneut=1, verbose=1)
207        epsinf, becs = r.epsinf, r.becs
208        assert np.all(becs.values == 0)
209        repr(becs); str(becs)
210        assert becs.to_string(verbose=2)
211
212        same_becs = self.decode_with_MSON(becs)
213        self.assert_almost_equal(same_becs.values, becs.values)
214
215        max_err = becs.check_site_symmetries(verbose=2)
216        #print(max_err)
217        assert max_err == 0
218
219        self.assert_almost_equal(phdos.idos.values[-1], 3 * len(ddb.structure), decimal=1)
220        phbands_file.close()
221        phdos_file.close()
222
223        # Test DOS computation via anaddb.
224        c = ddb.anacompare_phdos(nqsmalls=[2, 3, 4], dipdip=0, num_cpus=1, verbose=2)
225        assert c.phdoses and c.plotter is not None
226        if self.has_matplotlib():
227            assert c.plotter.combiplot(show=False)
228
229        # Use threads and gaussian DOS.
230        c = ddb.anacompare_phdos(nqsmalls=[2, 3, 4], dos_method="gaussian", dipdip=0, asr=0,
231                num_cpus=2, verbose=2)
232        assert c.phdoses and c.plotter is not None
233
234        # Execute anaddb to compute the interatomic force constants.
235        ifc = ddb.anaget_ifc()
236        str(ifc); repr(ifc)
237        assert ifc.to_string(verbose=2)
238        assert ifc.structure == ddb.structure
239        assert ifc.number_of_atoms == len(ddb.structure)
240
241        if self.has_matplotlib():
242            assert ifc.plot_longitudinal_ifc(show=False)
243            assert ifc.plot_longitudinal_ifc_short_range(show=False)
244            assert ifc.plot_longitudinal_ifc_ewald(show=False)
245
246        # Test get_coarse.
247        with ddb.get_coarse([2, 2, 2]) as coarse_ddb:
248            # Check whether anaddb can read the coarse DDB.
249
250            with coarse_ddb.anaget_phbst_and_phdos_files(nqsmall=4, ndivsm=1, verbose=1) as g:
251                coarse_phbands_file, coarse_phdos_file = g
252                assert coarse_phbands_file.filepath == g.files[0].filepath
253                assert coarse_phdos_file.filepath == g.files[1].filepath
254
255        ddb.close()
256
257    def test_zno_gamma_ddb_with_becs(self):
258        """Testing DDB for ZnO: Gamma only, with Born effective charges and E_macro."""
259        with DdbFile(os.path.join(test_dir, "ZnO_gamma_becs_DDB")) as ddb:
260            repr(ddb); str(ddb)
261            assert str(ddb.header)
262            assert ddb.to_string(verbose=2)
263            assert ddb.header["nkpt"] == 486
264            assert ddb.header.nband == 22 and ddb.header.occopt == 1
265            self.assert_equal(ddb.header.typat, [1, 1, 2, 2])
266            assert len(ddb.header.wtk) == ddb.header.nkpt
267            #assert ddb.header.occ.shape = (ddb.header.nsppol, ddb.header.nkpt, ddb.header.nsppol)
268
269            assert not ddb.has_qpoint([0.345, 0.456, 0.567])
270            assert ddb.has_qpoint([0, 0, 0])
271            assert len(ddb.qpoints) == 1
272            for qpoint in ddb.qpoints:
273                assert ddb.has_qpoint(qpoint)
274                assert ddb.has_qpoint(qpoint.frac_coords)
275                assert qpoint in ddb.computed_dynmat
276                assert len(ddb.computed_dynmat[qpoint].index[0]) == 4
277
278            # Test Lru_cache as well
279            assert ddb.has_bec_terms(select="at_least_one")
280            assert ddb.has_bec_terms(select="at_least_one")
281            assert ddb.has_bec_terms(select="all")
282            assert ddb.has_bec_terms(select="all")
283            assert ddb.has_epsinf_terms()
284            assert ddb.has_lo_to_data()
285
286            # Get epsinf and becs
287            epsinf, becs = ddb.anaget_epsinf_and_becs(chneut=1, verbose=1)
288
289            ref_becs_values = [
290                [[  2.15646571e+00,   0.00000000e+00,   3.26402110e-25],
291                 [  0.00000000e+00,   2.15646571e+00,  -5.46500204e-24],
292                 [ -5.66391495e-25,  -6.54012564e-25,   2.19362823e+00]],
293                [[  2.15646571e+00,   0.00000000e+00,   1.19680774e-24],
294                 [  0.00000000e+00,   2.15646571e+00,   8.10327888e-24],
295                 [ -1.69917448e-24,  -1.30802513e-24,   2.19362823e+00]],
296                [[ -2.15646571e+00,   6.66133815e-16,  -1.84961196e-24],
297                 [  8.88178420e-16,  -2.15646571e+00,   2.82672519e-24],
298                 [ -3.39834897e-24,  -3.27006282e-25,  -2.19362823e+00]],
299                [[ -2.15646571e+00,  -6.66133815e-16,   3.26402110e-25],
300                 [ -8.88178420e-16,  -2.15646571e+00,  -5.46500204e-24],
301                 [  5.66391495e-24,   2.28904397e-24,  -2.19362823e+00]]
302                ]
303
304            ref_epsinf = [[ 5.42055574e+00,  8.88178420e-16, -1.30717901e-25],
305                          [-8.88178420e-16,  5.42055574e+00, -2.26410045e-25],
306                          [-1.30717901e-25,  2.26410045e-25,  4.98835236e+00]]
307
308            self.assert_almost_equal(becs.values, ref_becs_values)
309            self.assert_almost_equal(np.array(epsinf), ref_epsinf)
310            repr(becs); str(becs)
311            assert becs.to_string(verbose=2)
312            for arr, z in zip(becs.values, becs.zstars):
313                self.assert_equal(arr, z)
314            df = becs.get_voigt_dataframe(view="all", select_symbols="O", verbose=1)
315            assert len(df) == 2
316            # Equivalent atoms should have same determinant.
317            self.assert_almost_equal(df["determinant"].values, df["determinant"].values[0])
318
319            # get the dielectric tensor generator from anaddb
320            dtg = ddb.anaget_dielectric_tensor_generator(verbose=2)
321            assert dtg is not None and hasattr(dtg, "phfreqs")
322            assert dtg.to_string(verbose=2)
323
324    def test_mgo_becs_epsinf(self):
325        """
326        Testing DDB for MgO with with Born effective charges and E_macro.
327        Large breaking of the ASR.
328        """
329        with abilab.abiopen(abidata.ref_file("mp-1009129-9x9x10q_ebecs_DDB")) as ddb:
330            assert ddb.structure.formula == "Mg1 O1"
331            assert len(ddb.qpoints) == 72
332            assert ddb.has_epsinf_terms()
333            assert ddb.has_epsinf_terms(select="at_least_one_diagoterm")
334            assert ddb.has_bec_terms()
335
336            if self.has_matplotlib():
337                plotter = ddb.anacompare_asr(asr_list=(0, 2), chneut_list=(0, 1), dipdip=1,
338                    nqsmall=2, ndivsm=5, dos_method="tetra", ngqpt=None, verbose=2)
339                str(plotter)
340                assert plotter.combiplot(show=False)
341
342                # Test nqsmall == 0
343                plotter = ddb.anacompare_asr(asr_list=(0, 2), chneut_list=(0, 1), dipdip=1,
344                    nqsmall=0, ndivsm=5, dos_method="tetra", ngqpt=None, verbose=2)
345                assert plotter.gridplot(show=False)
346
347                plotter = ddb.anacompare_dipdip(chneut_list=(0, 1), asr=1,
348                    nqsmall=0, ndivsm=5, dos_method="gaussian", ngqpt=None, verbose=2)
349                assert plotter.gridplot(show=False)
350
351    def test_mgb2_ddbs_ngkpt_tsmear(self):
352        """Testing multiple DDB files and gridplot_with_hue."""
353        paths = [
354            #"mgb2_444k_0.01tsmear_DDB",
355            #"mgb2_444k_0.02tsmear_DDB",
356            #"mgb2_444k_0.04tsmear_DDB",
357            "mgb2_888k_0.01tsmear_DDB",
358            #"mgb2_888k_0.02tsmear_DDB",
359            "mgb2_888k_0.04tsmear_DDB",
360            "mgb2_121212k_0.01tsmear_DDB",
361            #"mgb2_121212k_0.02tsmear_DDB",
362            "mgb2_121212k_0.04tsmear_DDB",
363        ]
364        paths = [os.path.join(abidata.dirpath, "refs", "mgb2_phonons_nkpt_tsmear", f) for f in paths]
365
366        robot = abilab.DdbRobot.from_files(paths)
367        robot.remap_labels(lambda ddb: "nkpt: %s, tsmear: %.3f" % (ddb.header["nkpt"], ddb.header["tsmear"]))
368
369        # Invoke anaddb to get bands and doses
370        r = robot.anaget_phonon_plotters(nqsmall=2)
371
372        data = robot.get_dataframe_at_qpoint(qpoint=(0, 0, 0), units="meV", with_geo=False)
373        assert "tsmear" in data
374        self.assert_equal(data["ixc"].values, 1)
375
376        if self.has_matplotlib():
377            assert r.phbands_plotter.gridplot_with_hue("nkpt", with_dos=True, show=False)
378            assert r.phbands_plotter.gridplot_with_hue("nkpt", with_dos=False, show=False)
379
380        robot.close()
381
382    def test_ddb_from_mprester(self):
383        """Test creation methods for DdbFile and DdbRobot from MP REST API."""
384        #ddb = abilab.DdbFile.from_mpid("mp-1138")
385        ddb = abilab.DdbFile.from_mpid("mp-149")
386        assert ddb.structure.formula == "Si2"
387        self.assert_equal(ddb.guessed_ngqpt, [9, 9, 9])
388        assert ddb.header["version"] == 100401
389        assert ddb.header["ixc"] == -116133
390
391        mpid_list = ["mp-149", "mp-1138"]
392        robot = abilab.DdbRobot.from_mpid_list(mpid_list)
393        assert len(robot) == len(mpid_list)
394        assert robot.abifiles[1].structure.formula == "Li1 F1"
395        assert robot.abifiles[1].header["ixc"] == -116133
396
397    def test_alas_with_third_order(self):
398        """
399        Testing DDB containing also third order derivatives.
400        """
401        with abilab.abiopen(abidata.ref_file("refs/alas_nl_dfpt/AlAs_nl_dte_DDB")) as ddb:
402            repr(ddb); str(ddb)
403            assert ddb.to_string(verbose=2)
404            self.assert_almost_equal(ddb.total_energy.to("Ha"), -0.10085769246152e+02)
405            assert ddb.cart_forces is not None
406            stress = ddb.cart_stress_tensor
407            # Ha/Bohr^3 from DDB
408            ref_voigt = np.array([-0.31110177329142E-05, -0.31110177329142E-05, -0.31110177329146E-05,
409                                  0.00000000000000E+00, 0.00000000000000E+00, 0.00000000000000E+00])
410            # AbiPy stress is in GPa
411            self.assert_almost_equal(stress[0, 0], ref_voigt[0] * abu.HaBohr3_GPa)
412            self.assert_almost_equal(stress[1, 1], ref_voigt[1] * abu.HaBohr3_GPa)
413            self.assert_almost_equal(stress[2, 2], ref_voigt[2] * abu.HaBohr3_GPa)
414            self.assert_almost_equal(stress[1, 2], ref_voigt[3] * abu.HaBohr3_GPa)
415            self.assert_almost_equal(stress[0, 2], ref_voigt[4] * abu.HaBohr3_GPa)
416            self.assert_almost_equal(stress[0, 1], ref_voigt[5] * abu.HaBohr3_GPa)
417
418            for qpoint in ddb.qpoints:
419                assert qpoint in ddb.computed_dynmat
420
421            raman = ddb.anaget_raman()
422            # take the mean to avoid potential changes in the order of degenerate modes.
423            sus_mean = raman.susceptibility[3:, 0, 1].mean()
424            self.assertAlmostEqual(sus_mean, -0.002829737, places=5)
425
426            # Test block parsing.
427            blocks = ddb._read_blocks()
428            assert len(blocks) == 4
429            assert blocks[3]["qpt"] == None
430            assert blocks[3]["dord"] == 3
431            assert blocks[3]["qpt3"] == [[0.,] * 3] * 3
432
433
434class DielectricTensorGeneratorTest(AbipyTest):
435
436    def test_base(self):
437        """Testing DielectricTensor"""
438        anaddbnc_fname = abidata.ref_file("AlAs_nl_dte_anaddb.nc")
439        phbstnc_fname = abidata.ref_file("AlAs_nl_dte_PHBST.nc")
440
441        d = DielectricTensorGenerator.from_files(phbstnc_fname, anaddbnc_fname)
442        assert d.eps0.shape == (3, 3)
443        df = d.epsinf.get_dataframe()
444        assert d.epsinf._repr_html_()
445
446        repr(d); str(d)
447        assert d.to_string(verbose=2)
448
449        df = d.get_oscillator_dataframe(reim="all", tol=1e-8)
450        df = d.get_oscillator_dataframe(reim="im", tol=1e-8)
451        df = d.get_oscillator_dataframe(reim="re", tol=1e-8)
452
453        self.assertAlmostEqual(d.tensor_at_frequency(0.001, units='Ha', gamma_ev=0.0)[0, 0], 11.917178540635028)
454
455        d = DielectricTensorGenerator.from_objects(PhononBands.from_file(phbstnc_fname),
456                                                   AnaddbNcFile.from_file(anaddbnc_fname))
457
458        self.assertAlmostEqual(d.tensor_at_frequency(0.001, units='Ha', gamma_ev=0.0)[0, 0], 11.917178540635028)
459        self.assertAlmostEqual(d.reflectivity([1, 0, 0], 0.045), 0.59389746, places=5)
460
461        if self.has_matplotlib():
462            assert d.plot_vs_w(w_min=0.0001, w_max=0.01, num=10, units="Ha", show=False)
463            assert d.plot_vs_w(w_min=0, w_max=None, num=10, units="cm-1", show=False)
464            for comp in ["diag", "all", "diag_av"]:
465                assert d.plot_vs_w(num=10, component=comp, units="cm-1", show=False)
466            assert d.plot_all(units="mev", show=False)
467            assert d.plot_e0w_qdirs(show=False)
468            assert d.plot_reflectivity(show=False)
469
470
471class DdbRobotTest(AbipyTest):
472
473    def test_ddb_robot(self):
474        """Testing DDB robots."""
475        assert not abilab.DdbRobot.class_handles_filename("foo_DDB.nc")
476        assert abilab.DdbRobot.class_handles_filename("foo_DDB")
477
478        path = abidata.ref_file("refs/znse_phonons/ZnSe_hex_qpt_DDB")
479        robot = abilab.DdbRobot.from_files(path)
480        robot.add_file("same_ddb", path)
481        repr(robot); str(robot)
482        assert robot.to_string(verbose=2)
483        assert len(robot) == 2
484        assert robot.EXT == "DDB"
485
486        data = robot.get_dataframe_at_qpoint(qpoint=[0, 0, 0], asr=2, chneut=1,
487                dipdip=0, with_geo=True, abspath=True)
488        assert "mode1" in data and "alpha" in data
489
490        r = robot.anaget_phonon_plotters(nqsmall=2, ndivsm=2, dipdip=0, verbose=2)
491        if self.has_matplotlib():
492            assert r.phbands_plotter.gridplot(show=False)
493            assert r.phdos_plotter.gridplot(show=False)
494
495        if self.has_nbformat():
496            assert robot.write_notebook(nbpath=self.get_tmpname(text=True))
497
498        robot.close()
499
500    def test_robot_elastic(self):
501        """Test DdbRobot with anacompare_elastic method."""
502        self.skip_if_abinit_not_ge("8.9.3")
503        filepaths = [abidata.ref_file("refs/alas_elastic_dfpt/AlAs_elastic_DDB")]
504
505        with abilab.DdbRobot.from_files(filepaths) as robot:
506            robot.add_file("samefile", filepaths[0])
507            assert len(robot) == 2
508
509            # Test anacompare_elastic
510            ddb_header_keys = ["nkpt", "tsmear"]
511            r = robot.anacompare_elastic(ddb_header_keys=ddb_header_keys, with_path=True,
512                with_structure=True, with_spglib=False, relaxed_ion="automatic", piezo="automatic", verbose=1)
513            df, edata_list = r.df, r.elastdata_list
514            assert "tensor_name" in df.keys()
515            assert "ddb_path" in df
516            for k in ddb_header_keys:
517                assert k in df
518            assert len(edata_list) == 2
519
520    def test_robot_becs_eps(self):
521        """Test DdbRobot with anacompare_becs and eps methods."""
522        paths = ["out_ngkpt222_DDB", "out_ngkpt444_DDB", "out_ngkpt888_DDB"]
523        paths = [os.path.join(abidata.dirpath, "refs", "alas_eps_and_becs_vs_ngkpt", f) for f in paths]
524
525        with abilab.DdbRobot.from_files(paths) as robot:
526            # Test anacompare_epsinf
527            rinf = robot.anacompare_epsinf(ddb_header_keys="nkpt", chneut=0, with_path=True, verbose=2)
528            assert "nkpt" in rinf.df
529            assert "ddb_path" in rinf.df
530            assert len(rinf.epsinf_list) == len(robot)
531
532            # Test anacompare_eps0
533            r0 = robot.anacompare_eps0(ddb_header_keys=["nkpt", "tsmear"], asr=0, tol=1e-5, with_path=True, verbose=2)
534            assert len(r0.eps0_list) == len(robot)
535            assert len(r0.dgen_list) == len(robot)
536            assert "ddb_path" in r0.df
537
538            # Test anacompare_becs
539            rb = robot.anacompare_becs(ddb_header_keys=["nkpt", "tsmear"], chneut=0, tol=1e-5, with_path=True, verbose=2)
540            assert len(rb.becs_list) == len(robot)
541            for k in ["nkpt", "tsmear", "ddb_path"]:
542                assert k in rb.df
543
544
545class PhononComputationTest(AbipyTest):
546
547    def test_phonon_computation(self):
548        """Testing if pjdoses compute by anaddb integrate to 3*natom"""
549        path = os.path.join(abidata.dirpath, "refs", "mgb2_phonons_nkpt_tsmear", "mgb2_121212k_0.04tsmear_DDB")
550        ddb = abilab.abiopen(path)
551
552        for dos_method in ("tetra", "gaussian"):
553            # Get phonon bands and Dos with anaddb.
554            phbands_file, phdos_file = ddb.anaget_phbst_and_phdos_files(nqsmall=4, ndivsm=2,
555                dipdip=0, chneut=0, dos_method=dos_method, lo_to_splitting=False, verbose=1)
556
557            phbands, phdos = phbands_file.phbands, phdos_file.phdos
558            natom3 = len(phbands.structure) * 3
559
560            # Test that amu is present with correct values.
561            assert phbands.amu is not None
562            self.assert_almost_equal(phbands.amu[12.0], 0.24305e+02)
563            self.assert_almost_equal(phbands.amu[5.0], 0.10811e+02)
564            self.assert_almost_equal(phbands.amu_symbol["Mg"], phbands.amu[12.0])
565            self.assert_almost_equal(phbands.amu_symbol["B"],  phbands.amu[5.0])
566
567            # Total PHDOS should integrate to 3 * natom
568            # Note that anaddb does not renormalize the DOS so we have to increase the tolerance.
569            #E       Arrays are not almost equal to 2 decimals
570            #E        ACTUAL: 8.9825274146312282
571            #E        DESIRED: 9
572            self.assert_almost_equal(phdos.integral_value, natom3, decimal=1)
573
574            # Test convertion to eigenvectors. Verify that they are orthonormal
575            cidentity = np.eye(natom3, dtype=complex)
576            eig = phbands.dyn_mat_eigenvect
577            for iq in range(phbands.nqpt):
578                #print("About to test iq", iq, np.dot(eig[iq].T.conjugate(), eig[iq]))
579                #assert np.allclose(np.dot(eig[iq], eig[iq].T), cidentity , atol=1e-5, rtol=1e-3)
580                self.assert_almost_equal(np.dot(eig[iq].conjugate().T, eig[iq]), cidentity) #, decimal=1)
581
582            # Summing projected DOSes over types should give the total DOS.
583            pj_sum = sum(pjdos.integral_value for pjdos in phdos_file.pjdos_symbol.values())
584            self.assert_almost_equal(phdos.integral_value, pj_sum)
585
586            # Summing projected DOSes over types and directions should give the total DOS.
587            # phdos_rc_type[ntypat, 3, nomega]
588            values = phdos_file.reader.read_value("pjdos_rc_type").sum(axis=(0, 1))
589            tot_dos = abilab.Function1D(phdos.mesh, values)
590            self.assert_almost_equal(phdos.integral_value, tot_dos.integral_value)
591
592            phbands_file.close()
593            phdos_file.close()
594
595        ddb.close()
596