1import re
2import warnings
3from typing import Dict
4
5import numpy as np
6
7import ase  # Annotations
8from ase.utils import jsonable
9from ase.cell import Cell
10
11
12def monkhorst_pack(size):
13    """Construct a uniform sampling of k-space of given size."""
14    if np.less_equal(size, 0).any():
15        raise ValueError('Illegal size: %s' % list(size))
16    kpts = np.indices(size).transpose((1, 2, 3, 0)).reshape((-1, 3))
17    return (kpts + 0.5) / size - 0.5
18
19
20def get_monkhorst_pack_size_and_offset(kpts):
21    """Find Monkhorst-Pack size and offset.
22
23    Returns (size, offset), where::
24
25        kpts = monkhorst_pack(size) + offset.
26
27    The set of k-points must not have been symmetry reduced."""
28
29    if len(kpts) == 1:
30        return np.ones(3, int), np.array(kpts[0], dtype=float)
31
32    size = np.zeros(3, int)
33    for c in range(3):
34        # Determine increment between k-points along current axis
35        delta = max(np.diff(np.sort(kpts[:, c])))
36
37        # Determine number of k-points as inverse of distance between kpoints
38        if delta > 1e-8:
39            size[c] = int(round(1.0 / delta))
40        else:
41            size[c] = 1
42
43    if size.prod() == len(kpts):
44        kpts0 = monkhorst_pack(size)
45        offsets = kpts - kpts0
46
47        # All offsets must be identical:
48        if (offsets.ptp(axis=0) < 1e-9).all():
49            return size, offsets[0].copy()
50
51    raise ValueError('Not an ASE-style Monkhorst-Pack grid!')
52
53
54def get_monkhorst_shape(kpts):
55    warnings.warn('Use get_monkhorst_pack_size_and_offset()[0] instead.')
56    return get_monkhorst_pack_size_and_offset(kpts)[0]
57
58
59def kpoint_convert(cell_cv, skpts_kc=None, ckpts_kv=None):
60    """Convert k-points between scaled and cartesian coordinates.
61
62    Given the atomic unit cell, and either the scaled or cartesian k-point
63    coordinates, the other is determined.
64
65    The k-point arrays can be either a single point, or a list of points,
66    i.e. the dimension k can be empty or multidimensional.
67    """
68    if ckpts_kv is None:
69        icell_cv = 2 * np.pi * np.linalg.pinv(cell_cv).T
70        return np.dot(skpts_kc, icell_cv)
71    elif skpts_kc is None:
72        return np.dot(ckpts_kv, cell_cv.T) / (2 * np.pi)
73    else:
74        raise KeyError('Either scaled or cartesian coordinates must be given.')
75
76
77def parse_path_string(s):
78    """Parse compact string representation of BZ path.
79
80    A path string can have several non-connected sections separated by
81    commas. The return value is a list of sections where each section is a
82    list of labels.
83
84    Examples:
85
86    >>> parse_path_string('GX')
87    [['G', 'X']]
88    >>> parse_path_string('GX,M1A')
89    [['G', 'X'], ['M1', 'A']]
90    """
91    paths = []
92    for path in s.split(','):
93        names = [name
94                 for name in re.split(r'([A-Z][a-z0-9]*)', path)
95                 if name]
96        paths.append(names)
97    return paths
98
99
100def resolve_kpt_path_string(path, special_points):
101    paths = parse_path_string(path)
102    coords = [np.array([special_points[sym] for sym in subpath]).reshape(-1, 3)
103              for subpath in paths]
104    return paths, coords
105
106
107def resolve_custom_points(pathspec, special_points, eps):
108    """Resolve a path specification into a string.
109
110    The path specification is a list path segments, each segment being a kpoint
111    label or kpoint coordinate, or a single such segment.
112
113    Return a string representing the same path.  Generic kpoint labels
114    are generated dynamically as necessary, updating the special_point
115    dictionary if necessary.  The tolerance eps is used to see whether
116    coordinates are close enough to a special point to deserve being
117    labelled as such."""
118    # This should really run on Cartesian coordinates but we'll probably
119    # be lazy and call it on scaled ones.
120
121    # We may add new points below so take a copy of the input:
122    special_points = dict(special_points)
123
124    if len(pathspec) == 0:
125        return '', special_points
126
127    if isinstance(pathspec, str):
128        pathspec = parse_path_string(pathspec)
129
130    def looks_like_single_kpoint(obj):
131        if isinstance(obj, str):
132            return True
133        try:
134            arr = np.asarray(obj, float)
135        except ValueError:
136            return False
137        else:
138            return arr.shape == (3,)
139
140    # We accept inputs that are either
141    #  1) string notation
142    #  2) list of kpoints (each either a string or three floats)
143    #  3) list of list of kpoints; each toplevel list is a contiguous subpath
144    # Here we detect form 2 and normalize to form 3:
145    for thing in pathspec:
146        if looks_like_single_kpoint(thing):
147            pathspec = [pathspec]
148            break
149
150    def name_generator():
151        counter = 0
152        while True:
153            name = 'Kpt{}'.format(counter)
154            yield name
155            counter += 1
156    custom_names = name_generator()
157
158    labelseq = []
159    for subpath in pathspec:
160        for kpt in subpath:
161            if isinstance(kpt, str):
162                if kpt not in special_points:
163                    raise KeyError('No kpoint "{}" among "{}"'
164                                   .format(kpt,
165                                           ''.join(special_points)))
166                labelseq.append(kpt)
167                continue
168
169            kpt = np.asarray(kpt, float)
170            if not kpt.shape == (3,):
171                raise ValueError(f'Not a valid kpoint: {kpt}')
172
173            for key, val in special_points.items():
174                if np.abs(kpt - val).max() < eps:
175                    labelseq.append(key)
176                    break  # Already present
177            else:
178                # New special point - search for name we haven't used yet:
179                name = next(custom_names)
180                while name in special_points:
181                    name = next(custom_names)
182                special_points[name] = kpt
183                labelseq.append(name)
184        labelseq.append(',')
185
186    last = labelseq.pop()
187    assert last == ','
188    return ''.join(labelseq), special_points
189
190
191def normalize_special_points(special_points):
192    dct = {}
193    for name, value in special_points.items():
194        if not isinstance(name, str):
195            raise TypeError('Expected name to be a string')
196        if not np.shape(value) == (3,):
197            raise ValueError('Expected 3 kpoint coordinates')
198        dct[name] = np.asarray(value, float)
199    return dct
200
201
202@jsonable('bandpath')
203class BandPath:
204    """Represents a Brillouin zone path or bandpath.
205
206    A band path has a unit cell, a path specification, special points,
207    and interpolated k-points.  Band paths are typically created
208    indirectly using the :class:`~ase.geometry.Cell` or
209    :class:`~ase.lattice.BravaisLattice` classes:
210
211    >>> from ase.lattice import CUB
212    >>> path = CUB(3).bandpath()
213    >>> path
214    BandPath(path='GXMGRX,MR', cell=[3x3], special_points={GMRX}, kpts=[40x3])
215
216    Band paths support JSON I/O:
217
218    >>> from ase.io.jsonio import read_json
219    >>> path.write('mybandpath.json')
220    >>> read_json('mybandpath.json')
221    BandPath(path='GXMGRX,MR', cell=[3x3], special_points={GMRX}, kpts=[40x3])
222
223    """
224    def __init__(self, cell, kpts=None,
225                 special_points=None, path=None):
226        if kpts is None:
227            kpts = np.empty((0, 3))
228
229        if special_points is None:
230            special_points = {}
231        else:
232            special_points = normalize_special_points(special_points)
233
234        if path is None:
235            path = ''
236
237        cell = Cell(cell)
238        self._cell = cell
239        kpts = np.asarray(kpts)
240        assert kpts.ndim == 2 and kpts.shape[1] == 3 and kpts.dtype == float
241        self._icell = self.cell.reciprocal()
242        self._kpts = kpts
243        self._special_points = special_points
244        if not isinstance(path, str):
245            raise TypeError(f'path must be a string; was {path!r}')
246        self._path = path
247
248    @property
249    def cell(self) -> Cell:
250        """The :class:`~ase.cell.Cell` of this BandPath."""
251        return self._cell
252
253    @property
254    def icell(self) -> Cell:
255        """Reciprocal cell of this BandPath as a :class:`~ase.cell.Cell`."""
256        return self._icell
257
258    @property
259    def kpts(self) -> np.ndarray:
260        """The kpoints of this BandPath as an array of shape (nkpts, 3).
261
262        The kpoints are given in units of the reciprocal cell."""
263        return self._kpts
264
265    @property
266    def special_points(self) -> Dict[str, np.ndarray]:
267        """Special points of this BandPath as a dictionary.
268
269        The dictionary maps names (such as `'G'`) to kpoint coordinates
270        in units of the reciprocal cell as a 3-element numpy array.
271
272        It's unwise to edit this dictionary directly.  If you need that,
273        consider deepcopying it."""
274        return self._special_points
275
276    @property
277    def path(self) -> str:
278        """The string specification of this band path.
279
280        This is a specification of the form `'GXWKGLUWLK,UX'`.
281
282        Comma marks a discontinuous jump: K is not connected to U."""
283        return self._path
284
285    def transform(self, op: np.ndarray) -> 'BandPath':
286        """Apply 3x3 matrix to this BandPath and return new BandPath.
287
288        This is useful for converting the band path to another cell.
289        The operation will typically be a permutation/flipping
290        established by a function such as Niggli reduction."""
291        # XXX acceptable operations are probably only those
292        # who come from Niggli reductions (permutations etc.)
293        #
294        # We should insert a check.
295        # I wonder which operations are valid?  They won't be valid
296        # if they change lengths, volume etc.
297        special_points = {}
298        for name, value in self.special_points.items():
299            special_points[name] = value @ op
300
301        return BandPath(op.T @ self.cell, kpts=self.kpts @ op,
302                        special_points=special_points,
303                        path=self.path)
304
305    def todict(self):
306        return {'kpts': self.kpts,
307                'special_points': self.special_points,
308                'labelseq': self.path,
309                'cell': self.cell}
310
311    def interpolate(
312            self,
313            path: str = None,
314            npoints: int = None,
315            special_points: Dict[str, np.ndarray] = None,
316            density: float = None,
317    ) -> 'BandPath':
318        """Create new bandpath, (re-)interpolating kpoints from this one."""
319        if path is None:
320            path = self.path
321
322        if special_points is None:
323            special_points = self.special_points
324
325        pathnames, pathcoords = resolve_kpt_path_string(path, special_points)
326        kpts, x, X = paths2kpts(pathcoords, self.cell, npoints, density)
327        return BandPath(self.cell, kpts, path=path,
328                        special_points=special_points)
329
330    def _scale(self, coords):
331        return np.dot(coords, self.icell)
332
333    def __repr__(self):
334        return ('{}(path={}, cell=[3x3], special_points={{{}}}, kpts=[{}x3])'
335                .format(self.__class__.__name__,
336                        repr(self.path),
337                        ''.join(sorted(self.special_points)),
338                        len(self.kpts)))
339
340    def cartesian_kpts(self) -> np.ndarray:
341        """Get Cartesian kpoints from this bandpath."""
342        return self._scale(self.kpts)
343
344    def __iter__(self):
345        """XXX Compatibility hack for bandpath() function.
346
347        bandpath() now returns a BandPath object, which is a Good
348        Thing.  However it used to return a tuple of (kpts, x_axis,
349        special_x_coords), and people would use tuple unpacking for
350        those.
351
352        This function makes tuple unpacking work in the same way.
353        It will be removed in the future.
354
355        """
356        import warnings
357        warnings.warn('Please do not use (kpts, x, X) = bandpath(...).  '
358                      'Use path = bandpath(...) and then kpts = path.kpts and '
359                      '(x, X, labels) = path.get_linear_kpoint_axis().')
360        yield self.kpts
361
362        x, xspecial, _ = self.get_linear_kpoint_axis()
363        yield x
364        yield xspecial
365
366    def __getitem__(self, index):
367        # Temp compatibility stuff, see __iter__
368        return tuple(self)[index]
369
370    def get_linear_kpoint_axis(self, eps=1e-5):
371        """Define x axis suitable for plotting a band structure.
372
373        See :func:`ase.dft.kpoints.labels_from_kpts`."""
374
375        index2name = self._find_special_point_indices(eps)
376        indices = sorted(index2name)
377        labels = [index2name[index] for index in indices]
378        xcoords, special_xcoords = indices_to_axis_coords(
379            indices, self.kpts, self.cell)
380        return xcoords, special_xcoords, labels
381
382    def _find_special_point_indices(self, eps):
383        """Find indices of kpoints which are close to special points.
384
385        The result is returned as a dictionary mapping indices to labels."""
386        # XXX must work in Cartesian coordinates for comparison to eps
387        # to fully make sense!
388        index2name = {}
389        nkpts = len(self.kpts)
390
391        for name, kpt in self.special_points.items():
392            displacements = self.kpts - kpt[np.newaxis, :]
393            distances = np.linalg.norm(displacements, axis=1)
394            args = np.argwhere(distances < eps)
395            for arg in args.flat:
396                dist = distances[arg]
397                # Check if an adjacent point exists and is even closer:
398                neighbours = distances[max(arg - 1, 0):min(arg + 1, nkpts - 1)]
399                if not any(neighbours < dist):
400                    index2name[arg] = name
401
402        return index2name
403
404    def plot(self, **plotkwargs):
405        """Visualize this bandpath.
406
407        Plots the irreducible Brillouin zone and this bandpath."""
408        import ase.dft.bz as bz
409
410        # We previously had a "dimension=3" argument which is now unused.
411        plotkwargs.pop('dimension', None)
412
413        special_points = self.special_points
414        labelseq, coords = resolve_kpt_path_string(self.path,
415                                                   special_points)
416
417        paths = []
418        points_already_plotted = set()
419        for subpath_labels, subpath_coords in zip(labelseq, coords):
420            subpath_coords = np.array(subpath_coords)
421            points_already_plotted.update(subpath_labels)
422            paths.append((subpath_labels, self._scale(subpath_coords)))
423
424        # Add each special point as a single-point subpath if they were
425        # not plotted already:
426        for label, point in special_points.items():
427            if label not in points_already_plotted:
428                paths.append(([label], [self._scale(point)]))
429
430        kw = {'vectors': True,
431              'pointstyle': {'marker': '.'}}
432
433        kw.update(plotkwargs)
434        return bz.bz_plot(self.cell, paths=paths,
435                          points=self.cartesian_kpts(),
436                          **kw)
437
438    def free_electron_band_structure(
439            self, **kwargs
440    ) -> 'ase.spectrum.band_structure.BandStructure':
441        """Return band structure of free electrons for this bandpath.
442
443        Keyword arguments are passed to
444        :class:`~ase.calculators.test.FreeElectrons`.
445
446        This is for mostly testing and visualization."""
447        from ase import Atoms
448        from ase.calculators.test import FreeElectrons
449        from ase.spectrum.band_structure import calculate_band_structure
450        atoms = Atoms(cell=self.cell, pbc=True)
451        atoms.calc = FreeElectrons(**kwargs)
452        bs = calculate_band_structure(atoms, path=self)
453        return bs
454
455
456def bandpath(path, cell, npoints=None, density=None, special_points=None,
457             eps=2e-4):
458    """Make a list of kpoints defining the path between the given points.
459
460    path: list or str
461        Can be:
462
463        * a string that parse_path_string() understands: 'GXL'
464        * a list of BZ points: [(0, 0, 0), (0.5, 0, 0)]
465        * or several lists of BZ points if the the path is not continuous.
466    cell: 3x3
467        Unit cell of the atoms.
468    npoints: int
469        Length of the output kpts list. If too small, at least the beginning
470        and ending point of each path segment will be used. If None (default),
471        it will be calculated using the supplied density or a default one.
472    density: float
473        k-points per 1/A on the output kpts list. If npoints is None,
474        the number of k-points in the output list will be:
475        npoints = density * path total length (in Angstroms).
476        If density is None (default), use 5 k-points per A⁻¹.
477        If the calculated npoints value is less than 50, a minimum value of 50
478        will be used.
479    special_points: dict or None
480        Dictionary mapping names to special points.  If None, the special
481        points will be derived from the cell.
482    eps: float
483        Precision used to identify Bravais lattice, deducing special points.
484
485    You may define npoints or density but not both.
486
487    Return a :class:`~ase.dft.kpoints.BandPath` object."""
488
489    cell = Cell.ascell(cell)
490    return cell.bandpath(path, npoints=npoints, density=density,
491                         special_points=special_points, eps=eps)
492
493
494DEFAULT_KPTS_DENSITY = 5    # points per 1/Angstrom
495
496
497def paths2kpts(paths, cell, npoints=None, density=None):
498    if not(npoints is None or density is None):
499        raise ValueError('You may define npoints or density, but not both.')
500    points = np.concatenate(paths)
501    dists = points[1:] - points[:-1]
502    lengths = [np.linalg.norm(d) for d in kpoint_convert(cell, skpts_kc=dists)]
503
504    i = 0
505    for path in paths[:-1]:
506        i += len(path)
507        lengths[i - 1] = 0
508
509    length = sum(lengths)
510
511    if npoints is None:
512        if density is None:
513            density = DEFAULT_KPTS_DENSITY
514        # Set npoints using the length of the path
515        npoints = int(round(length * density))
516
517    kpts = []
518    x0 = 0
519    x = []
520    X = [0]
521    for P, d, L in zip(points[:-1], dists, lengths):
522        diff = length - x0
523        if abs(diff) < 1e-6:
524            n = 0
525        else:
526            n = max(2, int(round(L * (npoints - len(x)) / diff)))
527
528        for t in np.linspace(0, 1, n)[:-1]:
529            kpts.append(P + t * d)
530            x.append(x0 + t * L)
531        x0 += L
532        X.append(x0)
533    if len(points):
534        kpts.append(points[-1])
535        x.append(x0)
536
537    if len(kpts) == 0:
538        kpts = np.empty((0, 3))
539
540    return np.array(kpts), np.array(x), np.array(X)
541
542
543get_bandpath = bandpath  # old name
544
545
546def find_bandpath_kinks(cell, kpts, eps=1e-5):
547    """Find indices of those kpoints that are not interior to a line segment."""
548    # XXX Should use the Cartesian kpoints.
549    # Else comparison to eps will be anisotropic and hence arbitrary.
550    diffs = kpts[1:] - kpts[:-1]
551    kinks = abs(diffs[1:] - diffs[:-1]).sum(1) > eps
552    N = len(kpts)
553    indices = []
554    if N > 0:
555        indices.append(0)
556        indices.extend(np.arange(1, N - 1)[kinks])
557        indices.append(N - 1)
558    return indices
559
560
561def labels_from_kpts(kpts, cell, eps=1e-5, special_points=None):
562    """Get an x-axis to be used when plotting a band structure.
563
564    The first of the returned lists can be used as a x-axis when plotting
565    the band structure. The second list can be used as xticks, and the third
566    as xticklabels.
567
568    Parameters:
569
570    kpts: list
571        List of scaled k-points.
572
573    cell: list
574        Unit cell of the atomic structure.
575
576    Returns:
577
578    Three arrays; the first is a list of cumulative distances between k-points,
579    the second is x coordinates of the special points,
580    the third is the special points as strings.
581    """
582
583    if special_points is None:
584        special_points = get_special_points(cell)
585    points = np.asarray(kpts)
586    # XXX Due to this mechanism, we are blind to special points
587    # that lie on straight segments such as [K, G, -K].
588    indices = find_bandpath_kinks(cell, kpts, eps=1e-5)
589
590    labels = []
591    for kpt in points[indices]:
592        for label, k in special_points.items():
593            if abs(kpt - k).sum() < eps:
594                break
595        else:
596            # No exact match.  Try modulus 1:
597            for label, k in special_points.items():
598                if abs((kpt - k) % 1).sum() < eps:
599                    break
600            else:
601                label = '?'
602        labels.append(label)
603
604    xcoords, ixcoords = indices_to_axis_coords(indices, points, cell)
605    return xcoords, ixcoords, labels
606
607
608def indices_to_axis_coords(indices, points, cell):
609    jump = False  # marks a discontinuity in the path
610    xcoords = [0]
611    for i1, i2 in zip(indices[:-1], indices[1:]):
612        if not jump and i1 + 1 == i2:
613            length = 0
614            jump = True  # we don't want two jumps in a row
615        else:
616            diff = points[i2] - points[i1]
617            length = np.linalg.norm(kpoint_convert(cell, skpts_kc=diff))
618            jump = False
619        xcoords.extend(np.linspace(0, length, i2 - i1 + 1)[1:] + xcoords[-1])
620
621    xcoords = np.array(xcoords)
622    return xcoords, xcoords[indices]
623
624
625special_paths = {
626    'cubic': 'GXMGRX,MR',
627    'fcc': 'GXWKGLUWLK,UX',
628    'bcc': 'GHNGPH,PN',
629    'tetragonal': 'GXMGZRAZXR,MA',
630    'orthorhombic': 'GXSYGZURTZ,YT,UX,SR',
631    'hexagonal': 'GMKGALHA,LM,KH',
632    'monoclinic': 'GYHCEM1AXH1,MDZ,YD',
633    'rhombohedral type 1': 'GLB1,BZGX,QFP1Z,LP',
634    'rhombohedral type 2': 'GPZQGFP1Q1LZ'}
635
636
637def get_special_points(cell, lattice=None, eps=2e-4):
638    """Return dict of special points.
639
640    The definitions are from a paper by Wahyu Setyawana and Stefano
641    Curtarolo::
642
643        https://doi.org/10.1016/j.commatsci.2010.05.010
644
645    cell: 3x3 ndarray
646        Unit cell.
647    lattice: str
648        Optionally check that the cell is one of the following: cubic, fcc,
649        bcc, orthorhombic, tetragonal, hexagonal or monoclinic.
650    eps: float
651        Tolerance for cell-check.
652    """
653
654    if isinstance(cell, str):
655        warnings.warn('Please call this function with cell as the first '
656                      'argument')
657        lattice, cell = cell, lattice
658
659    cell = Cell.ascell(cell)
660    # We create the bandpath because we want to transform the kpoints too,
661    # from the canonical cell to the given one.
662    #
663    # Note that this function is missing a tolerance, epsilon.
664    path = cell.bandpath(npoints=0)
665    return path.special_points
666
667
668def monkhorst_pack_interpolate(path, values, icell, bz2ibz,
669                               size, offset=(0, 0, 0), pad_width=2):
670    """Interpolate values from Monkhorst-Pack sampling.
671
672    `monkhorst_pack_interpolate` takes a set of `values`, for example
673    eigenvalues, that are resolved on a Monkhorst Pack k-point grid given by
674    `size` and `offset` and interpolates the values onto the k-points
675    in `path`.
676
677    Note
678    ----
679    For the interpolation to work, path has to lie inside the domain
680    that is spanned by the MP kpoint grid given by size and offset.
681
682    To try to ensure this we expand the domain slightly by adding additional
683    entries along the edges and sides of the domain with values determined by
684    wrapping the values to the opposite side of the domain. In this way we
685    assume that the function to be interpolated is a periodic function in
686    k-space. The padding width is determined by the `pad_width` parameter.
687
688    Parameters
689    ----------
690    path: (nk, 3) array-like
691        Desired path in units of reciprocal lattice vectors.
692    values: (nibz, ...) array-like
693        Values on Monkhorst-Pack grid.
694    icell: (3, 3) array-like
695        Reciprocal lattice vectors.
696    bz2ibz: (nbz,) array-like of int
697        Map from nbz points in BZ to nibz reduced points in IBZ.
698    size: (3,) array-like of int
699        Size of Monkhorst-Pack grid.
700    offset: (3,) array-like
701        Offset of Monkhorst-Pack grid.
702    pad_width: int
703        Padding width to aid interpolation
704
705    Returns
706    -------
707    (nbz,) array-like
708        *values* interpolated to *path*.
709    """
710    from scipy.interpolate import LinearNDInterpolator
711
712    path = (np.asarray(path) + 0.5) % 1 - 0.5
713    path = np.dot(path, icell)
714
715    # Fold out values from IBZ to BZ:
716    v = np.asarray(values)[bz2ibz]
717    v = v.reshape(tuple(size) + v.shape[1:])
718
719    # Create padded Monkhorst-Pack grid:
720    size = np.asarray(size)
721    i = (np.indices(size + 2 * pad_width)
722         .transpose((1, 2, 3, 0)).reshape((-1, 3)))
723    k = (i - pad_width + 0.5) / size - 0.5 + offset
724    k = np.dot(k, icell)
725
726    # Fill in boundary values:
727    V = np.pad(v, [(pad_width, pad_width)] * 3 +
728               [(0, 0)] * (v.ndim - 3), mode="wrap")
729
730    interpolate = LinearNDInterpolator(k, V.reshape((-1,) + V.shape[3:]))
731    interpolated_points = interpolate(path)
732
733    # NaN values indicate points outside interpolation domain, if fail
734    # try increasing padding
735    assert not np.isnan(interpolated_points).any(), \
736        "Points outside interpolation domain. Try increasing pad_width."
737
738    return interpolated_points
739
740
741# ChadiCohen k point grids. The k point grids are given in units of the
742# reciprocal unit cell. The variables are named after the following
743# convention: cc+'<Nkpoints>'+_+'shape'. For example an 18 k point
744# sq(3)xsq(3) is named 'cc18_sq3xsq3'.
745
746cc6_1x1 = np.array([
747    1, 1, 0, 1, 0, 0, 0, -1, 0, -1, -1, 0, -1, 0, 0,
748    0, 1, 0]).reshape((6, 3)) / 3.0
749
750cc12_2x3 = np.array([
751    3, 4, 0, 3, 10, 0, 6, 8, 0, 3, -2, 0, 6, -4, 0,
752    6, 2, 0, -3, 8, 0, -3, 2, 0, -3, -4, 0, -6, 4, 0, -6, -2, 0, -6,
753    -8, 0]).reshape((12, 3)) / 18.0
754
755cc18_sq3xsq3 = np.array([
756    2, 2, 0, 4, 4, 0, 8, 2, 0, 4, -2, 0, 8, -4,
757    0, 10, -2, 0, 10, -8, 0, 8, -10, 0, 2, -10, 0, 4, -8, 0, -2, -8,
758    0, 2, -4, 0, -4, -4, 0, -2, -2, 0, -4, 2, 0, -2, 4, 0, -8, 4, 0,
759    -4, 8, 0]).reshape((18, 3)) / 18.0
760
761cc18_1x1 = np.array([
762    2, 4, 0, 2, 10, 0, 4, 8, 0, 8, 4, 0, 8, 10, 0,
763    10, 8, 0, 2, -2, 0, 4, -4, 0, 4, 2, 0, -2, 8, 0, -2, 2, 0, -2, -4,
764    0, -4, 4, 0, -4, -2, 0, -4, -8, 0, -8, 2, 0, -8, -4, 0, -10, -2,
765    0]).reshape((18, 3)) / 18.0
766
767cc54_sq3xsq3 = np.array([
768    4, -10, 0, 6, -10, 0, 0, -8, 0, 2, -8, 0, 6,
769    -8, 0, 8, -8, 0, -4, -6, 0, -2, -6, 0, 2, -6, 0, 4, -6, 0, 8, -6,
770    0, 10, -6, 0, -6, -4, 0, -2, -4, 0, 0, -4, 0, 4, -4, 0, 6, -4, 0,
771    10, -4, 0, -6, -2, 0, -4, -2, 0, 0, -2, 0, 2, -2, 0, 6, -2, 0, 8,
772    -2, 0, -8, 0, 0, -4, 0, 0, -2, 0, 0, 2, 0, 0, 4, 0, 0, 8, 0, 0,
773    -8, 2, 0, -6, 2, 0, -2, 2, 0, 0, 2, 0, 4, 2, 0, 6, 2, 0, -10, 4,
774    0, -6, 4, 0, -4, 4, 0, 0, 4, 0, 2, 4, 0, 6, 4, 0, -10, 6, 0, -8,
775    6, 0, -4, 6, 0, -2, 6, 0, 2, 6, 0, 4, 6, 0, -8, 8, 0, -6, 8, 0,
776    -2, 8, 0, 0, 8, 0, -6, 10, 0, -4, 10, 0]).reshape((54, 3)) / 18.0
777
778cc54_1x1 = np.array([
779    2, 2, 0, 4, 4, 0, 8, 8, 0, 6, 8, 0, 4, 6, 0, 6,
780    10, 0, 4, 10, 0, 2, 6, 0, 2, 8, 0, 0, 2, 0, 0, 4, 0, 0, 8, 0, -2,
781    6, 0, -2, 4, 0, -4, 6, 0, -6, 4, 0, -4, 2, 0, -6, 2, 0, -2, 0, 0,
782    -4, 0, 0, -8, 0, 0, -8, -2, 0, -6, -2, 0, -10, -4, 0, -10, -6, 0,
783    -6, -4, 0, -8, -6, 0, -2, -2, 0, -4, -4, 0, -8, -8, 0, 4, -2, 0,
784    6, -2, 0, 6, -4, 0, 2, 0, 0, 4, 0, 0, 6, 2, 0, 6, 4, 0, 8, 6, 0,
785    8, 0, 0, 8, 2, 0, 10, 4, 0, 10, 6, 0, 2, -4, 0, 2, -6, 0, 4, -6,
786    0, 0, -2, 0, 0, -4, 0, -2, -6, 0, -4, -6, 0, -6, -8, 0, 0, -8, 0,
787    -2, -8, 0, -4, -10, 0, -6, -10, 0]).reshape((54, 3)) / 18.0
788
789cc162_sq3xsq3 = np.array([
790    -8, 16, 0, -10, 14, 0, -7, 14, 0, -4, 14,
791    0, -11, 13, 0, -8, 13, 0, -5, 13, 0, -2, 13, 0, -13, 11, 0, -10,
792    11, 0, -7, 11, 0, -4, 11, 0, -1, 11, 0, 2, 11, 0, -14, 10, 0, -11,
793    10, 0, -8, 10, 0, -5, 10, 0, -2, 10, 0, 1, 10, 0, 4, 10, 0, -16,
794    8, 0, -13, 8, 0, -10, 8, 0, -7, 8, 0, -4, 8, 0, -1, 8, 0, 2, 8, 0,
795    5, 8, 0, 8, 8, 0, -14, 7, 0, -11, 7, 0, -8, 7, 0, -5, 7, 0, -2, 7,
796    0, 1, 7, 0, 4, 7, 0, 7, 7, 0, 10, 7, 0, -13, 5, 0, -10, 5, 0, -7,
797    5, 0, -4, 5, 0, -1, 5, 0, 2, 5, 0, 5, 5, 0, 8, 5, 0, 11, 5, 0,
798    -14, 4, 0, -11, 4, 0, -8, 4, 0, -5, 4, 0, -2, 4, 0, 1, 4, 0, 4, 4,
799    0, 7, 4, 0, 10, 4, 0, -13, 2, 0, -10, 2, 0, -7, 2, 0, -4, 2, 0,
800    -1, 2, 0, 2, 2, 0, 5, 2, 0, 8, 2, 0, 11, 2, 0, -11, 1, 0, -8, 1,
801    0, -5, 1, 0, -2, 1, 0, 1, 1, 0, 4, 1, 0, 7, 1, 0, 10, 1, 0, 13, 1,
802    0, -10, -1, 0, -7, -1, 0, -4, -1, 0, -1, -1, 0, 2, -1, 0, 5, -1,
803    0, 8, -1, 0, 11, -1, 0, 14, -1, 0, -11, -2, 0, -8, -2, 0, -5, -2,
804    0, -2, -2, 0, 1, -2, 0, 4, -2, 0, 7, -2, 0, 10, -2, 0, 13, -2, 0,
805    -10, -4, 0, -7, -4, 0, -4, -4, 0, -1, -4, 0, 2, -4, 0, 5, -4, 0,
806    8, -4, 0, 11, -4, 0, 14, -4, 0, -8, -5, 0, -5, -5, 0, -2, -5, 0,
807    1, -5, 0, 4, -5, 0, 7, -5, 0, 10, -5, 0, 13, -5, 0, 16, -5, 0, -7,
808    -7, 0, -4, -7, 0, -1, -7, 0, 2, -7, 0, 5, -7, 0, 8, -7, 0, 11, -7,
809    0, 14, -7, 0, 17, -7, 0, -8, -8, 0, -5, -8, 0, -2, -8, 0, 1, -8,
810    0, 4, -8, 0, 7, -8, 0, 10, -8, 0, 13, -8, 0, 16, -8, 0, -7, -10,
811    0, -4, -10, 0, -1, -10, 0, 2, -10, 0, 5, -10, 0, 8, -10, 0, 11,
812    -10, 0, 14, -10, 0, 17, -10, 0, -5, -11, 0, -2, -11, 0, 1, -11, 0,
813    4, -11, 0, 7, -11, 0, 10, -11, 0, 13, -11, 0, 16, -11, 0, -1, -13,
814    0, 2, -13, 0, 5, -13, 0, 8, -13, 0, 11, -13, 0, 14, -13, 0, 1,
815    -14, 0, 4, -14, 0, 7, -14, 0, 10, -14, 0, 13, -14, 0, 5, -16, 0,
816    8, -16, 0, 11, -16, 0, 7, -17, 0, 10, -17, 0]).reshape((162, 3)) / 27.0
817
818cc162_1x1 = np.array([
819    -8, -16, 0, -10, -14, 0, -7, -14, 0, -4, -14,
820    0, -11, -13, 0, -8, -13, 0, -5, -13, 0, -2, -13, 0, -13, -11, 0,
821    -10, -11, 0, -7, -11, 0, -4, -11, 0, -1, -11, 0, 2, -11, 0, -14,
822    -10, 0, -11, -10, 0, -8, -10, 0, -5, -10, 0, -2, -10, 0, 1, -10,
823    0, 4, -10, 0, -16, -8, 0, -13, -8, 0, -10, -8, 0, -7, -8, 0, -4,
824    -8, 0, -1, -8, 0, 2, -8, 0, 5, -8, 0, 8, -8, 0, -14, -7, 0, -11,
825    -7, 0, -8, -7, 0, -5, -7, 0, -2, -7, 0, 1, -7, 0, 4, -7, 0, 7, -7,
826    0, 10, -7, 0, -13, -5, 0, -10, -5, 0, -7, -5, 0, -4, -5, 0, -1,
827    -5, 0, 2, -5, 0, 5, -5, 0, 8, -5, 0, 11, -5, 0, -14, -4, 0, -11,
828    -4, 0, -8, -4, 0, -5, -4, 0, -2, -4, 0, 1, -4, 0, 4, -4, 0, 7, -4,
829    0, 10, -4, 0, -13, -2, 0, -10, -2, 0, -7, -2, 0, -4, -2, 0, -1,
830    -2, 0, 2, -2, 0, 5, -2, 0, 8, -2, 0, 11, -2, 0, -11, -1, 0, -8,
831    -1, 0, -5, -1, 0, -2, -1, 0, 1, -1, 0, 4, -1, 0, 7, -1, 0, 10, -1,
832    0, 13, -1, 0, -10, 1, 0, -7, 1, 0, -4, 1, 0, -1, 1, 0, 2, 1, 0, 5,
833    1, 0, 8, 1, 0, 11, 1, 0, 14, 1, 0, -11, 2, 0, -8, 2, 0, -5, 2, 0,
834    -2, 2, 0, 1, 2, 0, 4, 2, 0, 7, 2, 0, 10, 2, 0, 13, 2, 0, -10, 4,
835    0, -7, 4, 0, -4, 4, 0, -1, 4, 0, 2, 4, 0, 5, 4, 0, 8, 4, 0, 11, 4,
836    0, 14, 4, 0, -8, 5, 0, -5, 5, 0, -2, 5, 0, 1, 5, 0, 4, 5, 0, 7, 5,
837    0, 10, 5, 0, 13, 5, 0, 16, 5, 0, -7, 7, 0, -4, 7, 0, -1, 7, 0, 2,
838    7, 0, 5, 7, 0, 8, 7, 0, 11, 7, 0, 14, 7, 0, 17, 7, 0, -8, 8, 0,
839    -5, 8, 0, -2, 8, 0, 1, 8, 0, 4, 8, 0, 7, 8, 0, 10, 8, 0, 13, 8, 0,
840    16, 8, 0, -7, 10, 0, -4, 10, 0, -1, 10, 0, 2, 10, 0, 5, 10, 0, 8,
841    10, 0, 11, 10, 0, 14, 10, 0, 17, 10, 0, -5, 11, 0, -2, 11, 0, 1,
842    11, 0, 4, 11, 0, 7, 11, 0, 10, 11, 0, 13, 11, 0, 16, 11, 0, -1,
843    13, 0, 2, 13, 0, 5, 13, 0, 8, 13, 0, 11, 13, 0, 14, 13, 0, 1, 14,
844    0, 4, 14, 0, 7, 14, 0, 10, 14, 0, 13, 14, 0, 5, 16, 0, 8, 16, 0,
845    11, 16, 0, 7, 17, 0, 10, 17, 0]).reshape((162, 3)) / 27.0
846
847
848# The following is a list of the critical points in the 1st Brillouin zone
849# for some typical crystal structures following the conventions of Setyawan
850# and Curtarolo [https://doi.org/10.1016/j.commatsci.2010.05.010].
851#
852# In units of the reciprocal basis vectors.
853#
854# See http://en.wikipedia.org/wiki/Brillouin_zone
855sc_special_points = {
856    'cubic': {'G': [0, 0, 0],
857              'M': [1 / 2, 1 / 2, 0],
858              'R': [1 / 2, 1 / 2, 1 / 2],
859              'X': [0, 1 / 2, 0]},
860    'fcc': {'G': [0, 0, 0],
861            'K': [3 / 8, 3 / 8, 3 / 4],
862            'L': [1 / 2, 1 / 2, 1 / 2],
863            'U': [5 / 8, 1 / 4, 5 / 8],
864            'W': [1 / 2, 1 / 4, 3 / 4],
865            'X': [1 / 2, 0, 1 / 2]},
866    'bcc': {'G': [0, 0, 0],
867            'H': [1 / 2, -1 / 2, 1 / 2],
868            'P': [1 / 4, 1 / 4, 1 / 4],
869            'N': [0, 0, 1 / 2]},
870    'tetragonal': {'G': [0, 0, 0],
871                   'A': [1 / 2, 1 / 2, 1 / 2],
872                   'M': [1 / 2, 1 / 2, 0],
873                   'R': [0, 1 / 2, 1 / 2],
874                   'X': [0, 1 / 2, 0],
875                   'Z': [0, 0, 1 / 2]},
876    'orthorhombic': {'G': [0, 0, 0],
877                     'R': [1 / 2, 1 / 2, 1 / 2],
878                     'S': [1 / 2, 1 / 2, 0],
879                     'T': [0, 1 / 2, 1 / 2],
880                     'U': [1 / 2, 0, 1 / 2],
881                     'X': [1 / 2, 0, 0],
882                     'Y': [0, 1 / 2, 0],
883                     'Z': [0, 0, 1 / 2]},
884    'hexagonal': {'G': [0, 0, 0],
885                  'A': [0, 0, 1 / 2],
886                  'H': [1 / 3, 1 / 3, 1 / 2],
887                  'K': [1 / 3, 1 / 3, 0],
888                  'L': [1 / 2, 0, 1 / 2],
889                  'M': [1 / 2, 0, 0]}}
890
891
892# Old version of dictionary kept for backwards compatibility.
893# Not for ordinary use.
894ibz_points = {'cubic': {'Gamma': [0, 0, 0],
895                        'X': [0, 0 / 2, 1 / 2],
896                        'R': [1 / 2, 1 / 2, 1 / 2],
897                        'M': [0 / 2, 1 / 2, 1 / 2]},
898              'fcc': {'Gamma': [0, 0, 0],
899                      'X': [1 / 2, 0, 1 / 2],
900                      'W': [1 / 2, 1 / 4, 3 / 4],
901                      'K': [3 / 8, 3 / 8, 3 / 4],
902                      'U': [5 / 8, 1 / 4, 5 / 8],
903                      'L': [1 / 2, 1 / 2, 1 / 2]},
904              'bcc': {'Gamma': [0, 0, 0],
905                      'H': [1 / 2, -1 / 2, 1 / 2],
906                      'N': [0, 0, 1 / 2],
907                      'P': [1 / 4, 1 / 4, 1 / 4]},
908              'hexagonal': {'Gamma': [0, 0, 0],
909                            'M': [0, 1 / 2, 0],
910                            'K': [-1 / 3, 1 / 3, 0],
911                            'A': [0, 0, 1 / 2],
912                            'L': [0, 1 / 2, 1 / 2],
913                            'H': [-1 / 3, 1 / 3, 1 / 2]},
914              'tetragonal': {'Gamma': [0, 0, 0],
915                             'X': [1 / 2, 0, 0],
916                             'M': [1 / 2, 1 / 2, 0],
917                             'Z': [0, 0, 1 / 2],
918                             'R': [1 / 2, 0, 1 / 2],
919                             'A': [1 / 2, 1 / 2, 1 / 2]},
920              'orthorhombic': {'Gamma': [0, 0, 0],
921                               'R': [1 / 2, 1 / 2, 1 / 2],
922                               'S': [1 / 2, 1 / 2, 0],
923                               'T': [0, 1 / 2, 1 / 2],
924                               'U': [1 / 2, 0, 1 / 2],
925                               'X': [1 / 2, 0, 0],
926                               'Y': [0, 1 / 2, 0],
927                               'Z': [0, 0, 1 / 2]}}
928