1# Licensed under a 3-clause BSD style license - see LICENSE.rst
2# pylint: disable=invalid-name, no-member
3
4import pytest
5import numpy as np
6from numpy.testing import assert_allclose
7
8from astropy.wcs import wcs
9from astropy.modeling import models
10from astropy import units as u
11from astropy.tests.helper import assert_quantity_allclose
12
13
14@pytest.mark.parametrize(('inp'), [(0, 0), (4000, -20.56), (-2001.5, 45.9),
15                                   (0, 90), (0, -90), (np.mgrid[:4, :6])])
16def test_against_wcslib(inp):
17    w = wcs.WCS()
18    crval = [202.4823228, 47.17511893]
19    w.wcs.crval = crval
20    w.wcs.ctype = ['RA---TAN', 'DEC--TAN']
21
22    lonpole = 180
23    tan = models.Pix2Sky_TAN()
24    n2c = models.RotateNative2Celestial(crval[0] * u.deg, crval[1] * u.deg, lonpole * u.deg)
25    c2n = models.RotateCelestial2Native(crval[0] * u.deg, crval[1] * u.deg, lonpole * u.deg)
26    m = tan | n2c
27    minv = c2n | tan.inverse
28
29    radec = w.wcs_pix2world(inp[0], inp[1], 1)
30    xy = w.wcs_world2pix(radec[0], radec[1], 1)
31
32    assert_allclose(m(*inp), radec, atol=1e-12)
33    assert_allclose(minv(*radec), xy, atol=1e-12)
34
35
36@pytest.mark.parametrize(('inp'), [(40 * u.deg, -0.057 * u.rad), (21.5 * u.arcsec, 45.9 * u.deg)])
37def test_roundtrip_sky_rotation(inp):
38    lon, lat, lon_pole = 42 * u.deg, (43 * u.deg).to(u.arcsec), (44 * u.deg).to(u.rad)
39    n2c = models.RotateNative2Celestial(lon, lat, lon_pole)
40    c2n = models.RotateCelestial2Native(lon, lat, lon_pole)
41    assert_quantity_allclose(n2c.inverse(*n2c(*inp)), inp, atol=1e-13 * u.deg)
42    assert_quantity_allclose(c2n.inverse(*c2n(*inp)), inp, atol=1e-13 * u.deg)
43
44
45def test_Rotation2D():
46    model = models.Rotation2D(angle=90 * u.deg)
47    a, b = 1 * u.deg, 0 * u.deg
48    x, y = model(a, b)
49    assert_quantity_allclose([x, y], [0 * u.deg, 1 * u.deg], atol=1e-10 * u.deg)
50
51
52def test_Rotation2D_inverse():
53    model = models.Rotation2D(angle=234.23494 * u.deg)
54    x, y = model.inverse(*model(1 * u.deg, 0 * u.deg))
55    assert_quantity_allclose([x, y], [1 * u.deg, 0 * u.deg], atol=1e-10 * u.deg)
56
57
58def test_euler_angle_rotations():
59    ydeg = (90 * u.deg, 0 * u.deg)
60    y = (90, 0)
61    z = (0, 90)
62
63    # rotate y into minus z
64    model = models.EulerAngleRotation(0 * u.rad, np.pi / 2 * u.rad, 0 * u.rad, 'zxz')
65    assert_allclose(model(*z), y, atol=10**-12)
66    model = models.EulerAngleRotation(0 * u.deg, 90 * u.deg, 0 * u.deg, 'zxz')
67    assert_quantity_allclose(model(*(z * u.deg)), ydeg, atol=10**-12 * u.deg)
68
69
70@pytest.mark.parametrize(('params'), [(60, 10, 25),
71                                      (60 * u.deg, 10 * u.deg, 25 * u.deg),
72                                      ((60 * u.deg).to(u.rad),
73                                       (10 * u.deg).to(u.rad),
74                                       (25 * u.deg).to(u.rad))])
75def test_euler_rotations_with_units(params):
76    x = 1 * u.deg
77    y = 1 * u.deg
78    phi, theta, psi = params
79
80    urot = models.EulerAngleRotation(phi, theta, psi, axes_order='xyz')
81    a, b = urot(x.value, y.value)
82    assert_allclose((a, b), (-23.614457631192547, 9.631254579686113))
83    a, b = urot(x, y)
84    assert_quantity_allclose((a, b), (-23.614457631192547 * u.deg, 9.631254579686113 * u.deg))
85    a, b = urot(x.to(u.rad), y.to(u.rad))
86    assert_quantity_allclose((a, b), (-23.614457631192547 * u.deg, 9.631254579686113 * u.deg))
87
88
89def test_attributes():
90    n2c = models.RotateNative2Celestial(20016 * u.arcsec, -72.3 * u.deg, np.pi * u.rad)
91    assert_allclose(n2c.lat.value, -72.3)
92    assert_allclose(n2c.lat._raw_value, -1.2618730491919001)
93    assert_allclose(n2c.lon.value, 20016)
94    assert_allclose(n2c.lon._raw_value, 0.09704030641088472)
95    assert_allclose(n2c.lon_pole.value, np.pi)
96    assert_allclose(n2c.lon_pole._raw_value, np.pi)
97    assert n2c.lon.unit is u.Unit("arcsec")
98    assert n2c.lon.internal_unit is u.Unit("rad")
99    assert n2c.lat.unit is u.Unit("deg")
100    assert n2c.lat.internal_unit is u.Unit("rad")
101    assert n2c.lon_pole.unit is u.Unit("rad")
102    assert n2c.lon_pole.internal_unit is u.Unit("rad")
103