1# Licensed under a 3-clause BSD style license - see LICENSE.rst
2
3import itertools
4
5import pytest
6import numpy as np
7from numpy.testing import assert_allclose
8
9from astropy.convolution.utils import discretize_model
10from astropy.modeling.functional_models import (
11    Gaussian1D, Box1D, RickerWavelet1D, Gaussian2D, Box2D, RickerWavelet2D)
12from astropy.modeling.tests.example_models import models_1D, models_2D
13from astropy.modeling.tests.test_models import create_model
14from astropy.utils.compat.optional_deps import HAS_SCIPY  # noqa
15
16
17modes = ['center', 'linear_interp', 'oversample']
18test_models_1D = [Gaussian1D, Box1D, RickerWavelet1D]
19test_models_2D = [Gaussian2D, Box2D, RickerWavelet2D]
20
21
22@pytest.mark.parametrize(('model_class', 'mode'), list(itertools.product(test_models_1D, modes)))
23def test_pixel_sum_1D(model_class, mode):
24    """
25    Test if the sum of all pixels corresponds nearly to the integral.
26    """
27    if model_class == Box1D and mode == "center":
28        pytest.skip("Non integrating mode. Skip integral test.")
29    parameters = models_1D[model_class]
30    model = create_model(model_class, parameters)
31
32    values = discretize_model(model, models_1D[model_class]['x_lim'], mode=mode)
33    assert_allclose(values.sum(), models_1D[model_class]['integral'], atol=0.0001)
34
35
36@pytest.mark.parametrize('mode', modes)
37def test_gaussian_eval_1D(mode):
38    """
39    Discretize Gaussian with different modes and check
40    if result is at least similar to Gaussian1D.eval().
41    """
42    model = Gaussian1D(1, 0, 20)
43    x = np.arange(-100, 101)
44    values = model(x)
45    disc_values = discretize_model(model, (-100, 101), mode=mode)
46    assert_allclose(values, disc_values, atol=0.001)
47
48
49@pytest.mark.parametrize(('model_class', 'mode'), list(itertools.product(test_models_2D, modes)))
50def test_pixel_sum_2D(model_class, mode):
51    """
52    Test if the sum of all pixels corresponds nearly to the integral.
53    """
54    if model_class == Box2D and mode == "center":
55        pytest.skip("Non integrating mode. Skip integral test.")
56
57    parameters = models_2D[model_class]
58    model = create_model(model_class, parameters)
59
60    values = discretize_model(model, models_2D[model_class]['x_lim'],
61                              models_2D[model_class]['y_lim'], mode=mode)
62    assert_allclose(values.sum(), models_2D[model_class]['integral'], atol=0.0001)
63
64
65@pytest.mark.parametrize('mode', modes)
66def test_gaussian_eval_2D(mode):
67    """
68    Discretize Gaussian with different modes and check
69    if result is at least similar to Gaussian2D.eval()
70    """
71    model = Gaussian2D(0.01, 0, 0, 1, 1)
72
73    x = np.arange(-2, 3)
74    y = np.arange(-2, 3)
75
76    x, y = np.meshgrid(x, y)
77
78    values = model(x, y)
79    disc_values = discretize_model(model, (-2, 3), (-2, 3), mode=mode)
80    assert_allclose(values, disc_values, atol=1e-2)
81
82
83@pytest.mark.skipif('not HAS_SCIPY')
84def test_gaussian_eval_2D_integrate_mode():
85    """
86    Discretize Gaussian with integrate mode
87    """
88    model_list = [Gaussian2D(.01, 0, 0, 2, 2),
89                  Gaussian2D(.01, 0, 0, 1, 2),
90                  Gaussian2D(.01, 0, 0, 2, 1)]
91
92    x = np.arange(-2, 3)
93    y = np.arange(-2, 3)
94
95    x, y = np.meshgrid(x, y)
96
97    for model in model_list:
98        values = model(x, y)
99        disc_values = discretize_model(model, (-2, 3), (-2, 3), mode='integrate')
100        assert_allclose(values, disc_values, atol=1e-2)
101
102
103@pytest.mark.skipif('not HAS_SCIPY')
104def test_subpixel_gauss_1D():
105    """
106    Test subpixel accuracy of the integrate mode with gaussian 1D model.
107    """
108    gauss_1D = Gaussian1D(1, 0, 0.1)
109    values = discretize_model(gauss_1D, (-1, 2), mode='integrate', factor=100)
110    assert_allclose(values.sum(), np.sqrt(2 * np.pi) * 0.1, atol=0.00001)
111
112
113@pytest.mark.skipif('not HAS_SCIPY')
114def test_subpixel_gauss_2D():
115    """
116    Test subpixel accuracy of the integrate mode with gaussian 2D model.
117    """
118    gauss_2D = Gaussian2D(1, 0, 0, 0.1, 0.1)
119    values = discretize_model(gauss_2D, (-1, 2), (-1, 2), mode='integrate', factor=100)
120    assert_allclose(values.sum(), 2 * np.pi * 0.01, atol=0.00001)
121
122
123def test_discretize_callable_1d():
124    """
125    Test discretize when a 1d function is passed.
126    """
127    def f(x):
128        return x ** 2
129    y = discretize_model(f, (-5, 6))
130    assert_allclose(y, np.arange(-5, 6) ** 2)
131
132
133def test_discretize_callable_2d():
134    """
135    Test discretize when a 2d function is passed.
136    """
137    def f(x, y):
138        return x ** 2 + y ** 2
139    actual = discretize_model(f, (-5, 6), (-5, 6))
140    y, x = (np.indices((11, 11)) - 5)
141    desired = x ** 2 + y ** 2
142    assert_allclose(actual, desired)
143
144
145def test_type_exception():
146    """
147    Test type exception.
148    """
149    with pytest.raises(TypeError) as exc:
150        discretize_model(float(0), (-10, 11))
151    assert exc.value.args[0] == 'Model must be callable.'
152
153
154def test_dim_exception_1d():
155    """
156    Test dimension exception 1d.
157    """
158    def f(x):
159        return x ** 2
160    with pytest.raises(ValueError) as exc:
161        discretize_model(f, (-10, 11), (-10, 11))
162    assert exc.value.args[0] == "y range specified, but model is only 1-d."
163
164
165def test_dim_exception_2d():
166    """
167    Test dimension exception 2d.
168    """
169    def f(x, y):
170        return x ** 2 + y ** 2
171    with pytest.raises(ValueError) as exc:
172        discretize_model(f, (-10, 11))
173    assert exc.value.args[0] == "y range not specified, but model is 2-d"
174
175
176def test_float_x_range_exception():
177    def f(x, y):
178        return x ** 2 + y ** 2
179    with pytest.raises(ValueError) as exc:
180        discretize_model(f, (-10.002, 11.23))
181    assert exc.value.args[0] == ("The difference between the upper and lower"
182                                 " limit of 'x_range' must be a whole number.")
183
184
185def test_float_y_range_exception():
186    def f(x, y):
187        return x ** 2 + y ** 2
188    with pytest.raises(ValueError) as exc:
189        discretize_model(f, (-10, 11), (-10.002, 11.23))
190    assert exc.value.args[0] == ("The difference between the upper and lower"
191                                 " limit of 'y_range' must be a whole number.")
192
193
194def test_discretize_oversample():
195    gauss_2D = Gaussian2D(amplitude=1.0, x_mean=5.,
196                    y_mean=125., x_stddev=0.75, y_stddev=3)
197    values = discretize_model(gauss_2D,
198                 x_range=[0, 10],
199                 y_range=[100, 135],
200                 mode='oversample', factor=10)
201    vmax = np.max(values)
202    vmax_yx = np.unravel_index(values.argmax(), values.shape)
203    values_osf1 = discretize_model(gauss_2D,
204                                   x_range=[0, 10],
205                                   y_range=[100, 135],
206                                   mode='oversample', factor=1)
207    values_center = discretize_model(gauss_2D,
208                                   x_range=[0, 10],
209                                   y_range=[100, 135],
210                                   mode = 'center')
211    assert values.shape == (35, 10)
212    assert_allclose(vmax, 0.927, atol=1e-3)
213    assert vmax_yx == (25, 5)
214    assert_allclose(values_center, values_osf1)
215