1# -*- coding: utf-8 -*-
2# Licensed under a 3-clause BSD style license - see LICENSE.rst
3
4"""
5Utililies used for constructing and inspecting rotation matrices.
6"""
7from functools import reduce
8import numpy as np
9
10from astropy import units as u
11from .angles import Angle
12
13
14def matrix_product(*matrices):
15    """Matrix multiply all arguments together.
16
17    Arguments should have dimension 2 or larger. Larger dimensional objects
18    are interpreted as stacks of matrices residing in the last two dimensions.
19
20    This function mostly exists for readability: using `~numpy.matmul`
21    directly, one would have ``matmul(matmul(m1, m2), m3)``, etc. For even
22    better readability, one might consider using `~numpy.matrix` for the
23    arguments (so that one could write ``m1 * m2 * m3``), but then it is not
24    possible to handle stacks of matrices. Once only python >=3.5 is supported,
25    this function can be replaced by ``m1 @ m2 @ m3``.
26    """
27    return reduce(np.matmul, matrices)
28
29
30def matrix_transpose(matrix):
31    """Transpose a matrix or stack of matrices by swapping the last two axes.
32
33    This function mostly exists for readability; seeing ``.swapaxes(-2, -1)``
34    it is not that obvious that one does a transpose.  Note that one cannot
35    use `~numpy.ndarray.T`, as this transposes all axes and thus does not
36    work for stacks of matrices.
37    """
38    return matrix.swapaxes(-2, -1)
39
40
41def rotation_matrix(angle, axis='z', unit=None):
42    """
43    Generate matrices for rotation by some angle around some axis.
44
45    Parameters
46    ----------
47    angle : angle-like
48        The amount of rotation the matrices should represent.  Can be an array.
49    axis : str or array-like
50        Either ``'x'``, ``'y'``, ``'z'``, or a (x,y,z) specifying the axis to
51        rotate about. If ``'x'``, ``'y'``, or ``'z'``, the rotation sense is
52        counterclockwise looking down the + axis (e.g. positive rotations obey
53        left-hand-rule).  If given as an array, the last dimension should be 3;
54        it will be broadcast against ``angle``.
55    unit : unit-like, optional
56        If ``angle`` does not have associated units, they are in this
57        unit.  If neither are provided, it is assumed to be degrees.
58
59    Returns
60    -------
61    rmat : `numpy.matrix`
62        A unitary rotation matrix.
63    """
64    if isinstance(angle, u.Quantity):
65        angle = angle.to_value(u.radian)
66    else:
67        if unit is None:
68            angle = np.deg2rad(angle)
69        else:
70            angle = u.Unit(unit).to(u.rad, angle)
71
72    s = np.sin(angle)
73    c = np.cos(angle)
74
75    # use optimized implementations for x/y/z
76    try:
77        i = 'xyz'.index(axis)
78    except TypeError:
79        axis = np.asarray(axis)
80        axis = axis / np.sqrt((axis * axis).sum(axis=-1, keepdims=True))
81        R = (axis[..., np.newaxis] * axis[..., np.newaxis, :] *
82             (1. - c)[..., np.newaxis, np.newaxis])
83
84        for i in range(0, 3):
85            R[..., i, i] += c
86            a1 = (i + 1) % 3
87            a2 = (i + 2) % 3
88            R[..., a1, a2] += axis[..., i] * s
89            R[..., a2, a1] -= axis[..., i] * s
90
91    else:
92        a1 = (i + 1) % 3
93        a2 = (i + 2) % 3
94        R = np.zeros(getattr(angle, 'shape', ()) + (3, 3))
95        R[..., i, i] = 1.
96        R[..., a1, a1] = c
97        R[..., a1, a2] = s
98        R[..., a2, a1] = -s
99        R[..., a2, a2] = c
100
101    return R
102
103
104def angle_axis(matrix):
105    """
106    Angle of rotation and rotation axis for a given rotation matrix.
107
108    Parameters
109    ----------
110    matrix : array-like
111        A 3 x 3 unitary rotation matrix (or stack of matrices).
112
113    Returns
114    -------
115    angle : `~astropy.coordinates.Angle`
116        The angle of rotation.
117    axis : array
118        The (normalized) axis of rotation (with last dimension 3).
119    """
120    m = np.asanyarray(matrix)
121    if m.shape[-2:] != (3, 3):
122        raise ValueError('matrix is not 3x3')
123
124    axis = np.zeros(m.shape[:-1])
125    axis[..., 0] = m[..., 2, 1] - m[..., 1, 2]
126    axis[..., 1] = m[..., 0, 2] - m[..., 2, 0]
127    axis[..., 2] = m[..., 1, 0] - m[..., 0, 1]
128    r = np.sqrt((axis * axis).sum(-1, keepdims=True))
129    angle = np.arctan2(r[..., 0],
130                       m[..., 0, 0] + m[..., 1, 1] + m[..., 2, 2] - 1.)
131    return Angle(angle, u.radian), -axis / r
132
133
134def is_O3(matrix):
135    """Check whether a matrix is in the length-preserving group O(3).
136
137    Parameters
138    ----------
139    matrix : (..., N, N) array-like
140        Must have attribute ``.shape`` and method ``.swapaxes()`` and not error
141        when using `~numpy.isclose`.
142
143    Returns
144    -------
145    is_o3 : bool or array of bool
146        If the matrix has more than two axes, the O(3) check is performed on
147        slices along the last two axes -- (M, N, N) => (M, ) bool array.
148
149    Notes
150    -----
151    The orthogonal group O(3) preserves lengths, but is not guaranteed to keep
152    orientations. Rotations and reflections are in this group.
153    For more information, see https://en.wikipedia.org/wiki/Orthogonal_group
154
155    """
156    # matrix is in O(3) (rotations, proper and improper).
157    I = np.identity(matrix.shape[-1])
158    is_o3 = np.all(np.isclose(matrix @ matrix.swapaxes(-2, -1), I, atol=1e-15),
159                   axis=(-2, -1))
160
161    return is_o3
162
163
164def is_rotation(matrix, allow_improper=False):
165    """Check whether a matrix is a rotation, proper or improper.
166
167    Parameters
168    ----------
169    matrix : (..., N, N) array-like
170        Must have attribute ``.shape`` and method ``.swapaxes()`` and not error
171        when using `~numpy.isclose` and `~numpy.linalg.det`.
172    allow_improper : bool, optional
173        Whether to restrict check to the SO(3), the group of proper rotations,
174        or also allow improper rotations (with determinant -1).
175        The default (False) is only SO(3).
176
177    Returns
178    -------
179    isrot : bool or array of bool
180        If the matrix has more than two axes, the checks are performed on
181        slices along the last two axes -- (M, N, N) => (M, ) bool array.
182
183    See Also
184    --------
185    `~astopy.coordinates.matrix_utilities.is_O3`
186        For the less restrictive check that a matrix is in the group O(3).
187
188    Notes
189    -----
190    The group SO(3) is the rotation group. It is O(3), with determinant 1.
191    Rotations with determinant -1 are improper rotations, combining both a
192    rotation and a reflection.
193    For more information, see https://en.wikipedia.org/wiki/Orthogonal_group
194
195    """
196    # matrix is in O(3).
197    is_o3 = is_O3(matrix)
198
199    # determinant checks  for rotation (proper and improper)
200    if allow_improper:  # determinant can be +/- 1
201        is_det1 = np.isclose(np.abs(np.linalg.det(matrix)), 1.0)
202    else:  # restrict to SO(3)
203        is_det1 = np.isclose(np.linalg.det(matrix), 1.0)
204
205    return is_o3 & is_det1
206