1import numpy as np
2import pytest
3from numpy.testing import assert_array_equal, assert_equal
4
5from skimage._shared._warnings import expected_warnings
6from skimage._shared.utils import _supported_float_type
7from skimage.filters import difference_of_gaussians, gaussian
8
9
10def test_negative_sigma():
11    a = np.zeros((3, 3))
12    a[1, 1] = 1.
13    with pytest.raises(ValueError):
14        gaussian(a, sigma=-1.0)
15    with pytest.raises(ValueError):
16        gaussian(a, sigma=[-1.0, 1.0])
17    with pytest.raises(ValueError):
18        gaussian(a, sigma=np.asarray([-1.0, 1.0]))
19
20
21def test_null_sigma():
22    a = np.zeros((3, 3))
23    a[1, 1] = 1.
24    assert np.all(gaussian(a, 0, preserve_range=True) == a)
25
26
27def test_default_sigma():
28    a = np.zeros((3, 3))
29    a[1, 1] = 1.
30    assert_array_equal(
31        gaussian(a, preserve_range=True),
32        gaussian(a, preserve_range=True, sigma=1)
33    )
34
35
36@pytest.mark.parametrize(
37    'dtype', [np.uint8, np.int32, np.float16, np.float32, np.float64]
38)
39def test_image_dtype(dtype):
40    a = np.zeros((3, 3), dtype=dtype)
41    assert gaussian(a).dtype == _supported_float_type(a.dtype)
42
43
44def test_energy_decrease():
45    a = np.zeros((3, 3))
46    a[1, 1] = 1.
47    gaussian_a = gaussian(a, preserve_range=True, sigma=1, mode='reflect')
48    assert gaussian_a.std() < a.std()
49
50
51@pytest.mark.parametrize('channel_axis', [0, 1, -1])
52def test_multichannel(channel_axis):
53    a = np.zeros((5, 5, 3))
54    a[1, 1] = np.arange(1, 4)
55    a = np.moveaxis(a, -1, channel_axis)
56    gaussian_rgb_a = gaussian(a, sigma=1, mode='reflect', preserve_range=True,
57                              channel_axis=channel_axis)
58    # Check that the mean value is conserved in each channel
59    # (color channels are not mixed together)
60    spatial_axes = tuple(
61        [ax for ax in range(a.ndim) if ax != channel_axis % a.ndim]
62    )
63    assert np.allclose(a.mean(axis=spatial_axes),
64                       gaussian_rgb_a.mean(axis=spatial_axes))
65
66    if channel_axis % a.ndim == 2:
67        # Test legacy behavior equivalent to old (multichannel = None)
68        with expected_warnings(['multichannel']):
69            gaussian_rgb_a = gaussian(a, sigma=1, mode='reflect',
70                                      preserve_range=True)
71
72        # Check that the mean value is conserved in each channel
73        # (color channels are not mixed together)
74        assert np.allclose(a.mean(axis=spatial_axes),
75                           gaussian_rgb_a.mean(axis=spatial_axes))
76    # Iterable sigma
77    gaussian_rgb_a = gaussian(a, sigma=[1, 2], mode='reflect',
78                              channel_axis=channel_axis,
79                              preserve_range=True)
80    assert np.allclose(a.mean(axis=spatial_axes),
81                       gaussian_rgb_a.mean(axis=spatial_axes))
82
83
84def test_deprecated_multichannel():
85    a = np.zeros((5, 5, 3))
86    a[1, 1] = np.arange(1, 4)
87    with expected_warnings(["`multichannel` is a deprecated argument"]):
88        gaussian_rgb_a = gaussian(a, sigma=1, mode='reflect',
89                                  multichannel=True)
90    # Check that the mean value is conserved in each channel
91    # (color channels are not mixed together)
92    assert np.allclose(a.mean(axis=(0, 1)), gaussian_rgb_a.mean(axis=(0, 1)))
93
94    # check positional multichannel argument warning
95    with expected_warnings(["Providing the `multichannel` argument"]):
96        gaussian_rgb_a = gaussian(a, 1, None, 'reflect', 0, True)
97
98
99def test_preserve_range():
100    """Test preserve_range parameter."""
101    ones = np.ones((2, 2), dtype=np.int64)
102    filtered_ones = gaussian(ones, preserve_range=False)
103    assert np.all(filtered_ones == filtered_ones[0, 0])
104    assert filtered_ones[0, 0] < 1e-10
105
106    filtered_preserved = gaussian(ones, preserve_range=True)
107    assert np.all(filtered_preserved == 1.)
108
109    img = np.array([[10.0, -10.0], [-4, 3]], dtype=np.float32)
110    gaussian(img, 1)
111
112
113def test_1d_ok():
114    """Testing Gaussian Filter for 1D array.
115    With any array consisting of positive integers and only one zero - it
116    should filter all values to be greater than 0.1
117    """
118    nums = np.arange(7)
119    filtered = gaussian(nums, preserve_range=True)
120    assert np.all(filtered > 0.1)
121
122
123def test_4d_ok():
124    img = np.zeros((5,) * 4)
125    img[2, 2, 2, 2] = 1
126    res = gaussian(img, 1, mode='reflect', preserve_range=True)
127    assert np.allclose(res.sum(), 1)
128
129
130@pytest.mark.parametrize(
131    "dtype", [np.float32, np.float64]
132)
133def test_preserve_output(dtype):
134    image = np.arange(9, dtype=dtype).reshape((3, 3))
135    output = np.zeros_like(image, dtype=dtype)
136    gaussian_image = gaussian(image, sigma=1, output=output,
137                              preserve_range=True)
138    assert gaussian_image is output
139
140
141def test_output_error():
142    image = np.arange(9, dtype=np.float32).reshape((3, 3))
143    output = np.zeros_like(image, dtype=np.uint8)
144    with pytest.raises(ValueError):
145        gaussian(image, sigma=1, output=output,
146                 preserve_range=True)
147
148
149@pytest.mark.parametrize("s", [1, (2, 3)])
150@pytest.mark.parametrize("s2", [4, (5, 6)])
151@pytest.mark.parametrize("channel_axis", [None, 0, 1, -1])
152def test_difference_of_gaussians(s, s2, channel_axis):
153    image = np.random.rand(10, 10)
154    if channel_axis is not None:
155        n_channels = 5
156        image = np.stack((image,) * n_channels, channel_axis)
157    im1 = gaussian(image, s, preserve_range=True, channel_axis=channel_axis)
158    im2 = gaussian(image, s2, preserve_range=True, channel_axis=channel_axis)
159    dog = im1 - im2
160    dog2 = difference_of_gaussians(image, s, s2, channel_axis=channel_axis)
161    assert np.allclose(dog, dog2)
162
163
164@pytest.mark.parametrize("s", [1, (1, 2)])
165def test_auto_sigma2(s):
166    image = np.random.rand(10, 10)
167    im1 = gaussian(image, s, preserve_range=True)
168    s2 = 1.6 * np.array(s)
169    im2 = gaussian(image, s2, preserve_range=True)
170    dog = im1 - im2
171    dog2 = difference_of_gaussians(image, s, s2)
172    assert np.allclose(dog, dog2)
173
174
175def test_dog_invalid_sigma_dims():
176    image = np.ones((5, 5, 3))
177    with pytest.raises(ValueError):
178        difference_of_gaussians(image, (1, 2))
179    with pytest.raises(ValueError):
180        difference_of_gaussians(image, 1, (3, 4))
181    with pytest.raises(ValueError):
182        with expected_warnings(["`multichannel` is a deprecated argument"]):
183            difference_of_gaussians(image, (1, 2, 3), multichannel=True)
184    with pytest.raises(ValueError):
185        difference_of_gaussians(image, (1, 2, 3), channel_axis=-1)
186
187
188def test_dog_invalid_sigma2():
189    image = np.ones((3, 3))
190    with pytest.raises(ValueError):
191        difference_of_gaussians(image, 3, 2)
192    with pytest.raises(ValueError):
193        difference_of_gaussians(image, (1, 5), (2, 4))
194