1#!/usr/local/bin/python3.8
2"""Tests for kpoints.kpoints module."""
3import itertools
4import unittest
5import numpy as np
6import abipy.data as abidata
7
8from pymatgen.core.lattice import Lattice
9from abipy import abilab
10from abipy.core.kpoints import (wrap_to_ws, wrap_to_bz, issamek, Kpoint, KpointList, IrredZone, Kpath, KpointsReader,
11    has_timrev_from_kptopt, KSamplingInfo, as_kpoints, rc_list, kmesh_from_mpdivs, map_grid2ibz,
12    set_atol_kdiff, set_spglib_tols, kpath_from_bounds_and_ndivsm, build_segments)  #Ktables,
13from abipy.core.testing import AbipyTest
14
15
16class TestWrapWS(AbipyTest):
17
18    def test_wrap_to_ws(self):
19        """Testing wrap_to_ws"""
20        self.assert_almost_equal(wrap_to_ws( 0.5), 0.5)
21        self.assert_almost_equal(wrap_to_ws(-0.5), 0.5)
22        self.assert_almost_equal(wrap_to_ws( 0.2), 0.2)
23        self.assert_almost_equal(wrap_to_ws(-0.3),-0.3)
24        self.assert_almost_equal(wrap_to_ws( 0.7),-0.3)
25        self.assert_almost_equal(wrap_to_ws( 2.3), 0.3)
26        self.assert_almost_equal(wrap_to_ws(-1.2),-0.2)
27        self.assert_almost_equal(wrap_to_ws(np.array([0.5,2.3,-1.2])), np.array([0.5,0.3,-0.2]))
28
29
30class TestHelperFunctions(AbipyTest):
31
32    def test_wrap_to_bz(self):
33        """Testing wrap_to_bz"""
34        self.assertAlmostEqual(wrap_to_bz( 0.0), 0.0)
35        self.assertAlmostEqual(wrap_to_bz( 1.0), 0.0)
36        self.assertAlmostEqual(wrap_to_bz( 0.2), 0.2)
37        self.assertAlmostEqual(wrap_to_bz(-0.2), 0.8)
38        self.assertAlmostEqual(wrap_to_bz( 3.2), 0.2)
39        self.assertAlmostEqual(wrap_to_bz(-3.2), 0.8)
40
41    def test_is_diagonal(self):
42        """Testing is_diagonal"""
43        from abipy.core.kpoints import is_diagonal
44        assert is_diagonal(np.eye(3, dtype=int))
45        assert is_diagonal(np.eye(3, dtype=float))
46        a = np.eye(3, dtype=float)
47        atol = 1e-12
48        a[1, 2] = atol
49        assert is_diagonal(a, atol=atol)
50        assert not is_diagonal(a, atol=atol / 2)
51
52    def test_has_timrev_from_kptopt(self):
53        """Testing has_timrev_from_kptopt."""
54        assert has_timrev_from_kptopt(1)
55        assert not has_timrev_from_kptopt(4)
56        assert has_timrev_from_kptopt(-7)
57
58    def test_kptopt2str(self):
59        """Testing kptopt2str."""
60        from abipy.core.kpoints import kptopt2str
61        for kptopt in [-5, 0, 1, 2, 3, 4]:
62            assert kptopt2str(kptopt, verbose=1 if kptopt != 1 else 0)
63
64    def test_kpath_from_bounds_and_ndivsm(self):
65        """Testing kpath_from_bounds_and_ndivsm."""
66        structure = abilab.Structure.as_structure(abidata.cif_file("si.cif"))
67        with self.assertRaises(ValueError):
68            kpath_from_bounds_and_ndivsm([(0, 0, 0)], 5, structure)
69        with self.assertRaises(ValueError):
70            kpath_from_bounds_and_ndivsm([(0, 0, 0), (0, 0, 0)], 5, structure)
71
72        path = kpath_from_bounds_and_ndivsm([(0, 0, 0), (0.5, 0, 0)], 5, structure)
73        self.assert_equal(path, [[0.0, 0.0, 0.0 ],
74                                 [0.1, 0.0, 0.0 ],
75                                 [0.2, 0.0, 0.0 ],
76                                 [0.3, 0.0, 0.0 ],
77                                 [0.4, 0.0, 0.0 ],
78                                 [0.5, 0.0, 0.0 ]])
79
80class TestKpoint(AbipyTest):
81    """Unit tests for Kpoint object."""
82
83    def setUp(self):
84        self.lattice = Lattice([0.5, 0.5, 0, 0, 0.5, 0, 0, 0, 0.4])
85
86        # Test API to set tolerances.
87
88        # Change _ATOL_KDIFF
89        atol_default = 1e-8
90        assert set_atol_kdiff(1e-6) == atol_default
91        set_atol_kdiff(atol_default)
92
93        # Change spglib tolerances.
94        symprec_default, angle_tolerance_default = 1e-5, -1.0
95        s, a = set_spglib_tols(1e-4, -2.0)
96        assert s == symprec_default
97        assert a == angle_tolerance_default
98        set_spglib_tols(symprec_default, angle_tolerance_default)
99
100    def test_kpoint_algebra(self):
101        """Test k-point algebra."""
102        lattice = self.lattice
103        gamma = Kpoint([0, 0, 0], lattice)
104        pgamma = Kpoint([1, 0, 1], lattice)
105        X = Kpoint([0.5, 0, 0], lattice)
106        K = Kpoint([1/3, 1/3, 1/3], lattice)
107        repr(X); str(X)
108        assert X.to_string(verbose=2)
109        assert X.tos(m="fract")
110        assert X.tos(m="cart")
111        assert X.tos(m="fracart")
112
113        assert gamma.is_gamma()
114        assert not pgamma.is_gamma()
115        assert pgamma.is_gamma(allow_umklapp=True)
116        assert not X.is_gamma()
117
118        # TODO
119        #assert np.all(np.array(X) == X.frac_coords)
120
121        self.serialize_with_pickle(X, protocols=[-1])
122        self.assert_almost_equal(X.versor().norm, 1.0)
123
124        other_gamma = Kpoint.gamma(lattice, weight=1)
125        assert other_gamma == [0, 0, 0]
126        assert other_gamma == gamma
127        assert gamma.versor() == gamma
128
129        X_outside = Kpoint([1.5, 0, 0], lattice)
130        assert X_outside.wrap_to_ws() == X
131        assert X_outside.wrap_to_ws() == [0.5, 0, 0]
132
133        X_outside = Kpoint([0.7, 0, 0], lattice)
134        assert X_outside.wrap_to_bz() == [-0.3, 0, 0]
135
136        assert X[0] == 0.5
137        self.assert_equal(pgamma[:2].tolist(), [1,0])
138
139        assert gamma == pgamma
140        assert gamma + pgamma == gamma
141        assert pgamma + X == X
142        assert gamma != X
143        # TODO
144        #assert gamma != 0
145
146        assert X.norm == (gamma + X).norm
147        assert X.norm ==  (gamma + X).norm
148        assert X.norm == np.sqrt(np.sum(X.cart_coords**2))
149        # TODO
150        #assert X != 0.5
151
152        assert hash(gamma) == hash(pgamma)
153        if hash(K) != hash(X):
154            assert K != X
155
156        # test on_border
157        assert not gamma.on_border
158        assert X.on_border
159        assert not K.on_border
160
161
162class TestKpointList(AbipyTest):
163    """Unit tests for KpointList."""
164
165    def setUp(self):
166        self.lattice = Lattice([0.5,0.5,0,0,0.5,0,0,0,0.4])
167
168    def test_askpoints(self):
169        """Test askpoints."""
170        lattice = self.lattice
171        kpts = as_kpoints([1, 2, 3], lattice)
172
173        self.serialize_with_pickle(kpts, protocols=[-1])
174
175        newkpts = as_kpoints(kpts, lattice)
176        assert kpts is newkpts
177
178        kpts = as_kpoints([1, 2, 3, 4, 5, 6], lattice)
179        assert len(kpts) == 2
180        assert kpts[0] == Kpoint([1, 2, 3], lattice)
181        assert kpts[1] == Kpoint([4, 5, 6], lattice)
182
183    def test_kpointlist(self):
184        """Test KpointList."""
185        lattice = self.lattice
186
187        frac_coords = [0, 0, 0, 1/2, 1/2, 1/2, 1/3, 1/3, 1/3]
188        weights = [0.1, 0.2, 0.7]
189        klist = KpointList(lattice, frac_coords, weights=weights)
190        repr(klist); str(klist)
191
192        self.serialize_with_pickle(klist, protocols=[-1])
193        self.assertMSONable(klist, test_if_subclass=False)
194
195        self.assert_equal(klist.frac_coords.flatten(), frac_coords)
196        self.assert_equal(klist.get_cart_coords(), np.reshape([k.cart_coords for k in klist], (-1, 3)))
197        assert klist.sum_weights() == 1
198        assert len(klist) == 3
199
200        for i, kpoint in enumerate(klist):
201            assert kpoint in klist
202            assert klist.count(kpoint) == 1
203            assert klist.index(kpoint) == i
204            assert klist.find(kpoint) == i
205
206        # Changing the weight of the Kpoint object should change the weights of klist.
207        for kpoint in klist: kpoint.set_weight(1.0)
208        assert np.all(klist.weights == 1.0)
209
210        # Test find_closest
211        iclose, kclose, dist = klist.find_closest([0, 0, 0])
212        assert iclose == 0 and dist == 0.
213
214        iclose, kclose, dist = klist.find_closest(Kpoint([0.001, 0.002, 0.003], klist.reciprocal_lattice))
215        assert iclose == 0
216        self.assert_almost_equal(dist, 0.001984943324127921)
217
218        # Compute mapping k_index --> (k + q)_index, g0
219        k2kqg = klist.get_k2kqg_map((0, 0, 0))
220        assert all(ikq == ik for ik, (ikq, g0) in k2kqg.items())
221        k2kqg = klist.get_k2kqg_map((1/2, 1/2, 1/2))
222        assert len(k2kqg) == 2
223        assert k2kqg[0][0] == 1 and np.all(k2kqg[0][1] == 0)
224        assert k2kqg[1][0] == 0 and np.all(k2kqg[1][1] == 1)
225
226        frac_coords = [0, 0, 0, 1/2, 1/3, 1/3]
227        other_klist = KpointList(lattice, frac_coords)
228
229        # Test __add__
230        add_klist = klist + other_klist
231
232        for k in itertools.chain(klist, other_klist):
233            assert k in add_klist
234
235        assert add_klist.count([0,0,0]) == 2
236
237        # Remove duplicated k-points.
238        add_klist = add_klist.remove_duplicated()
239        assert add_klist.count([0,0,0]) == 1
240        assert len(add_klist) == 4
241        assert add_klist == add_klist.remove_duplicated()
242
243        frac_coords = [1/2, 1/2, 1/2, 1/2, 1/2, 1/2]
244        klist = KpointList(lattice, frac_coords, weights=None)
245        assert np.all(klist.get_all_kindices([1/2, 1/2, 1/2]) == [0, 1])
246        with self.assertRaises(ValueError):
247            klist.index((0, 0, 0))
248
249
250class TestIrredZone(AbipyTest):
251
252    def test_irredzone_api(self):
253        """Testing IrredZone API."""
254        structure = abilab.Structure.as_structure(abidata.cif_file("si.cif"))
255
256        ibz = IrredZone.from_ngkpt(structure, ngkpt=[4, 4, 4], shiftk=[0.0, 0.0, 0.0], verbose=2)
257        repr(ibz); str(ibz)
258        assert ibz.to_string(verbose=2)
259        assert ibz.is_ibz
260        assert len(ibz) == 8
261        assert ibz.ksampling.kptopt == 1
262        self.assert_equal(ibz.ksampling.mpdivs, [4, 4, 4])
263
264        ibz = IrredZone.from_kppa(structure, kppa=1000, shiftk=[0.5, 0.5, 0.5], kptopt=1, verbose=1)
265        assert ibz.is_ibz
266        assert len(ibz) == 60
267        self.assert_equal(ibz.ksampling.mpdivs, [8, 8, 8])
268
269
270class TestKpath(AbipyTest):
271
272    def test_kpath_api(self):
273        """Testing Kpath API."""
274        structure = abilab.Structure.as_structure(abidata.cif_file("si.cif"))
275
276        knames = ["G", "X", "L", "G"]
277        kpath = Kpath.from_names(structure, knames, line_density=5)
278        repr(kpath); str(kpath)
279        assert kpath.to_string(verbose=2, title="Kpath")
280        assert not kpath.is_ibz and kpath.is_path
281        assert kpath[0].is_gamma and kpath[-1].is_gamma
282        #assert len(kpath.ds) == len(self) - 1
283        #assert kpath.ksampling.kptopt == 1
284        #self.assert_equal(kpath.ksampling.mpdivs, [4, 4, 4])
285
286        assert Kpoint.from_name_and_structure("Gamma", structure) == kpath[0]
287
288        assert len(kpath.ds) == len(kpath) - 1
289        assert len(kpath.versors) == len(kpath) - 1
290        assert len(kpath.lines) == len(knames) - 1
291        self.assert_almost_equal(kpath.frac_bounds, structure.get_kcoords_from_names(knames))
292        self.assert_almost_equal(kpath.cart_bounds, structure.get_kcoords_from_names(knames, cart_coords=True))
293
294        r = kpath.find_points_along_path(kpath.get_cart_coords())
295        assert len(r.ikfound) == len(kpath)
296        self.assert_equal(r.ikfound,
297            [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15,  0])
298
299        #kpath = IrredZone.from_kppa(structure, kppa=1000, shiftk=[0.5, 0.5, 0.5], kptopt=1, verbose=1)
300        #assert not kpath.is_ibz and kpath.is_path
301        #assert len(kpath) == 60
302        #self.assert_equal(kpath.ksampling.mpdivs, [8, 8, 8])
303
304        segments = build_segments(k0_list=(0, 0, 0), npts=1, step=0.01, red_dirs=(1, 0, 0),
305                                  reciprocal_lattice=structure.reciprocal_lattice)
306        assert len(segments) == 1
307        assert np.all(segments[0] == (0, 0, 0))
308
309        step, npts = 0.1, 5
310        red_dir = np.array((1, 1, 0))
311        segments = build_segments(k0_list=(0, 0, 0, 0.5, 0, 0), npts=npts, step=step, red_dirs=red_dir,
312                                  reciprocal_lattice=structure.reciprocal_lattice)
313
314        #print("segments:\n", segments)
315        # (nk0_list, len(red_dirs) * npts, 3)
316        assert segments.shape == (2, npts, 3)
317        self.assert_almost_equal(segments[0, 2], (0, 0, 0))
318        self.assert_almost_equal(segments[1, 2], (0.5, 0.0, 0))
319        def r2c(vec):
320            return structure.reciprocal_lattice.get_cartesian_coords(vec)
321        cart_vers = r2c(red_dir)
322        cart_vers /= np.linalg.norm(cart_vers)
323        self.assert_almost_equal(r2c(segments[1, 1] - segments[1, 0]), step * cart_vers)
324        self.assert_almost_equal(r2c(segments[1, 3] - segments[1, 2]), step * cart_vers)
325
326
327class TestKpointsReader(AbipyTest):
328
329    def test_reading(self):
330        """Test the reading of Kpoints from netcdf files."""
331        filenames = [
332            "si_scf_GSR.nc",
333            "si_nscf_GSR.nc",
334            "si_scf_WFK.nc",
335        ]
336
337        for fname in filenames:
338            filepath = abidata.ref_file(fname)
339            print("About to read file: %s" % filepath)
340
341            with KpointsReader(filepath) as r:
342                kpoints = r.read_kpoints()
343                repr(kpoints); str(kpoints)
344
345                if "_scf" in fname:
346                    # expecting a homogeneous sampling.
347                    assert not kpoints.is_path
348                    assert kpoints.is_ibz
349                    assert kpoints.sum_weights() == 1.0
350                    assert kpoints.ksampling.kptopt == 1
351                    mpdivs, shifts = kpoints.mpdivs_shifts
352                    assert np.all(mpdivs == [8, 8, 8])
353                    assert len(shifts) == 1 and np.all(shifts[0] == [0, 0, 0])
354
355                elif "_nscf" in fname:
356                    # expecting a path in k-space.
357                    assert kpoints.is_path
358                    assert not kpoints.is_ibz
359                    assert kpoints.ksampling.kptopt == -2
360                    mpdivs, shifts = kpoints.mpdivs_shifts
361                    assert mpdivs is None
362                    assert len(kpoints.lines) == abs(kpoints.ksampling.kptopt)
363
364            # Test pickle and json
365            self.serialize_with_pickle(kpoints)
366            self.assertMSONable(kpoints, test_if_subclass=False)
367
368
369class KmeshTest(AbipyTest):
370
371    def test_rc_list(self):
372        """Testing rc_list."""
373        # Special case mp=1
374        rc = rc_list(mp=1, sh=0.0, pbc=False, order="unit_cell")
375        self.assert_equal(rc, [0.0])
376
377        rc = rc_list(mp=1, sh=0.0, pbc=True, order="unit_cell")
378        self.assert_equal(rc, [0.0, 1.0])
379
380        rc = rc_list(mp=1, sh=0.0, pbc=False, order="bz")
381        self.assert_equal(rc, [0.0])
382
383        rc = rc_list(mp=1, sh=0.0, pbc=True, order="bz")
384        self.assert_equal(rc, [0.0, 1.0])
385
386        # Even mp
387        rc = rc_list(mp=2, sh=0, pbc=False, order="unit_cell")
388        self.assert_equal(rc, [0., 0.5])
389
390        rc = rc_list(mp=2, sh=0, pbc=True, order="unit_cell")
391        self.assert_equal(rc, [0., 0.5,  1.])
392
393        rc = rc_list(mp=2, sh=0, pbc=False, order="bz")
394        self.assert_equal(rc, [-0.5, 0.0])
395
396        rc = rc_list(mp=2, sh=0, pbc=True, order="bz")
397        self.assert_equal(rc, [-0.5,  0.,  0.5])
398
399        rc = rc_list(mp=2, sh=0.5, pbc=False, order="unit_cell")
400        self.assert_equal(rc, [0.25, 0.75])
401
402        rc = rc_list(mp=2, sh=0.5, pbc=True, order="unit_cell")
403        self.assert_equal(rc, [0.25,  0.75, 1.25])
404
405        rc = rc_list(mp=2, sh=0.5, pbc=False, order="bz")
406        self.assert_equal(rc, [-0.25,  0.25])
407
408        rc = rc_list(mp=2, sh=0.5, pbc=True, order="bz")
409        self.assert_equal(rc, [-0.25,  0.25,  0.75])
410
411        # Odd mp
412        rc = rc_list(mp=3, sh=0, pbc=False, order="unit_cell")
413        self.assert_almost_equal(rc, [0.,  0.33333333,  0.66666667])
414
415        rc = rc_list(mp=3, sh=0, pbc=True, order="unit_cell")
416        self.assert_almost_equal(rc, [ 0.,  0.33333333,  0.66666667,  1.])
417
418        rc = rc_list(mp=3, sh=0, pbc=False, order="bz")
419        self.assert_almost_equal(rc, [-0.33333333,  0.,  0.33333333])
420
421        rc = rc_list(mp=3, sh=0, pbc=True, order="bz")
422        self.assert_almost_equal(rc, [-0.33333333,  0.,  0.33333333,  0.66666667])
423
424        rc = rc_list(mp=3, sh=0.5, pbc=False, order="unit_cell")
425        self.assert_almost_equal(rc, [ 0.16666667, 0.5, 0.83333333])
426
427        rc = rc_list(mp=3, sh=0.5, pbc=True, order="unit_cell")
428        self.assert_almost_equal(rc, [ 0.16666667, 0.5,  0.83333333, 1.16666667])
429
430        rc = rc_list(mp=3, sh=0.5, pbc=False, order="bz")
431        self.assert_almost_equal(rc, [-0.5, -0.16666667,  0.16666667])
432
433        rc = rc_list(mp=3, sh=0.5, pbc=True, order="bz")
434        self.assert_almost_equal(rc, [-0.5, -0.16666667,  0.16666667,  0.5])
435
436    def test_unshifted_kmesh(self):
437        """Testing the generation of unshifted kmeshes."""
438        def rm_spaces(s):
439            return " ".join(s.split()).replace("[ ", "[")
440
441        mpdivs, shifts = [1, 2, 3], [0, 0, 0]
442
443        # No shift, no pbc.
444        kmesh = kmesh_from_mpdivs(mpdivs, shifts, order="unit_cell")
445
446        ref_string = \
447"""[[ 0.          0.          0.        ]
448 [ 0.          0.          0.33333333]
449 [ 0.          0.          0.66666667]
450 [ 0.          0.5         0.        ]
451 [ 0.          0.5         0.33333333]
452 [ 0.          0.5         0.66666667]]"""
453        self.assertMultiLineEqual(rm_spaces(str(kmesh)), rm_spaces(ref_string))
454
455        # No shift, with pbc.
456        pbc_kmesh = kmesh_from_mpdivs(mpdivs, shifts, pbc=True, order="unit_cell")
457
458        ref_string = \
459"""[[ 0.          0.          0.        ]
460 [ 0.          0.          0.33333333]
461 [ 0.          0.          0.66666667]
462 [ 0.          0.          1.        ]
463 [ 0.          0.5         0.        ]
464 [ 0.          0.5         0.33333333]
465 [ 0.          0.5         0.66666667]
466 [ 0.          0.5         1.        ]
467 [ 0.          1.          0.        ]
468 [ 0.          1.          0.33333333]
469 [ 0.          1.          0.66666667]
470 [ 0.          1.          1.        ]
471 [ 1.          0.          0.        ]
472 [ 1.          0.          0.33333333]
473 [ 1.          0.          0.66666667]
474 [ 1.          0.          1.        ]
475 [ 1.          0.5         0.        ]
476 [ 1.          0.5         0.33333333]
477 [ 1.          0.5         0.66666667]
478 [ 1.          0.5         1.        ]
479 [ 1.          1.          0.        ]
480 [ 1.          1.          0.33333333]
481 [ 1.          1.          0.66666667]
482 [ 1.          1.          1.        ]]"""
483        self.assertMultiLineEqual(rm_spaces(str(pbc_kmesh)), rm_spaces(ref_string))
484
485        # No shift, no pbc, bz order
486        bz_kmesh = kmesh_from_mpdivs(mpdivs, shifts, pbc=False, order="bz")
487
488        ref_string = \
489"""[[ 0.         -0.5        -0.33333333]
490 [ 0.         -0.5         0.        ]
491 [ 0.         -0.5         0.33333333]
492 [ 0.          0.         -0.33333333]
493 [ 0.          0.          0.        ]
494 [ 0.          0.          0.33333333]]"""
495        self.assertMultiLineEqual(rm_spaces(str(bz_kmesh)), rm_spaces(ref_string))
496
497        # No shift, pbc, bz order
498        bz_kmesh = kmesh_from_mpdivs(mpdivs, shifts, pbc=True, order="bz")
499
500        ref_string = \
501"""[[ 0.         -0.5        -0.33333333]
502 [ 0.         -0.5         0.        ]
503 [ 0.         -0.5         0.33333333]
504 [ 0.         -0.5         0.66666667]
505 [ 0.          0.         -0.33333333]
506 [ 0.          0.          0.        ]
507 [ 0.          0.          0.33333333]
508 [ 0.          0.          0.66666667]
509 [ 0.          0.5        -0.33333333]
510 [ 0.          0.5         0.        ]
511 [ 0.          0.5         0.33333333]
512 [ 0.          0.5         0.66666667]
513 [ 1.         -0.5        -0.33333333]
514 [ 1.         -0.5         0.        ]
515 [ 1.         -0.5         0.33333333]
516 [ 1.         -0.5         0.66666667]
517 [ 1.          0.         -0.33333333]
518 [ 1.          0.          0.        ]
519 [ 1.          0.          0.33333333]
520 [ 1.          0.          0.66666667]
521 [ 1.          0.5        -0.33333333]
522 [ 1.          0.5         0.        ]
523 [ 1.          0.5         0.33333333]
524 [ 1.          0.5         0.66666667]]"""
525        self.assertMultiLineEqual(rm_spaces(str(bz_kmesh)), rm_spaces(ref_string))
526
527
528class TestKsamplingInfo(AbipyTest):
529
530    def test_ksampling(self):
531        """Test KsamplingInfo API."""
532        with self.assertRaises(ValueError):
533            KSamplingInfo(foo=1)
534
535        # from_mpdivs constructor
536        mpdivs, shifts = [2, 3, 4], [0.5, 0.5, 0.5]
537        kptopt = 1
538        ksi = KSamplingInfo.from_mpdivs(mpdivs, shifts, kptopt)
539        repr(ksi); str(ksi)
540        self.assert_equal(ksi.mpdivs, mpdivs)
541        self.assert_equal(ksi.kptrlatt, np.diag(mpdivs))
542        self.assert_equal(ksi.shifts.flatten(), shifts)
543        assert ksi.shifts.shape == (1, 3)
544        assert ksi.kptopt == kptopt
545        assert ksi.is_mesh
546        assert ksi.has_diagonal_kptrlatt
547        assert not ksi.is_path
548
549        # from kptrlatt constructor
550        kptrlatt = np.diag(mpdivs)
551        ksi = KSamplingInfo.from_kptrlatt(kptrlatt, shifts, kptopt)
552        repr(ksi); str(ksi)
553        assert ksi.kptrlatt.shape == (3, 3)
554        self.assert_equal(ksi.kptrlatt, np.diag(mpdivs))
555        self.assert_equal(ksi.mpdivs, np.diag(ksi.kptrlatt))
556        self.assert_equal(ksi.shifts.flatten(), shifts)
557        assert ksi.kptopt == kptopt
558        assert ksi.is_mesh
559        assert ksi.has_diagonal_kptrlatt
560        assert not ksi.is_path
561
562        # kptrlatt with non-zero off-diagonal elements.
563        shifts = [0.5, 0.5, 0.5]
564        kptrlatt = [1, 1, 1, 2, 2, 2, 3, 3, 3]
565        kptopt = 1
566        ksi = KSamplingInfo.from_kptrlatt(kptrlatt, shifts, kptopt)
567        repr(ksi); str(ksi)
568        assert ksi.mpdivs is None
569        assert not ksi.has_diagonal_kptrlatt
570        assert not ksi.is_path
571
572        # from_kbounds constructor
573        kbounds = [0, 0, 0, 1, 1, 1]
574        ksi = KSamplingInfo.from_kbounds(kbounds)
575        repr(ksi); str(ksi)
576        assert (ksi.mpdivs, ksi.kptrlatt, ksi.kptrlatt_orig, ksi.shifts, ksi.shifts_orig) == 5 * (None,)
577        assert ksi.kptopt == -1
578        assert ksi.kptrlatt is None
579        assert not ksi.is_mesh
580        assert not ksi.has_diagonal_kptrlatt
581        assert ksi.is_path
582
583        assert ksi is KSamplingInfo.as_ksampling(ksi)
584
585        ksi_from_dict = KSamplingInfo.as_ksampling({k: v for k, v in ksi.items()})
586        assert ksi_from_dict.kptopt == ksi.kptopt
587
588        ksi_none = KSamplingInfo.as_ksampling(None)
589        repr(ksi_none); str(ksi_none)
590        assert ksi_none.kptopt == 0
591        assert not ksi_none.is_mesh
592        assert not ksi_none.is_path
593
594
595class TestKmappingTools(AbipyTest):
596
597    def setUp(self):
598        with abilab.abiopen(abidata.ref_file("mgb2_kmesh181818_FATBANDS.nc")) as ncfile:
599            self.mgb2 = ncfile.structure
600            assert ncfile.ebands.kpoints.is_ibz
601            self.kibz = [k.frac_coords for k in ncfile.ebands.kpoints]
602            self.has_timrev = True
603            #self.has_timrev = has_timrev_from_kptopt(kptopt)
604            self.ngkpt = [18, 18, 18]
605
606    def test_map_grid2ibz(self):
607        """Testing map_grid2ibz."""
608        bz2ibz = map_grid2ibz(self.mgb2, self.kibz, self.ngkpt, self.has_timrev, pbc=False)
609
610        bz = []
611        nx, ny, nz = self.ngkpt
612        for ix, iy, iz in itertools.product(range(nx), range(ny), range(nz)):
613            bz.append([ix/nz, iy/ny, iz/nz])
614        bz = np.reshape(bz, (-1, 3))
615
616        abispg = self.mgb2.abi_spacegroup
617
618        nmax = 54
619        errors = []
620        for ik_bz, kbz in enumerate(bz[:nmax]):
621            ik_ibz = bz2ibz[ik_bz]
622            ki = self.kibz[ik_ibz]
623            for symmop in abispg.fm_symmops:
624                krot = symmop.rotate_k(ki)
625                if issamek(krot, kbz):
626                    break
627            else:
628                errors.append((ik_bz, kbz))
629
630        assert not errors
631
632    #def test_with_from_structure_with_symrec(self):
633    #    """Generate Ktables from a structure with Abinit symmetries."""
634    #    self.mgb2 = self.get_abistructure.mgb2("mgb2_kpath_FATBANDS.nc")
635    #    assert self.mgb2.abi_spacegroup is not None
636    #    mesh = [4, 4, 4]
637    #    k = Ktables(self.mgb2, mesh, is_shift=None, has_timrev=True)
638    #    repr(k); str(k)
639    #    k.print_bz2ibz()
640
641    #def test_with_structure_without_symrec(self):
642    #    """Generate Ktables from a structure without Abinit symmetries."""
643    #    assert self.mgb2.abi_spacegroup is None
644    #    k = Ktables(self.mgb2, mesh, is_shift, has_timrev)
645    #    repr(k); str(k)
646    #    k.print_bz2ibz()
647