1# -*- coding: utf-8 -*- 2# Licensed under a 3-clause BSD style license - see LICENSE.rst 3import numpy as np 4from numpy.testing import assert_allclose, assert_array_equal 5 6from astropy import units as u 7from astropy.coordinates.matrix_utilities import (rotation_matrix, angle_axis, 8 is_O3, is_rotation) 9 10 11def test_rotation_matrix(): 12 assert_array_equal(rotation_matrix(0*u.deg, 'x'), np.eye(3)) 13 14 assert_allclose(rotation_matrix(90*u.deg, 'y'), [[0, 0, -1], 15 [0, 1, 0], 16 [1, 0, 0]], atol=1e-12) 17 18 assert_allclose(rotation_matrix(-90*u.deg, 'z'), [[0, -1, 0], 19 [1, 0, 0], 20 [0, 0, 1]], atol=1e-12) 21 22 assert_allclose(rotation_matrix(45*u.deg, 'x'), 23 rotation_matrix(45*u.deg, [1, 0, 0])) 24 assert_allclose(rotation_matrix(125*u.deg, 'y'), 25 rotation_matrix(125*u.deg, [0, 1, 0])) 26 assert_allclose(rotation_matrix(-30*u.deg, 'z'), 27 rotation_matrix(-30*u.deg, [0, 0, 1])) 28 29 assert_allclose(np.dot(rotation_matrix(180*u.deg, [1, 1, 0]), [1, 0, 0]), 30 [0, 1, 0], atol=1e-12) 31 32 # make sure it also works for very small angles 33 assert_allclose(rotation_matrix(0.000001*u.deg, 'x'), 34 rotation_matrix(0.000001*u.deg, [1, 0, 0])) 35 36 37def test_angle_axis(): 38 m1 = rotation_matrix(35*u.deg, 'x') 39 an1, ax1 = angle_axis(m1) 40 41 assert an1 - 35*u.deg < 1e-10*u.deg 42 assert_allclose(ax1, [1, 0, 0]) 43 44 m2 = rotation_matrix(-89*u.deg, [1, 1, 0]) 45 an2, ax2 = angle_axis(m2) 46 47 assert an2 - 89*u.deg < 1e-10*u.deg 48 assert_allclose(ax2, [-2**-0.5, -2**-0.5, 0]) 49 50 51def test_is_O3(): 52 """Test the matrix checker ``is_O3``.""" 53 # Normal rotation matrix 54 m1 = rotation_matrix(35*u.deg, 'x') 55 assert is_O3(m1) 56 # and (M, 3, 3) 57 n1 = np.tile(m1, (2, 1, 1)) 58 assert tuple(is_O3(n1)) == (True, True) # (show the broadcasting) 59 60 # reflection 61 m2 = m1.copy() 62 m2[0,0] *= -1 63 assert is_O3(m2) 64 # and (M, 3, 3) 65 n2 = np.stack((m1, m2)) 66 assert tuple(is_O3(n2)) == (True, True) # (show the broadcasting) 67 68 # Not any sort of O(3) 69 m3 = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) 70 assert not is_O3(m3) 71 # and (M, 3, 3) 72 n3 = np.stack((m1, m3)) 73 assert tuple(is_O3(n3)) == (True, False) # (show the broadcasting) 74 75 76def test_is_rotation(): 77 """Test the rotation matrix checker ``is_rotation``.""" 78 # Normal rotation matrix 79 m1 = rotation_matrix(35*u.deg, 'x') 80 assert is_rotation(m1) 81 assert is_rotation(m1, allow_improper=True) # (a less restrictive test) 82 # and (M, 3, 3) 83 n1 = np.tile(m1, (2, 1, 1)) 84 assert tuple(is_rotation(n1)) == (True, True) # (show the broadcasting) 85 86 # Improper rotation (unit rotation + reflection) 87 m2 = np.identity(3) 88 m2[0,0] = -1 89 assert not is_rotation(m2) 90 assert is_rotation(m2, allow_improper=True) 91 # and (M, 3, 3) 92 n2 = np.stack((m1, m2)) 93 assert tuple(is_rotation(n2)) == (True, False) # (show the broadcasting) 94 95 # Not any sort of rotation 96 m3 = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) 97 assert not is_rotation(m3) 98 assert not is_rotation(m3, allow_improper=True) 99 # and (M, 3, 3) 100 n3 = np.stack((m1, m3)) 101 assert tuple(is_rotation(n3)) == (True, False) # (show the broadcasting) 102