1import ase
2from typing import Mapping, Sequence, Union
3import numpy as np
4from ase.utils.arraywrapper import arraylike
5from ase.utils import pbc2pbc
6
7
8__all__ = ['Cell']
9
10
11@arraylike
12class Cell:
13    """Parallel epipedal unit cell of up to three dimensions.
14
15    This object resembles a 3x3 array whose [i, j]-th element is the jth
16    Cartesian coordinate of the ith unit vector.
17
18    Cells of less than three dimensions are represented by placeholder
19    unit vectors that are zero."""
20
21    ase_objtype = 'cell'  # For JSON'ing
22
23    def __init__(self, array):
24        """Create cell.
25
26        Parameters:
27
28        array: 3x3 arraylike object
29          The three cell vectors: cell[0], cell[1], and cell[2].
30        """
31        array = np.asarray(array, dtype=float)
32        assert array.shape == (3, 3)
33        self.array = array
34
35    def cellpar(self, radians=False):
36        """Get unit cell parameters. Sequence of 6 numbers.
37
38        First three are unit cell vector lengths and second three
39        are angles between them::
40
41            [len(a), len(b), len(c), angle(b,c), angle(a,c), angle(a,b)]
42
43        in degrees.
44
45        See also :func:`ase.geometry.cell.cell_to_cellpar`."""
46        from ase.geometry.cell import cell_to_cellpar
47        return cell_to_cellpar(self.array, radians)
48
49    def todict(self):
50        return dict(array=self.array)
51
52    @classmethod
53    def ascell(cls, cell):
54        """Return argument as a Cell object.  See :meth:`ase.cell.Cell.new`.
55
56        A new Cell object is created if necessary."""
57        if isinstance(cell, cls):
58            return cell
59        return cls.new(cell)
60
61    @classmethod
62    def new(cls, cell=None):
63        """Create new cell from any parameters.
64
65        If cell is three numbers, assume three lengths with right angles.
66
67        If cell is six numbers, assume three lengths, then three angles.
68
69        If cell is 3x3, assume three cell vectors."""
70
71        if cell is None:
72            cell = np.zeros((3, 3))
73
74        cell = np.array(cell, float)
75
76        if cell.shape == (3,):
77            cell = np.diag(cell)
78        elif cell.shape == (6,):
79            from ase.geometry.cell import cellpar_to_cell
80            cell = cellpar_to_cell(cell)
81        elif cell.shape != (3, 3):
82            raise ValueError('Cell must be length 3 sequence, length 6 '
83                             'sequence or 3x3 matrix!')
84
85        cellobj = cls(cell)
86        return cellobj
87
88    @classmethod
89    def fromcellpar(cls, cellpar, ab_normal=(0, 0, 1), a_direction=None):
90        """Return new Cell from cell lengths and angles.
91
92        See also :func:`~ase.geometry.cell.cellpar_to_cell()`."""
93        from ase.geometry.cell import cellpar_to_cell
94        cell = cellpar_to_cell(cellpar, ab_normal, a_direction)
95        return cls(cell)
96
97    def get_bravais_lattice(self, eps=2e-4, *, pbc=True):
98        """Return :class:`~ase.lattice.BravaisLattice` for this cell:
99
100        >>> cell = Cell.fromcellpar([4, 4, 4, 60, 60, 60])
101        >>> print(cell.get_bravais_lattice())
102        FCC(a=5.65685)
103
104        .. note:: The Bravais lattice object follows the AFlow
105           conventions.  ``cell.get_bravais_lattice().tocell()`` may
106           differ from the original cell by a permutation or other
107           operation which maps it to the AFlow convention.  For
108           example, the orthorhombic lattice enforces a < b < c.
109
110           To build a bandpath for a particular cell, use
111           :meth:`ase.cell.Cell.bandpath` instead of this method.
112           This maps the kpoints back to the original input cell.
113
114        """
115        from ase.lattice import identify_lattice
116        pbc = self.any(1) & pbc2pbc(pbc)
117        lat, op = identify_lattice(self, eps=eps, pbc=pbc)
118        return lat
119
120    def bandpath(
121            self,
122            path: str = None,
123            npoints: int = None,
124            *,
125            density: float = None,
126            special_points: Mapping[str, Sequence[float]] = None,
127            eps: float = 2e-4,
128            pbc: Union[bool, Sequence[bool]] = True
129    ) -> "ase.dft.kpoints.BandPath":
130        """Build a :class:`~ase.dft.kpoints.BandPath` for this cell.
131
132        If special points are None, determine the Bravais lattice of
133        this cell and return a suitable Brillouin zone path with
134        standard special points.
135
136        If special special points are given, interpolate the path
137        directly from the available data.
138
139        Parameters:
140
141        path: string
142            String of special point names defining the path, e.g. 'GXL'.
143        npoints: int
144            Number of points in total.  Note that at least one point
145            is added for each special point in the path.
146        density: float
147            density of kpoints along the path in Å⁻¹.
148        special_points: dict
149            Dictionary mapping special points to scaled kpoint coordinates.
150            For example ``{'G': [0, 0, 0], 'X': [1, 0, 0]}``.
151        eps: float
152            Tolerance for determining Bravais lattice.
153        pbc: three bools
154            Whether cell is periodic in each direction.  Normally not
155            necessary.  If cell has three nonzero cell vectors, use
156            e.g. pbc=[1, 1, 0] to request a 2D bandpath nevertheless.
157
158        Example
159        -------
160        >>> cell = Cell.fromcellpar([4, 4, 4, 60, 60, 60])
161        >>> cell.bandpath('GXW', npoints=20)
162        BandPath(path='GXW', cell=[3x3], special_points={GKLUWX}, kpts=[20x3])
163
164        """
165        # TODO: Combine with the rotation transformation from bandpath()
166
167        cell = self.uncomplete(pbc)
168
169        if special_points is None:
170            from ase.lattice import identify_lattice
171            lat, op = identify_lattice(cell, eps=eps)
172            bandpath = lat.bandpath(path, npoints=npoints, density=density)
173            return bandpath.transform(op)
174        else:
175            from ase.dft.kpoints import BandPath, resolve_custom_points
176            path, special_points = resolve_custom_points(
177                path, special_points, eps=eps)
178            bandpath = BandPath(cell, path=path, special_points=special_points)
179            return bandpath.interpolate(npoints=npoints, density=density)
180
181    def uncomplete(self, pbc):
182        """Return new cell, zeroing cell vectors where not periodic."""
183        _pbc = np.empty(3, bool)
184        _pbc[:] = pbc
185        cell = self.copy()
186        cell[~_pbc] = 0
187        return cell
188
189    def complete(self):
190        """Convert missing cell vectors into orthogonal unit vectors."""
191        from ase.geometry.cell import complete_cell
192        cell = Cell(complete_cell(self.array))
193        return cell
194
195    def copy(self):
196        """Return a copy of this cell."""
197        cell = Cell(self.array.copy())
198        return cell
199
200    @property
201    def rank(self) -> int:
202        """"Return the dimension of the cell.
203
204        Equal to the number of nonzero lattice vectors."""
205        # The name ndim clashes with ndarray.ndim
206        return self.any(1).sum()  # type: ignore
207
208    @property
209    def orthorhombic(self) -> bool:
210        """Return whether this cell is represented by a diagonal matrix."""
211        from ase.geometry.cell import is_orthorhombic
212        return is_orthorhombic(self)
213
214    def lengths(self):
215        """Return the length of each lattice vector as an array."""
216        return np.linalg.norm(self, axis=1)
217
218    def angles(self):
219        """Return an array with the three angles alpha, beta, and gamma."""
220        return self.cellpar()[3:].copy()
221
222    def __array__(self, dtype=float):
223        if dtype != float:
224            raise ValueError('Cannot convert cell to array of type {}'
225                             .format(dtype))
226        return self.array
227
228    def __bool__(self):
229        return bool(self.any())  # need to convert from np.bool_
230
231    __nonzero__ = __bool__
232
233    @property
234    def volume(self) -> float:
235        """Get the volume of this cell.
236
237        If there are less than 3 lattice vectors, return 0."""
238        # Fail or 0 for <3D cells?
239        # Definitely 0 since this is currently a property.
240        # I think normally it is more convenient just to get zero
241        return np.abs(np.linalg.det(self))
242
243    @property
244    def handedness(self) -> int:
245        """Sign of the determinant of the matrix of cell vectors.
246
247        1 for right-handed cells, -1 for left, and 0 for cells that
248        do not span three dimensions."""
249        return int(np.sign(np.linalg.det(self)))
250
251    def scaled_positions(self, positions) -> np.ndarray:
252        """Calculate scaled positions from Cartesian positions.
253
254        The scaled positions are the positions given in the basis
255        of the cell vectors.  For the purpose of defining the basis, cell
256        vectors that are zero will be replaced by unit vectors as per
257        :meth:`~ase.cell.Cell.complete`."""
258        return np.linalg.solve(self.complete().T, np.transpose(positions)).T
259
260    def cartesian_positions(self, scaled_positions) -> np.ndarray:
261        """Calculate Cartesian positions from scaled positions."""
262        return scaled_positions @ self.complete()
263
264    def reciprocal(self) -> 'Cell':
265        """Get reciprocal lattice as a Cell object.
266
267        Does not include factor of 2 pi."""
268        return Cell(np.linalg.pinv(self).transpose())
269
270    def __repr__(self):
271        if self.orthorhombic:
272            numbers = self.lengths().tolist()
273        else:
274            numbers = self.tolist()
275
276        return 'Cell({})'.format(numbers)
277
278    def niggli_reduce(self, eps=1e-5):
279        """Niggli reduce this cell, returning a new cell and mapping.
280
281        See also :func:`ase.build.tools.niggli_reduce_cell`."""
282        from ase.build.tools import niggli_reduce_cell
283        cell, op = niggli_reduce_cell(self, epsfactor=eps)
284        result = Cell(cell)
285        return result, op
286
287    def minkowski_reduce(self):
288        """Minkowski-reduce this cell, returning new cell and mapping.
289
290        See also :func:`ase.geometry.minkowski_reduction.minkowski_reduce`."""
291        from ase.geometry.minkowski_reduction import minkowski_reduce
292        cell, op = minkowski_reduce(self, self.any(1))
293        result = Cell(cell)
294        return result, op
295
296    def permute_axes(self, permutation):
297        """Permute axes of cell."""
298        assert (np.sort(permutation) == np.arange(3)).all()
299        permuted = Cell(self[permutation][:, permutation])
300        return permuted
301
302    def standard_form(self):
303        """Rotate axes such that unit cell is lower triangular. The cell
304        handedness is preserved.
305
306        A lower-triangular cell with positive diagonal entries is a canonical
307        (i.e. unique) description. For a left-handed cell the diagonal entries
308        are negative.
309
310        Returns:
311
312        rcell: the standardized cell object
313
314        Q: ndarray
315            The orthogonal transformation.  Here, rcell @ Q = cell, where cell
316            is the input cell and rcell is the lower triangular (output) cell.
317        """
318
319        # get cell handedness (right or left)
320        sign = np.sign(np.linalg.det(self))
321        if sign == 0:
322            sign = 1
323
324        # LQ decomposition provides an axis-aligned description of the cell.
325        # Q is an orthogonal matrix and L is a lower triangular matrix. The
326        # decomposition is a unique description if the diagonal elements are
327        # all positive (negative for a left-handed cell).
328        Q, L = np.linalg.qr(self.T)
329        Q = Q.T
330        L = L.T
331
332        # correct the signs of the diagonal elements
333        signs = np.sign(np.diag(L))
334        indices = np.where(signs == 0)[0]
335        signs[indices] = 1
336        indices = np.where(signs != sign)[0]
337        L[:, indices] *= -1
338        Q[indices] *= -1
339        return Cell(L), Q
340
341    # XXX We want a reduction function that brings the cell into
342    # standard form as defined by Setyawan and Curtarolo.
343