1"""Tests for electrons.gw module"""
2import os
3import collections
4import numpy as np
5import abipy.data as abidata
6
7from abipy import abilab
8from abipy.electrons.gw import *
9from abipy.core.testing import AbipyTest
10
11
12class TestQPList(AbipyTest):
13
14    def setUp(self):
15        self.sigres = sigres = abilab.abiopen(abidata.ref_file("tgw1_9o_DS4_SIGRES.nc"))
16        repr(self.sigres); str(self.sigres)
17        assert self.sigres.to_string(verbose=2)
18        self.qplist = sigres.get_qplist(spin=0, kpoint=sigres.gwkpoints[0])
19
20    def tearDown(self):
21        self.sigres.close()
22
23    def test_qplist(self):
24        """Test QPList object."""
25        qplist = self.qplist
26        assert isinstance(qplist, collections.abc.Iterable)
27        self.serialize_with_pickle(qplist, protocols=[-1])
28
29        repr(qplist); str(qplist)
30        qplist_copy = qplist.copy()
31        assert qplist_copy == qplist
32
33        qpl_e0sort = qplist.sort_by_e0()
34        assert qpl_e0sort.is_e0sorted
35        e0mesh = qpl_e0sort.get_e0mesh()
36        assert e0mesh[-1] > e0mesh[0]
37        values = qpl_e0sort.get_field("qpeme0")
38        assert len(values) == len(qpl_e0sort)
39
40        qp = qpl_e0sort[2]
41        value = qpl_e0sort.get_skb_field(qp.skb, "qpeme0")
42        assert qp.qpeme0 == value
43
44        with self.assertRaises(ValueError):
45            qplist.get_e0mesh()
46        with self.assertRaises(ValueError):
47            qplist.merge(qpl_e0sort)
48        with self.assertRaises(ValueError):
49            qplist.merge(qplist)
50
51        other_qplist = self.sigres.get_qplist(spin=0, kpoint=self.sigres.gwkpoints[1])
52        qpl_merge = qplist.merge(other_qplist)
53
54        assert all(qp in qpl_merge for qp in qplist)
55        assert all(qp in qpl_merge for qp in other_qplist)
56
57        # Test QPState object.
58        qp = qplist[0]
59        repr(qp); str(qp)
60        #qp.to_string(verbose=2, title="QP State")
61        assert str(qp.tips)
62        assert qp.spin == 0
63        assert qp.kpoint == self.sigres.gwkpoints[0]
64        assert qp.kpoint is self.sigres.gwkpoints[0]
65
66        self.assert_equal(qp.re_qpe + 1j * qp.imag_qpe, qp.qpe)
67        self.assert_almost_equal(qp.e0, -5.04619941555265, decimal=5)
68        self.assert_almost_equal(qp.qpe.real, -4.76022137474714)
69        self.assert_almost_equal(qp.qpe.imag, -0.011501666037697)
70        self.assert_almost_equal(qp.sigxme, -16.549383605401)
71
72
73class TestSigresFile(AbipyTest):
74
75    def test_readall(self):
76        for path in abidata.SIGRES_NCFILES:
77            with abilab.abiopen(path) as sigres:
78                repr(sigres); str(sigres)
79                assert sigres.to_string(verbose=2)
80                assert len(sigres.structure)
81
82    def test_base(self):
83        """Test SIGRES File."""
84        sigres = abilab.abiopen(abidata.ref_file("tgw1_9o_DS4_SIGRES.nc"))
85        assert sigres.nsppol == 1
86        sigres.print_qps(precision=5, ignore_imag=False)
87        assert sigres.params["nsppol"] == sigres.nsppol
88        assert not sigres.has_spectral_function
89
90        # In this run IBZ = kptgw
91        assert len(sigres.ibz) == 6
92        assert sigres.gwkpoints == sigres.ibz
93        # No spectral function
94        assert not sigres.reader.has_spfunc
95        with self.assertRaises(ValueError):
96            sigres.read_sigee_skb(0, 0, 0)
97
98        kptgw_coords = np.reshape([
99            -0.25, -0.25, 0,
100            -0.25, 0.25, 0,
101            0.5, 0.5, 0,
102            -0.25, 0.5, 0.25,
103            0.5, 0, 0,
104            0, 0, 0
105        ], (-1, 3))
106
107        self.assert_almost_equal(sigres.ibz.frac_coords, kptgw_coords)
108
109        qpgaps = [3.53719151871085, 4.35685250045637, 4.11717896881632,
110                  8.71122659251508, 3.29693118466282, 3.125545059031]
111        self.assert_almost_equal(sigres.qpgaps, np.reshape(qpgaps, (1, 6)))
112
113        ik = 2
114        df = sigres.get_dataframe_sk(spin=0, kpoint=ik)
115        same_df = sigres.get_dataframe_sk(spin=0, kpoint=sigres.gwkpoints[ik])
116
117        assert np.all(df["qpe"] == same_df["qpe"])
118
119        # Ignore imaginary part.
120        df_real = sigres.get_dataframe_sk(spin=0, kpoint=ik, ignore_imag=True)
121        assert np.all(df["qpe"].to_numpy().real == df_real["qpe"])
122
123        full_df = sigres.to_dataframe()
124
125        marker = sigres.get_marker("qpeme0")
126        assert marker and len(marker.x)
127
128        if self.has_matplotlib():
129            assert sigres.plot_qps_vs_e0(fontsize=8, e0=1.0, xlims=(-10, 10), show=False)
130            with self.assertRaises(ValueError):
131                sigres.plot_qps_vs_e0(with_fields="qqeme0", show=False)
132            assert sigres.plot_qps_vs_e0(with_fields="qpeme0", show=False)
133            assert sigres.plot_qps_vs_e0(exclude_fields=["vUme"], show=False)
134            assert sigres.plot_ksbands_with_qpmarkers(qpattr="sigxme", e0=None, fact=1000, show=False)
135            assert sigres.plot_qpbands_ibz(show=False)
136
137            assert sigres.plot_eigvec_qp(spin=0, kpoint=0, show=False)
138            assert sigres.plot_eigvec_qp(spin=0, kpoint=None, show=False)
139
140            assert sigres.plot_qpgaps(plot_qpmks=True, show=False)
141            assert sigres.plot_qpgaps(plot_qpmks=False, show=False)
142
143        if self.has_nbformat():
144            sigres.write_notebook(nbpath=self.get_tmpname(text=True))
145
146        sigres.close()
147
148    def test_sigres_with_spectral_function(self):
149        """Test methods to plot spectral function from SIGRES."""
150        filepath = abidata.ref_file("al_g0w0_sigmaw_SIGRES.nc")
151        with abilab.abiopen(filepath) as sigres:
152            assert sigres.reader.has_spfunc
153            sigma = sigres.read_sigee_skb(spin=0, kpoint=(0, 0, 0), band=0)
154            repr(sigma); str(sigma)
155            assert sigma.to_string(verbose=2)
156
157            if self.has_matplotlib():
158                assert sigma.plot(what_list="spfunc", xlims=(-10, 10), fontsize=12, show=False)
159                assert sigres.plot_spectral_functions(show=False)
160                assert sigres.plot_spectral_functions(include_bands=range(0, 4), show=False)
161
162            with abilab.SigresRobot() as robot:
163                robot.add_file("foo", filepath)
164                robot.add_file("same", filepath)
165                if self.has_matplotlib():
166                    assert robot.plot_selfenergy_conv(0, (0.5, 0, 0), band=3, show=False)
167                    assert robot.plot_selfenergy_conv(0, (0.5, 0, 0), band=3, sortby="nkpt", hue="nband", show=False)
168                    with self.assertRaises(AttributeError):
169                        assert robot.plot_selfenergy_conv(0, (0.5, 0, 0), band=3, sortby="foonkpt", hue="nband", show=False)
170
171    def test_interpolator(self):
172        """Test QP interpolation."""
173        # Get quasiparticle results from the SIGRES.nc database.
174        sigres = abilab.abiopen(abidata.ref_file("si_g0w0ppm_nband30_SIGRES.nc"))
175
176        # Interpolate QP corrections and apply them on top of the KS band structures.
177        # QP band energies are returned in r.qp_ebands_kpath and r.qp_ebands_kmesh.
178
179        # Just to test call without ks_ebands.
180        r = sigres.interpolate(lpratio=5,
181                               ks_ebands_kpath=None,
182                               ks_ebands_kmesh=None,
183                               verbose=0, filter_params=[1.0, 1.0], line_density=10)
184
185        r = sigres.interpolate(lpratio=5,
186                               ks_ebands_kpath=abidata.ref_file("si_nscf_GSR.nc"),
187                               ks_ebands_kmesh=abidata.ref_file("si_scf_GSR.nc"),
188                               verbose=0, filter_params=[1.0, 1.0], line_density=10)
189
190        assert r.qp_ebands_kpath is not None
191        assert r.qp_ebands_kpath.kpoints.is_path
192        #print(r.qp_ebands_kpath.kpoints.ksampling, r.qp_ebands_kpath.kpoints.mpdivs_shifts)
193        assert r.qp_ebands_kpath.kpoints.mpdivs_shifts == (None, None)
194
195        assert r.qp_ebands_kmesh is not None
196        assert r.qp_ebands_kmesh.kpoints.is_ibz
197        assert r.qp_ebands_kmesh.kpoints.ksampling is not None
198        assert r.qp_ebands_kmesh.kpoints.is_mpmesh
199        qp_mpdivs, qp_shifts = r.qp_ebands_kmesh.kpoints.mpdivs_shifts
200        assert qp_mpdivs is not None
201        assert qp_shifts is not None
202        ks_mpdivs, ks_shifts = r.ks_ebands_kmesh.kpoints.mpdivs_shifts
203        self.assert_equal(qp_mpdivs, ks_mpdivs)
204        self.assert_equal(qp_shifts, ks_shifts)
205
206        # Get DOS from interpolated energies.
207        ks_edos = r.ks_ebands_kmesh.get_edos()
208        qp_edos = r.qp_ebands_kmesh.get_edos()
209
210        r.qp_ebands_kmesh.to_bxsf(self.get_tmpname(text=True))
211
212        points = sigres.get_points_from_ebands(r.qp_ebands_kpath, size=24, verbose=2)
213
214        # Plot the LDA and the QPState band structure with matplotlib.
215        plotter = abilab.ElectronBandsPlotter()
216        plotter.add_ebands("LDA", r.ks_ebands_kpath, edos=ks_edos)
217        plotter.add_ebands("GW (interpolated)", r.qp_ebands_kpath, edos=qp_edos)
218
219        if self.has_matplotlib():
220            assert plotter.combiplot(title="Silicon band structure", show=False)
221            assert plotter.gridplot(title="Silicon band structure", show=False)
222            assert r.qp_ebands_kpath.plot(points=points, show=False)
223
224        sigres.close()
225
226
227class SigresRobotTest(AbipyTest):
228
229    def test_sigres_robot(self):
230        """Testing SIGRES robot."""
231        filepaths = abidata.ref_files(
232            "si_g0w0ppm_nband10_SIGRES.nc",
233            "si_g0w0ppm_nband20_SIGRES.nc",
234            "si_g0w0ppm_nband30_SIGRES.nc",
235        )
236        assert abilab.SigresRobot.class_handles_filename(filepaths[0])
237        assert len(filepaths) == 3
238
239        with abilab.SigresRobot.from_files(filepaths, abspath=True) as robot:
240            assert robot.start is None
241            start = robot.trim_paths(start=None)
242            assert robot.start == start
243            for p, _ in robot.items():
244                assert p == os.path.relpath(p, start=start)
245
246            assert robot.EXT == "SIGRES"
247            repr(robot); str(robot)
248            assert robot.to_string(verbose=2)
249            assert robot._repr_html_()
250
251            df_params = robot.get_params_dataframe()
252            self.assert_equal(df_params["nsppol"].values, 1)
253
254            label_ncfile_param = robot.sortby("nband")
255            assert [t[2] for t in label_ncfile_param] == [10, 20, 30]
256            label_ncfile_param = robot.sortby(lambda ncfile: ncfile.ebands.nband, reverse=True)
257            assert [t[2] for t in label_ncfile_param] == [30, 20, 10]
258
259            df_sk = robot.merge_dataframes_sk(spin=0, kpoint=[0, 0, 0])
260            qpdata = robot.get_qpgaps_dataframe(with_geo=True)
261
262            # Test plotting methods.
263            if self.has_matplotlib():
264                assert robot.plot_qpgaps_convergence(plot_qpmks=False, sortby=None, hue=None, show=False)
265                assert robot.plot_qpgaps_convergence(plot_qpmks=True, sortby="nband", hue="ecuteps", show=False)
266
267                assert robot.plot_qpdata_conv_skb(spin=0, kpoint=(0, 0, 0), band=3, show=False)
268                assert robot.plot_qpdata_conv_skb(spin=0, kpoint=(0, 0, 0), band=5,
269                        sortby="sigma_nband", hue="ecuteps", show=False)
270                with self.assertRaises(TypeError):
271                    robot.plot_qpdata_conv_skb(spin=0, kpoint=(0, 0, 0), band=5,
272                            sortby="sigma_nband", hue="fooecueps", show=False)
273
274                # Test plot_qpfield_vs_e0
275                assert robot.plot_qpfield_vs_e0("qpeme0", sortby=None, hue=None, e0="fermie",
276                        colormap="viridis", show=False)
277                assert robot.plot_qpfield_vs_e0("ze0", itemp=1, sortby="ebands.nkpt", hue="scr_nband",
278                        colormap="viridis", show=False)
279
280            if self.has_nbformat():
281                robot.write_notebook(nbpath=self.get_tmpname(text=True))
282
283            robot.pop_label(os.path.relpath(filepaths[0], start=start))
284            assert len(robot) == 2
285            robot.pop_label("foobar")
286            new2old = robot.change_labels(["hello", "world"], dryrun=True)
287            assert len(new2old) == 2 and "hello" in new2old
288
289            new2old = robot.remap_labels(lambda af: af.filepath, dryrun=False)
290            assert len(new2old) == 2
291            assert all(key == abifile.filepath for key, abifile in robot.items())
292