1"""Tests for structure module"""
2import numpy as np
3import sys
4import abipy.data as abidata
5
6from pymatgen.core.lattice import Lattice
7from pymatgen.core.units import bohr_to_ang
8from abipy.core.structure import *
9from abipy.core.testing import AbipyTest
10
11
12class TestStructure(AbipyTest):
13    """Unit tests for Structure."""
14
15    def test_structure_from_ncfiles(self):
16        """Initialize Structure from Netcdf data files"""
17
18        for filename in abidata.WFK_NCFILES + abidata.GSR_NCFILES:
19            #print("About to read file %s" % filename)
20            structure = Structure.from_file(filename)
21            str(structure)
22            structure.to_string(verbose=2)
23            assert structure.__class__ is Structure
24
25            # All nc files produced by ABINIT should have info on the spacegroup.
26            assert structure.has_abi_spacegroup
27
28            # Call pymatgen machinery to get the high-symmetry stars.
29            str(structure.hsym_stars)
30
31            geodict = structure.get_dict4pandas()
32            assert geodict["abispg_num"] is not None
33
34            # Export data in Xcrysden format.
35            #structure.export(self.get_tmpname(text=True, suffix=".xsf"))
36            #visu = structure.visualize(appname="vesta")
37            #assert callable(visu)
38
39            if self.has_ase():
40                assert structure == Structure.from_ase_atoms(structure.to_ase_atoms())
41                if self.has_matplotlib():
42                    assert structure.plot_atoms(show=False)
43
44    def test_utils(self):
45        """Test utilities for the generation of Abinit inputs."""
46        # Test as_structure and from/to abivars
47        si = Structure.as_structure(abidata.cif_file("si.cif"))
48        assert si.formula == "Si2"
49        assert si.latex_formula == "Si$_{2}$"
50        assert si.abi_spacegroup is None and not si.has_abi_spacegroup
51        assert "ntypat" in si.to(fmt="abivars")
52
53        spgroup = si.spgset_abi_spacegroup(has_timerev=True)
54        assert spgroup is not None
55        assert si.has_abi_spacegroup
56        assert si.abi_spacegroup.spgid == 227
57        kfrac_coords = si.get_kcoords_from_names(["G", "X", "L", "Gamma"])
58        self.assert_equal(kfrac_coords,
59            ([[0. , 0. , 0. ], [0.5, 0. , 0.5], [0.5, 0.5, 0.5], [0. , 0. , 0. ]]))
60
61        si_wfk = Structure.as_structure(abidata.ref_file("si_scf_WFK.nc"))
62        assert si_wfk.formula == "Si2"
63        si_wfk.print_neighbors(radius=2.5)
64
65        assert si_wfk.has_abi_spacegroup
66        # Cannot change spacegroup
67        with self.assertRaises(ValueError):
68            si_wfk.spgset_abi_spacegroup(has_timerev=True)
69
70        # K and U are equivalent. [5/8, 1/4, 5/8] should return U
71        assert si_wfk.findname_in_hsym_stars([3/8, 3/8, 3/4]) == "K"
72        assert si_wfk.findname_in_hsym_stars([5/8, 1/4, 5/8]) == "U"
73
74        # TODO: Fix order of atoms in supercells.
75        # Test __mul__, __rmul__ (should return Abipy structures)
76        assert si_wfk == 1 * si_wfk
77        supcell = si_wfk * [2, 2, 2]
78        assert len(supcell) == 8 * len(si_wfk) and hasattr(supcell, "abi_string")
79
80        si_abi = Structure.from_file(abidata.ref_file("refs/si_ebands/run.abi"))
81        assert si_abi.formula == "Si2"
82        self.assert_equal(si_abi.frac_coords, [[0, 0, 0], [0.25, 0.25, 0.25]])
83
84        si_abo = Structure.from_file(abidata.ref_file("refs/si_ebands/run.abo"))
85        assert si_abo == si_abi
86        assert "ntypat" in si_abi.to(fmt="abivars")
87
88        znse = Structure.from_file(abidata.ref_file("refs/znse_phonons/ZnSe_hex_qpt_DDB"))
89        assert len(znse) == 4
90        assert znse.formula == "Zn2 Se2"
91        self.assert_almost_equal(znse.frac_coords.flat, [
92            0.33333333333333,  0.66666666666667, 0.99962203020000,
93            0.66666666666667,  0.33333333333333, 0.49962203020000,
94            0.33333333333333,  0.66666666666667, 0.62537796980000,
95            0.66666666666667,  0.33333333333333, 0.12537796980000])
96
97        from abipy.core.structure import diff_structures
98        diff_structures([si_abi, znse], headers=["si_abi", "znse"], fmt="abivars", mode="table")
99        diff_structures([si_abi, znse], headers=["si_abi", "znse"], fmt="abivars", mode="diff")
100
101        # From pickle file.
102        import pickle
103        tmp_path = self.get_tmpname(suffix=".pickle")
104        with open(tmp_path, "wb") as fh:
105            pickle.dump(znse, fh)
106        same_znse = Structure.from_file(tmp_path)
107        assert same_znse == znse
108        same_znse = Structure.as_structure(tmp_path)
109        assert same_znse == znse
110
111        for fmt in ["abivars", "cif", "POSCAR", "json", "xsf", "qe", "siesta", "wannier90"]:
112            assert len(znse.convert(fmt=fmt)) > 0
113
114        for fmt in ["abinit", "w90", "siesta"]:
115            assert len(znse.get_kpath_input_string(fmt=fmt)) > 0
116
117        oxi_znse = znse.get_oxi_state_decorated()
118        assert len(oxi_znse.abi_string)
119        from pymatgen.core.periodic_table import Specie
120        assert Specie("Zn", 2) in oxi_znse.composition.elements
121        assert Specie("Se", -2) in oxi_znse.composition.elements
122
123        system = si.spget_lattice_type()
124        assert system == "cubic"
125
126        e = si.spget_equivalent_atoms(printout=True)
127        assert len(e.irred_pos) == 1
128        self.assert_equal(e.eqmap[0], [0, 1])
129        for irr_pos in e.irred_pos:
130            assert len(e.eqmap[irr_pos]) > 0
131        assert "equivalent_atoms" in e.spgdata
132
133        if self.has_matplotlib():
134            assert si.plot_bz(show=False)
135            assert si.plot_bz(pmg_path=False, show=False)
136            assert si.plot(show=False)
137            if sys.version[0:3] > '2.7':
138                # pmg broke py compatibility
139                assert si.plot_xrd(show=False)
140
141        if self.has_mayavi():
142            #assert si.plot_vtk(show=False)  # Disabled due to (core dumped) on travis
143            assert si.plot_mayaview(show=False)
144
145        if self.has_panel():
146            assert hasattr(si.get_panel(), "show")
147
148        assert si is Structure.as_structure(si)
149        assert si == Structure.as_structure(si.to_abivars())
150        assert si == Structure.from_abivars(si.to_abivars())
151        assert len(si.abi_string)
152        assert si.reciprocal_lattice == si.lattice.reciprocal_lattice
153
154        kptbounds = si.calc_kptbounds()
155        ksamp = si.calc_ksampling(nksmall=10)
156
157        shiftk = [[ 0.5,  0.5,  0.5], [ 0.5,  0. ,  0. ], [ 0. ,  0.5,  0. ], [ 0. ,  0. ,  0.5]]
158        self.assert_equal(si.calc_ngkpt(nksmall=2), [2, 2, 2])
159        self.assert_equal(si.calc_shiftk(), shiftk)
160        self.assert_equal(ksamp.ngkpt, [10, 10, 10])
161        self.assert_equal(ksamp.shiftk, shiftk)
162
163        lif = Structure.from_abistring("""
164acell      7.7030079150    7.7030079150    7.7030079150 Angstrom
165rprim      0.0000000000    0.5000000000    0.5000000000
166           0.5000000000    0.0000000000    0.5000000000
167           0.5000000000    0.5000000000    0.0000000000
168natom      2
169ntypat     2
170typat      1 2
171znucl      3 9
172xred       0.0000000000    0.0000000000    0.0000000000
173           0.5000000000    0.5000000000    0.5000000000
174""")
175        assert lif.formula == "Li1 F1"
176        same = Structure.rocksalt(7.7030079150, ["Li", "F"], units="ang")
177        self.assert_almost_equal(lif.lattice.a,  same.lattice.a)
178
179        si = Structure.from_mpid("mp-149")
180        assert si.formula == "Si2"
181
182        # Test abiget_spginfo
183        d = si.abiget_spginfo(tolsym=None, pre="abi_")
184        assert d["abi_spg_symbol"] == "Fd-3m"
185        assert d["abi_spg_number"] == 227
186        assert d["abi_bravais"] == "Bravais cF (face-center cubic)"
187
188        llzo = Structure.from_file(abidata.cif_file("LLZO_oxi.cif"))
189        assert llzo.is_ordered
190        d = llzo.abiget_spginfo(tolsym=0.001)
191        assert d["spg_number"] == 142
192
193        mgb2_cod = Structure.from_cod_id(1526507, primitive=True)
194        assert mgb2_cod.formula == "Mg1 B2"
195        assert mgb2_cod.spget_lattice_type() == "hexagonal"
196
197        mgb2 = abidata.structure_from_ucell("MgB2")
198        if self.has_ase():
199            mgb2.abi_primitive()
200
201        assert [site.species_string for site in mgb2.get_sorted_structure_z()] == ["B", "B", "Mg"]
202
203        s2inds = mgb2.get_symbol2indices()
204        self.assert_equal(s2inds["Mg"], [0])
205        self.assert_equal(s2inds["B"], [1, 2])
206
207        s2coords = mgb2.get_symbol2coords()
208        self.assert_equal(s2coords["Mg"], [[0, 0, 0]])
209        self.assert_equal(s2coords["B"],  [[1/3, 2/3, 0.5], [2/3, 1/3, 0.5]])
210
211        new_mgb2 = mgb2.scale_lattice(mgb2.volume * 1.1)
212        self.assert_almost_equal(new_mgb2.volume, mgb2.volume * 1.1)
213        assert new_mgb2.lattice.is_hexagonal
214
215        # TODO: This part should be tested more carefully
216        mgb2.abi_sanitize()
217        mgb2.abi_sanitize(primitive_standard=True)
218        mgb2.get_conventional_standard_structure()
219        assert len(mgb2.abi_string)
220        assert len(mgb2.spget_summary(site_symmetry=True, verbose=10))
221
222        self.serialize_with_pickle(mgb2)
223
224        pseudos = abidata.pseudos("12mg.pspnc", "5b.pspnc")
225        nv = mgb2.num_valence_electrons(pseudos)
226        assert nv == 8 and isinstance(nv , int)
227        assert mgb2.valence_electrons_per_atom(pseudos) == [2, 3, 3]
228        self.assert_equal(mgb2.calc_shiftk() , [[0.0, 0.0, 0.5]])
229
230        bmol = Structure.boxed_molecule(pseudos, cart_coords=[[0, 0, 0], [5, 5, 5]], acell=[10, 10, 10])
231        self.assert_almost_equal(bmol.volume, (10 * bohr_to_ang) ** 3)
232
233        # FIXME This is buggy
234        #acell = np.array([10, 20, 30])
235        #batom = Structure.boxed_atom(abidata.pseudo("12mg.pspnc"), cart_coords=[1, 2, 3], acell=acell)
236        #assert isinstance(batom, Structure)
237        #assert len(batom.cart_coords) == 1
238        #self.assert_equal(batom.cart_coords[0], [1, 2, 3])
239
240        # Function to compute cubic a0 from primitive v0 (depends on struct_type)
241        vol2a = {"fcc": lambda vol: (4 * vol) ** (1/3.),
242                 "bcc": lambda vol: (2 * vol) ** (1/3.),
243                 "zincblende": lambda vol: (4 * vol) ** (1/3.),
244                 "rocksalt": lambda vol: (4 * vol) ** (1/3.),
245                 "ABO3": lambda vol: vol ** (1/3.),
246                 "hH": lambda vol: (4 * vol) ** (1/3.),
247                 }
248
249        a = 10
250        bcc_prim = Structure.bcc(a, ["Si"], primitive=True)
251        assert len(bcc_prim) == 1
252        self.assert_almost_equal(a, vol2a["bcc"](bcc_prim.volume))
253        bcc_conv = Structure.bcc(a, ["Si"], primitive=False)
254        assert len(bcc_conv) == 2
255        self.assert_almost_equal(a**3, bcc_conv.volume)
256        fcc_prim = Structure.fcc(a, ["Si"], primitive=True)
257        assert len(fcc_prim) == 1
258        self.assert_almost_equal(a, vol2a["fcc"](fcc_prim.volume))
259        fcc_conv = Structure.fcc(a, ["Si"], primitive=False)
260        assert len(fcc_conv) == 4
261        self.assert_almost_equal(a**3, fcc_conv.volume)
262        zns = Structure.zincblende(a / bohr_to_ang, ["Zn", "S"], units="bohr")
263        self.assert_almost_equal(a, vol2a["zincblende"](zns.volume))
264        rock = Structure.rocksalt(a, ["Na", "Cl"])
265        assert len(rock) == 2
266        self.assert_almost_equal(a, vol2a["rocksalt"](rock.volume))
267        perov = Structure.ABO3(a, ["Ca", "Ti", "O", "O", "O"])
268        assert len(perov) == 5
269        self.assert_almost_equal(a**3, perov.volume)
270
271        # Test notebook generation.
272        if self.has_nbformat():
273            assert mgb2.write_notebook(nbpath=self.get_tmpname(text=True))
274
275    def test_znucl_typat(self):
276        """Test the order of typat and znucl in the Abinit input and enforce_typat, enforce_znucl."""
277
278        # Ga  Ga1  1  0.33333333333333  0.666666666666667  0.500880  1.0
279        # Ga  Ga2  1  0.66666666666667  0.333333333333333  0.000880  1.0
280        # N  N3  1  0.333333333333333  0.666666666666667  0.124120  1.0
281        # N  N4  1  0.666666666666667  0.333333333333333  0.624120  1.0
282        gan2 = Structure.from_file(abidata.cif_file("gan2.cif"))
283
284        # By default, znucl is filled using the first new type found in sites.
285        def_vars = gan2.to_abivars()
286        def_znucl = def_vars["znucl"]
287        self.assert_equal(def_znucl, [31, 7])
288        def_typat = def_vars["typat"]
289        self.assert_equal(def_typat, [1, 1, 2, 2])
290
291        # But it's possible to enforce a particular value of typat and znucl.
292        enforce_znucl = [7 ,31]
293        enforce_typat = [2, 2, 1, 1]
294        enf_vars = gan2.to_abivars(enforce_znucl=enforce_znucl, enforce_typat=enforce_typat)
295        self.assert_equal(enf_vars["znucl"], enforce_znucl)
296        self.assert_equal(enf_vars["typat"], enforce_typat)
297        self.assert_equal(def_vars["xred"], enf_vars["xred"])
298
299        assert [s.symbol for s in gan2.species_by_znucl] == ["Ga", "N"]
300
301        for itype1, itype2 in zip(def_typat, enforce_typat):
302            assert def_znucl[itype1 - 1] == enforce_znucl[itype2 -1]
303
304        with self.assertRaises(Exception):
305            gan2.to_abivars(enforce_znucl=enforce_znucl, enforce_typat=None)
306
307    def test_dataframes_from_structures(self):
308        """Testing dataframes from structures."""
309        mgb2 = abidata.structure_from_ucell("MgB2")
310        sic = abidata.structure_from_ucell("SiC")
311        alas = abidata.structure_from_ucell("AlAs")
312        dfs = dataframes_from_structures([mgb2, sic, alas], index=None, with_spglib=True, cart_coords=True)
313        dfs = dataframes_from_structures([mgb2, sic, alas], index=None, with_spglib=True, cart_coords=False)
314
315        assert dfs.lattice is not None
316        assert dfs.coords is not None
317        assert dfs.structures is not None
318        formulas = [struct.composition.reduced_formula for struct in dfs.structures]
319        assert formulas == ["MgB2", "SiC", "AlAs"]
320
321    def test_frozen_phonon_methods(self):
322        """Testing frozen phonon methods (This is not a real test, just to show how to use it!)"""
323        rprimd = np.array([[0.5, 0.5, 0], [0.5, 0, 0.5], [0, 0.5, 0.5]])
324        #rprimd = rprimd*6.7468
325        rprimd = rprimd * 10.60 * 0.529
326        lattice = Lattice(rprimd)
327        structure = Structure(lattice, ["Ga", "As"], [[0, 0, 0], [0.25, 0.25, 0.25]])
328        old_structure = structure.copy()
329
330        #print(old_structure.lattice._matrix)
331        for site in old_structure:
332            _ = structure.lattice.get_cartesian_coords(site.frac_coords)
333
334        # TODO: Check all this stuff more carefully
335        #qpoint = [0, 0, 0]
336        qpoint = [1/2, 1/2, 1/2]
337        mx_sc = [2, 2, 2]
338        scale_matrix = structure.get_smallest_supercell(qpoint, max_supercell=mx_sc)
339        scale_matrix = 2 * np.eye(3)
340        #print("Scale_matrix = ", scale_matrix)
341        #scale_matrix = 2*np.eye(3)
342        natoms = int(np.round(2*np.linalg.det(scale_matrix)))
343
344        structure.write_vib_file(sys.stdout, qpoint, 0.1*np.array([[1, 1, 1], [1, 1, 1]]),
345                                 do_real=True, frac_coords=False, max_supercell=mx_sc, scale_matrix=scale_matrix)
346
347        displ = np.array([[1, 1, 1], [-1, -1, -1]])
348        structure.write_vib_file(sys.stdout, qpoint, 0.1 * displ,
349                                 do_real=True, frac_coords=False, max_supercell=mx_sc, scale_matrix=scale_matrix)
350
351        structure.write_vib_file(sys.stdout, qpoint, 0.1 * displ,
352                                 do_real=True, frac_coords=False, max_supercell=mx_sc, scale_matrix=None)
353
354        fp_data = structure.frozen_phonon(qpoint, 0.1 * displ, eta=0.5, frac_coords=False,
355                                          max_supercell=mx_sc, scale_matrix=scale_matrix)
356
357        max_displ = np.linalg.norm(displ, axis=1).max()
358        self.assertArrayAlmostEqual(fp_data.structure[0].coords,
359                                    structure[0].coords + 0.5*displ[0]/max_displ)
360        self.assertArrayAlmostEqual(fp_data.structure[8].coords,
361                                    structure[1].coords + 0.5*displ[1]/max_displ)
362
363        displ2 = np.array([[1, 0, 0], [0, 1, 1]])
364
365        f2p_data = structure.frozen_2phonon(qpoint, 0.05 * displ, 0.02*displ2, eta=0.5, frac_coords=False,
366                                           max_supercell=mx_sc, scale_matrix=scale_matrix)
367
368        d_tot = 0.05*displ+0.02*displ2
369        max_displ = np.linalg.norm(d_tot, axis=1).max()
370        self.assertArrayAlmostEqual(f2p_data.structure[0].coords,
371                                    structure[0].coords + 0.5*d_tot[0]/max_displ)
372        self.assertArrayAlmostEqual(f2p_data.structure[8].coords,
373                                    structure[1].coords + 0.5*d_tot[1]/max_displ)
374
375        #print("Structure = ", structure)
376        #print(structure.lattice._matrix)
377        #for site in structure:
378        #    print(structure.lattice.get_cartesian_coords(site.frac_coords))
379