1from functools import partial
2
3import numpy as np
4
5from skimage import img_as_float, img_as_uint
6from skimage import color, data, filters
7from skimage.color.adapt_rgb import adapt_rgb, each_channel, hsv_value
8
9# Down-sample image for quicker testing.
10COLOR_IMAGE = data.astronaut()[::5, ::6]
11GRAY_IMAGE = data.camera()[::5, ::5]
12
13SIGMA = 3
14smooth = partial(filters.gaussian, sigma=SIGMA)
15assert_allclose = partial(np.testing.assert_allclose, atol=1e-8)
16
17
18@adapt_rgb(each_channel)
19def edges_each(image):
20    return filters.sobel(image)
21
22
23@adapt_rgb(each_channel)
24def smooth_each(image, sigma):
25    return filters.gaussian(image, sigma)
26
27
28@adapt_rgb(each_channel)
29def mask_each(image, mask):
30    result = image.copy()
31    result[mask] = 0
32    return result
33
34
35@adapt_rgb(hsv_value)
36def edges_hsv(image):
37    return filters.sobel(image)
38
39
40@adapt_rgb(hsv_value)
41def smooth_hsv(image, sigma):
42    return filters.gaussian(image, sigma)
43
44
45@adapt_rgb(hsv_value)
46def edges_hsv_uint(image):
47    return img_as_uint(filters.sobel(image))
48
49
50def test_gray_scale_image():
51    # We don't need to test both `hsv_value` and `each_channel` since
52    # `adapt_rgb` is handling gray-scale inputs.
53    assert_allclose(edges_each(GRAY_IMAGE), filters.sobel(GRAY_IMAGE))
54
55
56def test_each_channel():
57    filtered = edges_each(COLOR_IMAGE)
58    for i, channel in enumerate(np.rollaxis(filtered, axis=-1)):
59        expected = img_as_float(filters.sobel(COLOR_IMAGE[:, :, i]))
60        assert_allclose(channel, expected)
61
62
63def test_each_channel_with_filter_argument():
64    filtered = smooth_each(COLOR_IMAGE, SIGMA)
65    for i, channel in enumerate(np.rollaxis(filtered, axis=-1)):
66        assert_allclose(channel, smooth(COLOR_IMAGE[:, :, i]))
67
68
69def test_each_channel_with_asymmetric_kernel():
70    mask = np.triu(np.ones(COLOR_IMAGE.shape[:2], dtype=bool))
71    mask_each(COLOR_IMAGE, mask)
72
73
74def test_hsv_value():
75    filtered = edges_hsv(COLOR_IMAGE)
76    value = color.rgb2hsv(COLOR_IMAGE)[:, :, 2]
77    assert_allclose(color.rgb2hsv(filtered)[:, :, 2], filters.sobel(value))
78
79
80def test_hsv_value_with_filter_argument():
81    filtered = smooth_hsv(COLOR_IMAGE, SIGMA)
82    value = color.rgb2hsv(COLOR_IMAGE)[:, :, 2]
83    assert_allclose(color.rgb2hsv(filtered)[:, :, 2], smooth(value))
84
85
86def test_hsv_value_with_non_float_output():
87    # Since `rgb2hsv` returns a float image and the result of the filtered
88    # result is inserted into the HSV image, we want to make sure there isn't
89    # a dtype mismatch.
90    filtered = edges_hsv_uint(COLOR_IMAGE)
91    filtered_value = color.rgb2hsv(filtered)[:, :, 2]
92    value = color.rgb2hsv(COLOR_IMAGE)[:, :, 2]
93    # Reduce tolerance because dtype conversion.
94    assert_allclose(filtered_value, filters.sobel(value), rtol=1e-5, atol=1e-5)
95