1# -*- coding: utf-8 -*-
2# Licensed under a 3-clause BSD style license - see LICENSE.rst
3import numpy as np
4from numpy.testing import assert_allclose, assert_array_equal
5
6from astropy import units as u
7from astropy.coordinates.matrix_utilities import (rotation_matrix, angle_axis,
8                                                  is_O3, is_rotation)
9
10
11def test_rotation_matrix():
12    assert_array_equal(rotation_matrix(0*u.deg, 'x'), np.eye(3))
13
14    assert_allclose(rotation_matrix(90*u.deg, 'y'), [[0, 0, -1],
15                                                     [0, 1, 0],
16                                                     [1, 0, 0]], atol=1e-12)
17
18    assert_allclose(rotation_matrix(-90*u.deg, 'z'), [[0, -1, 0],
19                                                      [1, 0, 0],
20                                                      [0, 0, 1]], atol=1e-12)
21
22    assert_allclose(rotation_matrix(45*u.deg, 'x'),
23                    rotation_matrix(45*u.deg, [1, 0, 0]))
24    assert_allclose(rotation_matrix(125*u.deg, 'y'),
25                    rotation_matrix(125*u.deg, [0, 1, 0]))
26    assert_allclose(rotation_matrix(-30*u.deg, 'z'),
27                    rotation_matrix(-30*u.deg, [0, 0, 1]))
28
29    assert_allclose(np.dot(rotation_matrix(180*u.deg, [1, 1, 0]), [1, 0, 0]),
30                    [0, 1, 0], atol=1e-12)
31
32    # make sure it also works for very small angles
33    assert_allclose(rotation_matrix(0.000001*u.deg, 'x'),
34                    rotation_matrix(0.000001*u.deg, [1, 0, 0]))
35
36
37def test_angle_axis():
38    m1 = rotation_matrix(35*u.deg, 'x')
39    an1, ax1 = angle_axis(m1)
40
41    assert an1 - 35*u.deg < 1e-10*u.deg
42    assert_allclose(ax1, [1, 0, 0])
43
44    m2 = rotation_matrix(-89*u.deg, [1, 1, 0])
45    an2, ax2 = angle_axis(m2)
46
47    assert an2 - 89*u.deg < 1e-10*u.deg
48    assert_allclose(ax2, [-2**-0.5, -2**-0.5, 0])
49
50
51def test_is_O3():
52    """Test the matrix checker ``is_O3``."""
53    # Normal rotation matrix
54    m1 = rotation_matrix(35*u.deg, 'x')
55    assert is_O3(m1)
56    # and (M, 3, 3)
57    n1 = np.tile(m1, (2, 1, 1))
58    assert tuple(is_O3(n1)) == (True, True)  # (show the broadcasting)
59
60    # reflection
61    m2 = m1.copy()
62    m2[0,0] *= -1
63    assert is_O3(m2)
64    # and (M, 3, 3)
65    n2 = np.stack((m1, m2))
66    assert tuple(is_O3(n2)) == (True, True)  # (show the broadcasting)
67
68    # Not any sort of O(3)
69    m3 = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
70    assert not is_O3(m3)
71    # and (M, 3, 3)
72    n3 = np.stack((m1, m3))
73    assert tuple(is_O3(n3)) == (True, False)  # (show the broadcasting)
74
75
76def test_is_rotation():
77    """Test the rotation matrix checker ``is_rotation``."""
78    # Normal rotation matrix
79    m1 = rotation_matrix(35*u.deg, 'x')
80    assert is_rotation(m1)
81    assert is_rotation(m1, allow_improper=True)  # (a less restrictive test)
82    # and (M, 3, 3)
83    n1 = np.tile(m1, (2, 1, 1))
84    assert tuple(is_rotation(n1)) == (True, True)  # (show the broadcasting)
85
86    # Improper rotation (unit rotation + reflection)
87    m2 = np.identity(3)
88    m2[0,0] = -1
89    assert not is_rotation(m2)
90    assert is_rotation(m2, allow_improper=True)
91    # and (M, 3, 3)
92    n2 = np.stack((m1, m2))
93    assert tuple(is_rotation(n2)) == (True, False)  # (show the broadcasting)
94
95    # Not any sort of rotation
96    m3 = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
97    assert not is_rotation(m3)
98    assert not is_rotation(m3, allow_improper=True)
99    # and (M, 3, 3)
100    n3 = np.stack((m1, m3))
101    assert tuple(is_rotation(n3)) == (True, False)  # (show the broadcasting)
102