1import numpy as np
2import pytest
3
4from skimage._shared._warnings import expected_warnings
5from skimage._shared.utils import _supported_float_type
6from skimage.filters import unsharp_mask
7
8
9@pytest.mark.parametrize("shape,multichannel",
10                         [((29,), False),
11                          ((40, 4), True),
12                          ((32, 32), False),
13                          ((29, 31, 3), True),
14                          ((13, 17, 4, 8), False)])
15@pytest.mark.parametrize("dtype", [np.uint8, np.int8,
16                                   np.uint16, np.int16,
17                                   np.uint32, np.int32,
18                                   np.uint64, np.int64,
19                                   np.float32, np.float64])
20@pytest.mark.parametrize("radius", [0, 0.1, 2.0])
21@pytest.mark.parametrize("amount", [0.0, 0.5, 2.0, -1.0])
22@pytest.mark.parametrize("offset", [-1.0, 0.0, 1.0])
23@pytest.mark.parametrize("preserve", [False, True])
24def test_unsharp_masking_output_type_and_shape(
25        radius, amount, shape, multichannel, dtype, offset, preserve):
26    array = np.random.random(shape)
27    array = ((array + offset) * 128).astype(dtype)
28    if (preserve is False) and (dtype in [np.float32, np.float64]):
29        array /= max(np.abs(array).max(), 1.0)
30    channel_axis = -1 if multichannel else None
31    output = unsharp_mask(array, radius, amount, preserve_range=preserve,
32                          channel_axis=channel_axis)
33    assert output.dtype in [np.float32, np.float64]
34    assert output.shape == shape
35
36
37@pytest.mark.parametrize("shape,multichannel",
38                         [((32, 32), False),
39                          ((15, 15, 2), True),
40                          ((17, 19, 3), True)])
41@pytest.mark.parametrize("radius", [(0.0, 0.0), (1.0, 1.0), (2.0, 1.5)])
42@pytest.mark.parametrize("preserve", [False, True])
43def test_unsharp_masking_with_different_radii(radius, shape,
44                                              multichannel, preserve):
45    amount = 1.0
46    dtype = np.float64
47    array = (np.random.random(shape) * 96).astype(dtype)
48    if preserve is False:
49        array /= max(np.abs(array).max(), 1.0)
50    channel_axis = -1 if multichannel else None
51    output = unsharp_mask(array, radius, amount, preserve_range=preserve,
52                          channel_axis=channel_axis)
53    assert output.dtype in [np.float32, np.float64]
54    assert output.shape == shape
55
56
57@pytest.mark.parametrize("shape,channel_axis",
58                         [((16, 16), None),
59                          ((15, 15, 2), -1),
60                          ((13, 17, 3), -1),
61                          ((2, 15, 15), 0),
62                          ((3, 13, 17), 0)])
63@pytest.mark.parametrize("offset", [-5, 0, 5])
64@pytest.mark.parametrize("preserve", [False, True])
65def test_unsharp_masking_with_different_ranges(shape, offset, channel_axis,
66                                               preserve):
67    radius = 2.0
68    amount = 1.0
69    dtype = np.int16
70    array = (np.random.random(shape) * 5 + offset).astype(dtype)
71    negative = np.any(array < 0)
72    output = unsharp_mask(array, radius, amount, preserve_range=preserve,
73                          channel_axis=channel_axis)
74    if preserve is False:
75        assert np.any(output <= 1)
76        assert np.any(output >= -1)
77        if negative is False:
78            assert np.any(output >= 0)
79    assert output.dtype in [np.float32, np.float64]
80    assert output.shape == shape
81
82
83@pytest.mark.parametrize("shape,multichannel",
84                         [((16, 16), False),
85                          ((15, 15, 2), True),
86                          ((13, 17, 3), True)])
87@pytest.mark.parametrize("offset", [-5, 0, 5])
88@pytest.mark.parametrize("preserve", [False, True])
89def test_unsharp_masking_with_different_ranges_deprecated(shape, offset,
90                                                          multichannel,
91                                                          preserve):
92    radius = 2.0
93    amount = 1.0
94    dtype = np.int16
95    array = (np.random.random(shape) * 5 + offset).astype(dtype)
96    negative = np.any(array < 0)
97    with expected_warnings(["`multichannel` is a deprecated argument"]):
98        output = unsharp_mask(array, radius, amount, multichannel=multichannel,
99                              preserve_range=preserve)
100    if preserve is False:
101        assert np.any(output <= 1)
102        assert np.any(output >= -1)
103        if negative is False:
104            assert np.any(output >= 0)
105    assert output.dtype in [np.float32, np.float64]
106    assert output.shape == shape
107
108    # providing multichannel positionally also raises a warning
109    with expected_warnings(["Providing the `multichannel`"]):
110        output = unsharp_mask(array, radius, amount, multichannel, preserve)
111
112
113@pytest.mark.parametrize("shape,channel_axis",
114                         [((16, 16), None),
115                          ((15, 15, 2), -1),
116                          ((13, 17, 3), -1)])
117@pytest.mark.parametrize("preserve", [False, True])
118@pytest.mark.parametrize("dtype", [np.uint8, np.float16, np.float32, np.float64])
119def test_unsharp_masking_dtypes(shape, channel_axis, preserve, dtype):
120    radius = 2.0
121    amount = 1.0
122    array = (np.random.random(shape) * 10).astype(dtype, copy=False)
123    negative = np.any(array < 0)
124    output = unsharp_mask(array, radius, amount, preserve_range=preserve,
125                          channel_axis=channel_axis)
126    if preserve is False:
127        assert np.any(output <= 1)
128        assert np.any(output >= -1)
129        if negative is False:
130            assert np.any(output >= 0)
131    assert output.dtype == _supported_float_type(dtype)
132    assert output.shape == shape
133