1# Licensed under a 3-clause BSD style license - see LICENSE.rst 2# pylint: disable=invalid-name 3import pytest 4import numpy as np 5 6from astropy.convolution import convolve_models_fft 7from astropy.modeling.models import Const1D, Const2D 8 9try: 10 import scipy # pylint: disable=W0611 # noqa 11except ImportError: 12 HAS_SCIPY = False 13else: 14 HAS_SCIPY = True 15 16 17@pytest.mark.skipif('not HAS_SCIPY') 18def test_clear_cache(): 19 m1 = Const1D() 20 m2 = Const1D() 21 22 model = convolve_models_fft(m1, m2, (-1, 1), 0.01) 23 assert model._kwargs is None 24 assert model._convolution is None 25 26 results = model(0) 27 assert results.all() == np.array([1.]).all() 28 assert model._kwargs is not None 29 assert model._convolution is not None 30 31 model.clear_cache() 32 assert model._kwargs is None 33 assert model._convolution is None 34 35 36@pytest.mark.skipif('not HAS_SCIPY') 37def test_input_shape_1d(): 38 m1 = Const1D() 39 m2 = Const1D() 40 41 model = convolve_models_fft(m1, m2, (-1, 1), 0.01) 42 43 results = model(0) 44 assert results.shape == (1,) 45 46 x = np.arange(-1, 1, 0.1) 47 results = model(x) 48 assert results.shape == x.shape 49 50 51@pytest.mark.skipif('not HAS_SCIPY') 52def test_input_shape_2d(): 53 m1 = Const2D() 54 m2 = Const2D() 55 56 model = convolve_models_fft(m1, m2, ((-1, 1), (-1, 1)), 0.01) 57 58 results = model(0, 0) 59 assert results.shape == (1,) 60 61 x = np.arange(-1, 1, 0.1) 62 results = model(x, 0) 63 assert results.shape == x.shape 64 results = model(0, x) 65 assert results.shape == x.shape 66 67 grid = np.meshgrid(x, x) 68 results = model(*grid) 69 assert results.shape == grid[0].shape 70 assert results.shape == grid[1].shape 71 72 73@pytest.mark.skipif('not HAS_SCIPY') 74def test__convolution_inputs(): 75 m1 = Const2D() 76 m2 = Const2D() 77 78 model = convolve_models_fft(m1, m2, ((-1, 1), (-1, 1)), 0.01) 79 80 x = np.arange(-1, 1, 0.1) 81 y = np.arange(-2, 2, 0.1) 82 grid0 = np.meshgrid(x, x) 83 grid1 = np.meshgrid(y, y) 84 85 # scalar inputs 86 assert (np.array([1]), (1,)) == model._convolution_inputs(1) 87 88 # Multiple inputs 89 assert np.all(model._convolution_inputs(*grid0)[0] == 90 np.reshape([grid0[0], grid0[1]], (2, -1)).T) 91 assert model._convolution_inputs(*grid0)[1] == grid0[0].shape 92 assert np.all(model._convolution_inputs(*grid1)[0] == 93 np.reshape([grid1[0], grid1[1]], (2, -1)).T) 94 assert model._convolution_inputs(*grid1)[1] == grid1[0].shape 95 96 # Error 97 with pytest.raises(ValueError) as err: 98 model._convolution_inputs(grid0[0], grid1[1]) 99 assert str(err.value) ==\ 100 "Values have differing shapes" 101