1import numpy as np 2import pytest 3from numpy.testing import assert_array_equal, assert_equal 4 5from skimage._shared._warnings import expected_warnings 6from skimage._shared.utils import _supported_float_type 7from skimage.filters import difference_of_gaussians, gaussian 8 9 10def test_negative_sigma(): 11 a = np.zeros((3, 3)) 12 a[1, 1] = 1. 13 with pytest.raises(ValueError): 14 gaussian(a, sigma=-1.0) 15 with pytest.raises(ValueError): 16 gaussian(a, sigma=[-1.0, 1.0]) 17 with pytest.raises(ValueError): 18 gaussian(a, sigma=np.asarray([-1.0, 1.0])) 19 20 21def test_null_sigma(): 22 a = np.zeros((3, 3)) 23 a[1, 1] = 1. 24 assert np.all(gaussian(a, 0, preserve_range=True) == a) 25 26 27def test_default_sigma(): 28 a = np.zeros((3, 3)) 29 a[1, 1] = 1. 30 assert_array_equal( 31 gaussian(a, preserve_range=True), 32 gaussian(a, preserve_range=True, sigma=1) 33 ) 34 35 36@pytest.mark.parametrize( 37 'dtype', [np.uint8, np.int32, np.float16, np.float32, np.float64] 38) 39def test_image_dtype(dtype): 40 a = np.zeros((3, 3), dtype=dtype) 41 assert gaussian(a).dtype == _supported_float_type(a.dtype) 42 43 44def test_energy_decrease(): 45 a = np.zeros((3, 3)) 46 a[1, 1] = 1. 47 gaussian_a = gaussian(a, preserve_range=True, sigma=1, mode='reflect') 48 assert gaussian_a.std() < a.std() 49 50 51@pytest.mark.parametrize('channel_axis', [0, 1, -1]) 52def test_multichannel(channel_axis): 53 a = np.zeros((5, 5, 3)) 54 a[1, 1] = np.arange(1, 4) 55 a = np.moveaxis(a, -1, channel_axis) 56 gaussian_rgb_a = gaussian(a, sigma=1, mode='reflect', preserve_range=True, 57 channel_axis=channel_axis) 58 # Check that the mean value is conserved in each channel 59 # (color channels are not mixed together) 60 spatial_axes = tuple( 61 [ax for ax in range(a.ndim) if ax != channel_axis % a.ndim] 62 ) 63 assert np.allclose(a.mean(axis=spatial_axes), 64 gaussian_rgb_a.mean(axis=spatial_axes)) 65 66 if channel_axis % a.ndim == 2: 67 # Test legacy behavior equivalent to old (multichannel = None) 68 with expected_warnings(['multichannel']): 69 gaussian_rgb_a = gaussian(a, sigma=1, mode='reflect', 70 preserve_range=True) 71 72 # Check that the mean value is conserved in each channel 73 # (color channels are not mixed together) 74 assert np.allclose(a.mean(axis=spatial_axes), 75 gaussian_rgb_a.mean(axis=spatial_axes)) 76 # Iterable sigma 77 gaussian_rgb_a = gaussian(a, sigma=[1, 2], mode='reflect', 78 channel_axis=channel_axis, 79 preserve_range=True) 80 assert np.allclose(a.mean(axis=spatial_axes), 81 gaussian_rgb_a.mean(axis=spatial_axes)) 82 83 84def test_deprecated_multichannel(): 85 a = np.zeros((5, 5, 3)) 86 a[1, 1] = np.arange(1, 4) 87 with expected_warnings(["`multichannel` is a deprecated argument"]): 88 gaussian_rgb_a = gaussian(a, sigma=1, mode='reflect', 89 multichannel=True) 90 # Check that the mean value is conserved in each channel 91 # (color channels are not mixed together) 92 assert np.allclose(a.mean(axis=(0, 1)), gaussian_rgb_a.mean(axis=(0, 1))) 93 94 # check positional multichannel argument warning 95 with expected_warnings(["Providing the `multichannel` argument"]): 96 gaussian_rgb_a = gaussian(a, 1, None, 'reflect', 0, True) 97 98 99def test_preserve_range(): 100 """Test preserve_range parameter.""" 101 ones = np.ones((2, 2), dtype=np.int64) 102 filtered_ones = gaussian(ones, preserve_range=False) 103 assert np.all(filtered_ones == filtered_ones[0, 0]) 104 assert filtered_ones[0, 0] < 1e-10 105 106 filtered_preserved = gaussian(ones, preserve_range=True) 107 assert np.all(filtered_preserved == 1.) 108 109 img = np.array([[10.0, -10.0], [-4, 3]], dtype=np.float32) 110 gaussian(img, 1) 111 112 113def test_1d_ok(): 114 """Testing Gaussian Filter for 1D array. 115 With any array consisting of positive integers and only one zero - it 116 should filter all values to be greater than 0.1 117 """ 118 nums = np.arange(7) 119 filtered = gaussian(nums, preserve_range=True) 120 assert np.all(filtered > 0.1) 121 122 123def test_4d_ok(): 124 img = np.zeros((5,) * 4) 125 img[2, 2, 2, 2] = 1 126 res = gaussian(img, 1, mode='reflect', preserve_range=True) 127 assert np.allclose(res.sum(), 1) 128 129 130@pytest.mark.parametrize( 131 "dtype", [np.float32, np.float64] 132) 133def test_preserve_output(dtype): 134 image = np.arange(9, dtype=dtype).reshape((3, 3)) 135 output = np.zeros_like(image, dtype=dtype) 136 gaussian_image = gaussian(image, sigma=1, output=output, 137 preserve_range=True) 138 assert gaussian_image is output 139 140 141def test_output_error(): 142 image = np.arange(9, dtype=np.float32).reshape((3, 3)) 143 output = np.zeros_like(image, dtype=np.uint8) 144 with pytest.raises(ValueError): 145 gaussian(image, sigma=1, output=output, 146 preserve_range=True) 147 148 149@pytest.mark.parametrize("s", [1, (2, 3)]) 150@pytest.mark.parametrize("s2", [4, (5, 6)]) 151@pytest.mark.parametrize("channel_axis", [None, 0, 1, -1]) 152def test_difference_of_gaussians(s, s2, channel_axis): 153 image = np.random.rand(10, 10) 154 if channel_axis is not None: 155 n_channels = 5 156 image = np.stack((image,) * n_channels, channel_axis) 157 im1 = gaussian(image, s, preserve_range=True, channel_axis=channel_axis) 158 im2 = gaussian(image, s2, preserve_range=True, channel_axis=channel_axis) 159 dog = im1 - im2 160 dog2 = difference_of_gaussians(image, s, s2, channel_axis=channel_axis) 161 assert np.allclose(dog, dog2) 162 163 164@pytest.mark.parametrize("s", [1, (1, 2)]) 165def test_auto_sigma2(s): 166 image = np.random.rand(10, 10) 167 im1 = gaussian(image, s, preserve_range=True) 168 s2 = 1.6 * np.array(s) 169 im2 = gaussian(image, s2, preserve_range=True) 170 dog = im1 - im2 171 dog2 = difference_of_gaussians(image, s, s2) 172 assert np.allclose(dog, dog2) 173 174 175def test_dog_invalid_sigma_dims(): 176 image = np.ones((5, 5, 3)) 177 with pytest.raises(ValueError): 178 difference_of_gaussians(image, (1, 2)) 179 with pytest.raises(ValueError): 180 difference_of_gaussians(image, 1, (3, 4)) 181 with pytest.raises(ValueError): 182 with expected_warnings(["`multichannel` is a deprecated argument"]): 183 difference_of_gaussians(image, (1, 2, 3), multichannel=True) 184 with pytest.raises(ValueError): 185 difference_of_gaussians(image, (1, 2, 3), channel_axis=-1) 186 187 188def test_dog_invalid_sigma2(): 189 image = np.ones((3, 3)) 190 with pytest.raises(ValueError): 191 difference_of_gaussians(image, 3, 2) 192 with pytest.raises(ValueError): 193 difference_of_gaussians(image, (1, 5), (2, 4)) 194