1# coding: utf-8
2# Copyright (c) Pymatgen Development Team.
3# Distributed under the terms of the MIT License.
4
5"""
6This module provides classes that operate on points or vectors in 3D space.
7"""
8
9import re
10import string
11import warnings
12from math import cos, pi, sin, sqrt
13
14import numpy as np
15from monty.json import MSONable
16
17from pymatgen.electronic_structure.core import Magmom
18from pymatgen.util.string import transformation_to_string
19from pymatgen.util.typing import ArrayLike
20
21__author__ = "Shyue Ping Ong, Shyam Dwaraknath, Matthew Horton"
22
23
24class SymmOp(MSONable):
25    """
26    A symmetry operation in cartesian space. Consists of a rotation plus a
27    translation. Implementation is as an affine transformation matrix of rank 4
28    for efficiency. Read: http://en.wikipedia.org/wiki/Affine_transformation.
29
30    .. attribute:: affine_matrix
31
32        A 4x4 numpy.array representing the symmetry operation.
33    """
34
35    def __init__(self, affine_transformation_matrix: ArrayLike, tol=0.01):
36        """
37        Initializes the SymmOp from a 4x4 affine transformation matrix.
38        In general, this constructor should not be used unless you are
39        transferring rotations.  Use the static constructors instead to
40        generate a SymmOp from proper rotations and translation.
41
42        Args:
43            affine_transformation_matrix (4x4 array): Representing an
44                affine transformation.
45            tol (float): Tolerance for determining if matrices are equal.
46        """
47        affine_transformation_matrix = np.array(affine_transformation_matrix)
48        if affine_transformation_matrix.shape != (4, 4):
49            raise ValueError("Affine Matrix must be a 4x4 numpy array!")
50        self.affine_matrix = affine_transformation_matrix
51        self.tol = tol
52
53    @staticmethod
54    def from_rotation_and_translation(
55        rotation_matrix: ArrayLike = ((1, 0, 0), (0, 1, 0), (0, 0, 1)),
56        translation_vec: ArrayLike = (0, 0, 0),
57        tol=0.1,
58    ):
59        """
60        Creates a symmetry operation from a rotation matrix and a translation
61        vector.
62
63        Args:
64            rotation_matrix (3x3 array): Rotation matrix.
65            translation_vec (3x1 array): Translation vector.
66            tol (float): Tolerance to determine if rotation matrix is valid.
67
68        Returns:
69            SymmOp object
70        """
71        rotation_matrix = np.array(rotation_matrix)
72        translation_vec = np.array(translation_vec)
73        if rotation_matrix.shape != (3, 3):
74            raise ValueError("Rotation Matrix must be a 3x3 numpy array.")
75        if translation_vec.shape != (3,):
76            raise ValueError("Translation vector must be a rank 1 numpy array " "with 3 elements.")
77        affine_matrix = np.eye(4)
78        affine_matrix[0:3][:, 0:3] = rotation_matrix
79        affine_matrix[0:3][:, 3] = translation_vec
80        return SymmOp(affine_matrix, tol)
81
82    def __eq__(self, other):
83        return np.allclose(self.affine_matrix, other.affine_matrix, atol=self.tol)
84
85    def __hash__(self):
86        return 7
87
88    def __repr__(self):
89        return self.__str__()
90
91    def __str__(self):
92        output = [
93            "Rot:",
94            str(self.affine_matrix[0:3][:, 0:3]),
95            "tau",
96            str(self.affine_matrix[0:3][:, 3]),
97        ]
98        return "\n".join(output)
99
100    def operate(self, point):
101        """
102        Apply the operation on a point.
103
104        Args:
105            point: Cartesian coordinate.
106
107        Returns:
108            Coordinates of point after operation.
109        """
110        affine_point = np.array([point[0], point[1], point[2], 1])
111        return np.dot(self.affine_matrix, affine_point)[0:3]
112
113    def operate_multi(self, points):
114        """
115        Apply the operation on a list of points.
116
117        Args:
118            points: List of Cartesian coordinates
119
120        Returns:
121            Numpy array of coordinates after operation
122        """
123        points = np.array(points)
124        affine_points = np.concatenate([points, np.ones(points.shape[:-1] + (1,))], axis=-1)
125        return np.inner(affine_points, self.affine_matrix)[..., :-1]
126
127    def apply_rotation_only(self, vector: ArrayLike):
128        """
129        Vectors should only be operated by the rotation matrix and not the
130        translation vector.
131
132        Args:
133            vector (3x1 array): A vector.
134        """
135        return np.dot(self.rotation_matrix, vector)
136
137    def transform_tensor(self, tensor: np.ndarray):
138        """
139        Applies rotation portion to a tensor. Note that tensor has to be in
140        full form, not the Voigt form.
141
142        Args:
143            tensor (numpy array): a rank n tensor
144
145        Returns:
146            Transformed tensor.
147        """
148        dim = tensor.shape
149        rank = len(dim)
150        assert all(i == 3 for i in dim)
151        # Build einstein sum string
152        lc = string.ascii_lowercase
153        indices = lc[:rank], lc[rank : 2 * rank]
154        einsum_string = ",".join([a + i for a, i in zip(*indices)])
155        einsum_string += ",{}->{}".format(*indices[::-1])
156        einsum_args = [self.rotation_matrix] * rank + [tensor]
157
158        return np.einsum(einsum_string, *einsum_args)
159
160    def are_symmetrically_related(self, point_a: ArrayLike, point_b: ArrayLike, tol: float = 0.001) -> bool:
161        """
162        Checks if two points are symmetrically related.
163
164        Args:
165            point_a (3x1 array): First point.
166            point_b (3x1 array): Second point.
167            tol (float): Absolute tolerance for checking distance.
168
169        Returns:
170            True if self.operate(point_a) == point_b or vice versa.
171        """
172        if np.allclose(self.operate(point_a), point_b, atol=tol):
173            return True
174        if np.allclose(self.operate(point_b), point_a, atol=tol):
175            return True
176        return False
177
178    @property
179    def rotation_matrix(self) -> np.ndarray:
180        """
181        A 3x3 numpy.array representing the rotation matrix.
182        """
183        return self.affine_matrix[0:3][:, 0:3]
184
185    @property
186    def translation_vector(self) -> np.ndarray:
187        """
188        A rank 1 numpy.array of dim 3 representing the translation vector.
189        """
190        return self.affine_matrix[0:3][:, 3]
191
192    def __mul__(self, other):
193        """
194        Returns a new SymmOp which is equivalent to apply the "other" SymmOp
195        followed by this one.
196        """
197        new_matrix = np.dot(self.affine_matrix, other.affine_matrix)
198        return SymmOp(new_matrix)
199
200    @property
201    def inverse(self) -> "SymmOp":
202        """
203        Returns inverse of transformation.
204        """
205        invr = np.linalg.inv(self.affine_matrix)
206        return SymmOp(invr)
207
208    @staticmethod
209    def from_axis_angle_and_translation(
210        axis: ArrayLike, angle: float, angle_in_radians: bool = False, translation_vec: ArrayLike = (0, 0, 0)
211    ) -> "SymmOp":
212        """
213        Generates a SymmOp for a rotation about a given axis plus translation.
214
215        Args:
216            axis: The axis of rotation in cartesian space. For example,
217                [1, 0, 0]indicates rotation about x-axis.
218            angle (float): Angle of rotation.
219            angle_in_radians (bool): Set to True if angles are given in
220                radians. Or else, units of degrees are assumed.
221            translation_vec: A translation vector. Defaults to zero.
222
223        Returns:
224            SymmOp for a rotation about given axis and translation.
225        """
226        if isinstance(axis, (tuple, list)):
227            axis = np.array(axis)
228
229        vec = np.array(translation_vec)
230
231        a = angle if angle_in_radians else angle * pi / 180
232        cosa = cos(a)
233        sina = sin(a)
234        u = axis / np.linalg.norm(axis)
235        r = np.zeros((3, 3))
236        r[0, 0] = cosa + u[0] ** 2 * (1 - cosa)
237        r[0, 1] = u[0] * u[1] * (1 - cosa) - u[2] * sina
238        r[0, 2] = u[0] * u[2] * (1 - cosa) + u[1] * sina
239        r[1, 0] = u[0] * u[1] * (1 - cosa) + u[2] * sina
240        r[1, 1] = cosa + u[1] ** 2 * (1 - cosa)
241        r[1, 2] = u[1] * u[2] * (1 - cosa) - u[0] * sina
242        r[2, 0] = u[0] * u[2] * (1 - cosa) - u[1] * sina
243        r[2, 1] = u[1] * u[2] * (1 - cosa) + u[0] * sina
244        r[2, 2] = cosa + u[2] ** 2 * (1 - cosa)
245
246        return SymmOp.from_rotation_and_translation(r, vec)
247
248    @staticmethod
249    def from_origin_axis_angle(
250        origin: ArrayLike, axis: ArrayLike, angle: float, angle_in_radians: bool = False
251    ) -> "SymmOp":
252        """
253        Generates a SymmOp for a rotation about a given axis through an
254        origin.
255
256        Args:
257            origin (3x1 array): The origin which the axis passes through.
258            axis (3x1 array): The axis of rotation in cartesian space. For
259                example, [1, 0, 0]indicates rotation about x-axis.
260            angle (float): Angle of rotation.
261            angle_in_radians (bool): Set to True if angles are given in
262                radians. Or else, units of degrees are assumed.
263
264        Returns:
265            SymmOp.
266        """
267        theta = angle * pi / 180 if not angle_in_radians else angle
268        a = origin[0]  # type: ignore
269        b = origin[1]  # type: ignore
270        c = origin[2]  # type: ignore
271        u = axis[0]  # type: ignore
272        v = axis[1]  # type: ignore
273        w = axis[2]  # type: ignore
274        # Set some intermediate values.
275        u2 = u * u  # type: ignore
276        v2 = v * v  # type: ignore
277        w2 = w * w  # type: ignore
278        cos_t = cos(theta)
279        sin_t = sin(theta)
280        l2 = u2 + v2 + w2  # type: ignore
281        l = sqrt(l2)  # type: ignore
282
283        # Build the matrix entries element by element.
284        m11 = (u2 + (v2 + w2) * cos_t) / l2  # type: ignore
285        m12 = (u * v * (1 - cos_t) - w * l * sin_t) / l2  # type: ignore
286        m13 = (u * w * (1 - cos_t) + v * l * sin_t) / l2  # type: ignore
287        m14 = (  # type: ignore
288            a * (v2 + w2)  # type: ignore
289            - u * (b * v + c * w)  # type: ignore
290            + (u * (b * v + c * w) - a * (v2 + w2)) * cos_t  # type: ignore
291            + (b * w - c * v) * l * sin_t  # type: ignore
292        ) / l2  # type: ignore
293
294        m21 = (u * v * (1 - cos_t) + w * l * sin_t) / l2  # type: ignore
295        m22 = (v2 + (u2 + w2) * cos_t) / l2  # type: ignore
296        m23 = (v * w * (1 - cos_t) - u * l * sin_t) / l2  # type: ignore
297        m24 = (  # type: ignore
298            b * (u2 + w2)  # type: ignore
299            - v * (a * u + c * w)  # type: ignore
300            + (v * (a * u + c * w) - b * (u2 + w2)) * cos_t  # type: ignore
301            + (c * u - a * w) * l * sin_t  # type: ignore
302        ) / l2  # type: ignore
303
304        m31 = (u * w * (1 - cos_t) - v * l * sin_t) / l2  # type: ignore
305        m32 = (v * w * (1 - cos_t) + u * l * sin_t) / l2  # type: ignore
306        m33 = (w2 + (u2 + v2) * cos_t) / l2  # type: ignore
307        m34 = (  # type: ignore
308            c * (u2 + v2)  # type: ignore
309            - w * (a * u + b * v)  # type: ignore
310            + (w * (a * u + b * v) - c * (u2 + v2)) * cos_t  # type: ignore
311            + (a * v - b * u) * l * sin_t  # type: ignore
312        ) / l2
313
314        return SymmOp(
315            [  # type: ignore
316                [m11, m12, m13, m14],
317                [m21, m22, m23, m24],
318                [m31, m32, m33, m34],
319                [0, 0, 0, 1],
320            ]
321        )
322
323    @staticmethod
324    def reflection(normal: ArrayLike, origin: ArrayLike = (0, 0, 0)) -> "SymmOp":
325        """
326        Returns reflection symmetry operation.
327
328        Args:
329            normal (3x1 array): Vector of the normal to the plane of
330                reflection.
331            origin (3x1 array): A point in which the mirror plane passes
332                through.
333
334        Returns:
335            SymmOp for the reflection about the plane
336        """
337        # Normalize the normal vector first.
338        n = np.array(normal, dtype=float) / np.linalg.norm(normal)
339
340        u, v, w = n
341
342        translation = np.eye(4)
343        translation[0:3, 3] = -np.array(origin)
344
345        xx = 1 - 2 * u ** 2
346        yy = 1 - 2 * v ** 2
347        zz = 1 - 2 * w ** 2
348        xy = -2 * u * v
349        xz = -2 * u * w
350        yz = -2 * v * w
351        mirror_mat = [[xx, xy, xz, 0], [xy, yy, yz, 0], [xz, yz, zz, 0], [0, 0, 0, 1]]
352
353        if np.linalg.norm(origin) > 1e-6:
354            mirror_mat = np.dot(np.linalg.inv(translation), np.dot(mirror_mat, translation))
355        return SymmOp(mirror_mat)
356
357    @staticmethod
358    def inversion(origin: ArrayLike = (0, 0, 0)) -> "SymmOp":
359        """
360        Inversion symmetry operation about axis.
361
362        Args:
363            origin (3x1 array): Origin of the inversion operation. Defaults
364                to [0, 0, 0].
365
366        Returns:
367            SymmOp representing an inversion operation about the origin.
368        """
369        mat = -np.eye(4)
370        mat[3, 3] = 1
371        mat[0:3, 3] = 2 * np.array(origin)
372        return SymmOp(mat)
373
374    @staticmethod
375    def rotoreflection(axis: ArrayLike, angle: float, origin: ArrayLike = (0, 0, 0)) -> "SymmOp":
376        """
377        Returns a roto-reflection symmetry operation
378
379        Args:
380            axis (3x1 array): Axis of rotation / mirror normal
381            angle (float): Angle in degrees
382            origin (3x1 array): Point left invariant by roto-reflection.
383                Defaults to (0, 0, 0).
384
385        Return:
386            Roto-reflection operation
387        """
388        rot = SymmOp.from_origin_axis_angle(origin, axis, angle)
389        refl = SymmOp.reflection(axis, origin)
390        m = np.dot(rot.affine_matrix, refl.affine_matrix)
391        return SymmOp(m)
392
393    def as_dict(self) -> dict:
394        """
395        :return: MSONAble dict.
396        """
397        return {
398            "@module": self.__class__.__module__,
399            "@class": self.__class__.__name__,
400            "matrix": self.affine_matrix.tolist(),
401            "tolerance": self.tol,
402        }
403
404    def as_xyz_string(self) -> str:
405        """
406        Returns a string of the form 'x, y, z', '-x, -y, z',
407        '-y+1/2, x+1/2, z+1/2', etc. Only works for integer rotation matrices
408        """
409        # test for invalid rotation matrix
410        if not np.all(np.isclose(self.rotation_matrix, np.round(self.rotation_matrix))):
411            warnings.warn("Rotation matrix should be integer")
412
413        return transformation_to_string(self.rotation_matrix, translation_vec=self.translation_vector, delim=", ")
414
415    @staticmethod
416    def from_xyz_string(xyz_string: str) -> "SymmOp":
417        """
418        Args:
419            xyz_string: string of the form 'x, y, z', '-x, -y, z',
420                '-2y+1/2, 3x+1/2, z-y+1/2', etc.
421        Returns:
422            SymmOp
423        """
424        rot_matrix = np.zeros((3, 3))
425        trans = np.zeros(3)
426        toks = xyz_string.strip().replace(" ", "").lower().split(",")
427        re_rot = re.compile(r"([+-]?)([\d\.]*)/?([\d\.]*)([x-z])")
428        re_trans = re.compile(r"([+-]?)([\d\.]+)/?([\d\.]*)(?![x-z])")
429        for i, tok in enumerate(toks):
430            # build the rotation matrix
431            for m in re_rot.finditer(tok):
432                factor = -1.0 if m.group(1) == "-" else 1.0
433                if m.group(2) != "":
434                    factor *= float(m.group(2)) / float(m.group(3)) if m.group(3) != "" else float(m.group(2))
435                j = ord(m.group(4)) - 120
436                rot_matrix[i, j] = factor
437            # build the translation vector
438            for m in re_trans.finditer(tok):
439                factor = -1 if m.group(1) == "-" else 1
440                num = float(m.group(2)) / float(m.group(3)) if m.group(3) != "" else float(m.group(2))
441                trans[i] = num * factor
442        return SymmOp.from_rotation_and_translation(rot_matrix, trans)
443
444    @classmethod
445    def from_dict(cls, d) -> "SymmOp":
446        """
447        :param d: dict
448        :return: SymmOp from dict representation.
449        """
450        return cls(d["matrix"], d["tolerance"])
451
452
453class MagSymmOp(SymmOp):
454    """
455    Thin wrapper around SymmOp to extend it to support magnetic symmetry
456    by including a  time reversal operator. Magnetic symmetry is similar
457    to conventional crystal symmetry, except symmetry is reduced by the
458    addition of a time reversal operator which acts on an atom's magnetic
459    moment.
460    """
461
462    def __init__(self, affine_transformation_matrix: ArrayLike, time_reversal: int, tol: float = 0.01):
463        """
464        Initializes the MagSymmOp from a 4x4 affine transformation matrix
465        and time reversal operator.
466        In general, this constructor should not be used unless you are
467        transferring rotations.  Use the static constructors instead to
468        generate a SymmOp from proper rotations and translation.
469
470        Args:
471            affine_transformation_matrix (4x4 array): Representing an
472                affine transformation.
473            time_reversal (int): 1 or -1
474            tol (float): Tolerance for determining if matrices are equal.
475        """
476        SymmOp.__init__(self, affine_transformation_matrix, tol=tol)
477        if time_reversal not in (-1, 1):
478            raise Exception(
479                "Time reversal operator not well defined: {0}, {1}".format(time_reversal, type(time_reversal))
480            )
481        self.time_reversal = time_reversal
482
483    def __eq__(self, other):
484        return np.allclose(self.affine_matrix, other.affine_matrix, atol=self.tol) and (
485            self.time_reversal == other.time_reversal
486        )
487
488    def __str__(self):
489        return self.as_xyzt_string()
490
491    def __repr__(self):
492        output = [
493            "Rot:",
494            str(self.affine_matrix[0:3][:, 0:3]),
495            "tau",
496            str(self.affine_matrix[0:3][:, 3]),
497            "Time reversal:",
498            str(self.time_reversal),
499        ]
500        return "\n".join(output)
501
502    def __hash__(self):
503        # useful for obtaining a set of unique MagSymmOps
504        hashable_value = tuple(self.affine_matrix.flatten()) + (self.time_reversal,)
505        return hashable_value.__hash__()
506
507    def operate_magmom(self, magmom):
508        """
509        Apply time reversal operator on the magnetic moment. Note that
510        magnetic moments transform as axial vectors, not polar vectors.
511
512        See 'Symmetry and magnetic structures', Rodríguez-Carvajal and
513        Bourée for a good discussion. DOI: 10.1051/epjconf/20122200010
514
515        Args:
516            magmom: Magnetic moment as electronic_structure.core.Magmom
517            class or as list or np array-like
518
519        Returns:
520            Magnetic moment after operator applied as Magmom class
521        """
522
523        magmom = Magmom(magmom)  # type casting to handle lists as input
524
525        transformed_moment = (
526            self.apply_rotation_only(magmom.global_moment) * np.linalg.det(self.rotation_matrix) * self.time_reversal
527        )
528
529        # retains input spin axis if different from default
530        return Magmom.from_global_moment_and_saxis(transformed_moment, magmom.saxis)
531
532    @classmethod
533    def from_symmop(cls, symmop, time_reversal) -> "MagSymmOp":
534        """
535        Initialize a MagSymmOp from a SymmOp and time reversal operator.
536
537        Args:
538            symmop (SymmOp): SymmOp
539            time_reversal (int): Time reversal operator, +1 or -1.
540
541        Returns:
542            MagSymmOp object
543        """
544        magsymmop = cls(symmop.affine_matrix, time_reversal, symmop.tol)
545        return magsymmop
546
547    @staticmethod
548    def from_rotation_and_translation_and_time_reversal(
549        rotation_matrix: ArrayLike = ((1, 0, 0), (0, 1, 0), (0, 0, 1)),
550        translation_vec: ArrayLike = (0, 0, 0),
551        time_reversal: int = 1,
552        tol: float = 0.1,
553    ) -> "MagSymmOp":
554        """
555        Creates a symmetry operation from a rotation matrix, translation
556        vector and time reversal operator.
557
558        Args:
559            rotation_matrix (3x3 array): Rotation matrix.
560            translation_vec (3x1 array): Translation vector.
561            time_reversal (int): Time reversal operator, +1 or -1.
562            tol (float): Tolerance to determine if rotation matrix is valid.
563
564        Returns:
565            MagSymmOp object
566        """
567        symmop = SymmOp.from_rotation_and_translation(
568            rotation_matrix=rotation_matrix, translation_vec=translation_vec, tol=tol
569        )
570        return MagSymmOp.from_symmop(symmop, time_reversal)
571
572    @staticmethod
573    def from_xyzt_string(xyzt_string: str) -> "MagSymmOp":
574        """
575        Args:
576            xyz_string: string of the form 'x, y, z, +1', '-x, -y, z, -1',
577                '-2y+1/2, 3x+1/2, z-y+1/2, +1', etc.
578        Returns:
579            MagSymmOp object
580        """
581        symmop = SymmOp.from_xyz_string(xyzt_string.rsplit(",", 1)[0])
582        try:
583            time_reversal = int(xyzt_string.rsplit(",", 1)[1])
584        except Exception:
585            raise Exception("Time reversal operator could not be parsed.")
586        return MagSymmOp.from_symmop(symmop, time_reversal)
587
588    def as_xyzt_string(self) -> str:
589        """
590        Returns a string of the form 'x, y, z, +1', '-x, -y, z, -1',
591        '-y+1/2, x+1/2, z+1/2, +1', etc. Only works for integer rotation matrices
592        """
593        xyzt_string = SymmOp.as_xyz_string(self)
594        return xyzt_string + ", {:+}".format(self.time_reversal)
595
596    def as_dict(self) -> dict:
597        """
598        :return: MSONABle dict
599        """
600        return {
601            "@module": self.__class__.__module__,
602            "@class": self.__class__.__name__,
603            "matrix": self.affine_matrix.tolist(),
604            "tolerance": self.tol,
605            "time_reversal": self.time_reversal,
606        }
607
608    @classmethod
609    def from_dict(cls, d: dict) -> "MagSymmOp":
610        """
611        :param d: dict
612        :return: MagneticSymmOp from dict representation.
613        """
614        return cls(d["matrix"], tol=d["tolerance"], time_reversal=d["time_reversal"])
615