1# Licensed under a 3-clause BSD style license - see LICENSE.rst
2# pylint: disable=invalid-name
3import pytest
4import numpy as np
5
6from astropy.convolution import convolve_models_fft
7from astropy.modeling.models import Const1D, Const2D
8
9try:
10    import scipy  # pylint: disable=W0611 # noqa
11except ImportError:
12    HAS_SCIPY = False
13else:
14    HAS_SCIPY = True
15
16
17@pytest.mark.skipif('not HAS_SCIPY')
18def test_clear_cache():
19    m1 = Const1D()
20    m2 = Const1D()
21
22    model = convolve_models_fft(m1, m2, (-1, 1), 0.01)
23    assert model._kwargs is None
24    assert model._convolution is None
25
26    results = model(0)
27    assert results.all() == np.array([1.]).all()
28    assert model._kwargs is not None
29    assert model._convolution is not None
30
31    model.clear_cache()
32    assert model._kwargs is None
33    assert model._convolution is None
34
35
36@pytest.mark.skipif('not HAS_SCIPY')
37def test_input_shape_1d():
38    m1 = Const1D()
39    m2 = Const1D()
40
41    model = convolve_models_fft(m1, m2, (-1, 1), 0.01)
42
43    results = model(0)
44    assert results.shape == (1,)
45
46    x = np.arange(-1, 1, 0.1)
47    results = model(x)
48    assert results.shape == x.shape
49
50
51@pytest.mark.skipif('not HAS_SCIPY')
52def test_input_shape_2d():
53    m1 = Const2D()
54    m2 = Const2D()
55
56    model = convolve_models_fft(m1, m2, ((-1, 1), (-1, 1)), 0.01)
57
58    results = model(0, 0)
59    assert results.shape == (1,)
60
61    x = np.arange(-1, 1, 0.1)
62    results = model(x, 0)
63    assert results.shape == x.shape
64    results = model(0, x)
65    assert results.shape == x.shape
66
67    grid = np.meshgrid(x, x)
68    results = model(*grid)
69    assert results.shape == grid[0].shape
70    assert results.shape == grid[1].shape
71
72
73@pytest.mark.skipif('not HAS_SCIPY')
74def test__convolution_inputs():
75    m1 = Const2D()
76    m2 = Const2D()
77
78    model = convolve_models_fft(m1, m2, ((-1, 1), (-1, 1)), 0.01)
79
80    x = np.arange(-1, 1, 0.1)
81    y = np.arange(-2, 2, 0.1)
82    grid0 = np.meshgrid(x, x)
83    grid1 = np.meshgrid(y, y)
84
85    # scalar inputs
86    assert (np.array([1]), (1,)) == model._convolution_inputs(1)
87
88    # Multiple inputs
89    assert np.all(model._convolution_inputs(*grid0)[0] ==
90                  np.reshape([grid0[0], grid0[1]], (2, -1)).T)
91    assert model._convolution_inputs(*grid0)[1] == grid0[0].shape
92    assert np.all(model._convolution_inputs(*grid1)[0] ==
93                  np.reshape([grid1[0], grid1[1]], (2, -1)).T)
94    assert model._convolution_inputs(*grid1)[1] == grid1[0].shape
95
96    # Error
97    with pytest.raises(ValueError) as err:
98        model._convolution_inputs(grid0[0], grid1[1])
99    assert str(err.value) ==\
100        "Values have differing shapes"
101