1"""Test for correctness of color distance functions"""
2
3import numpy as np
4import pytest
5from numpy.testing import assert_allclose, assert_almost_equal, assert_equal
6
7from skimage._shared.testing import fetch
8from skimage._shared.utils import _supported_float_type
9from skimage.color.delta_e import (deltaE_cie76, deltaE_ciede94,
10                                   deltaE_ciede2000, deltaE_cmc)
11
12
13@pytest.mark.parametrize("channel_axis", [0, 1, -1])
14@pytest.mark.parametrize('dtype', [np.float32, np.float64])
15def test_ciede2000_dE(dtype, channel_axis):
16    data = load_ciede2000_data()
17    N = len(data)
18    lab1 = np.zeros((N, 3), dtype=dtype)
19    lab1[:, 0] = data['L1']
20    lab1[:, 1] = data['a1']
21    lab1[:, 2] = data['b1']
22
23    lab2 = np.zeros((N, 3), dtype=dtype)
24    lab2[:, 0] = data['L2']
25    lab2[:, 1] = data['a2']
26    lab2[:, 2] = data['b2']
27
28    lab1 = np.moveaxis(lab1, source=-1, destination=channel_axis)
29    lab2 = np.moveaxis(lab2, source=-1, destination=channel_axis)
30    dE2 = deltaE_ciede2000(lab1, lab2, channel_axis=channel_axis)
31    assert dE2.dtype == _supported_float_type(dtype)
32
33    rtol = 1e-2 if dtype == np.float32 else 1e-4
34    assert_allclose(dE2, data['dE'], rtol=rtol)
35
36
37def load_ciede2000_data():
38    dtype = [('pair', int),
39             ('1', int),
40             ('L1', float),
41             ('a1', float),
42             ('b1', float),
43             ('a1_prime', float),
44             ('C1_prime', float),
45             ('h1_prime', float),
46             ('hbar_prime', float),
47             ('G', float),
48             ('T', float),
49             ('SL', float),
50             ('SC', float),
51             ('SH', float),
52             ('RT', float),
53             ('dE', float),
54             ('2', int),
55             ('L2', float),
56             ('a2', float),
57             ('b2', float),
58             ('a2_prime', float),
59             ('C2_prime', float),
60             ('h2_prime', float),
61             ]
62
63    # note: ciede_test_data.txt contains several intermediate quantities
64    path = fetch('color/tests/ciede2000_test_data.txt')
65    return np.loadtxt(path, dtype=dtype)
66
67
68@pytest.mark.parametrize("channel_axis", [0, 1, -1])
69@pytest.mark.parametrize('dtype', [np.float32, np.float64])
70def test_cie76(dtype, channel_axis):
71    data = load_ciede2000_data()
72    N = len(data)
73    lab1 = np.zeros((N, 3), dtype=dtype)
74    lab1[:, 0] = data['L1']
75    lab1[:, 1] = data['a1']
76    lab1[:, 2] = data['b1']
77
78    lab2 = np.zeros((N, 3), dtype=dtype)
79    lab2[:, 0] = data['L2']
80    lab2[:, 1] = data['a2']
81    lab2[:, 2] = data['b2']
82
83    lab1 = np.moveaxis(lab1, source=-1, destination=channel_axis)
84    lab2 = np.moveaxis(lab2, source=-1, destination=channel_axis)
85    dE2 = deltaE_cie76(lab1, lab2, channel_axis=channel_axis)
86    assert dE2.dtype == _supported_float_type(dtype)
87    oracle = np.array([
88        4.00106328, 6.31415011, 9.1776999, 2.06270077, 2.36957073,
89        2.91529271, 2.23606798, 2.23606798, 4.98000036, 4.9800004,
90        4.98000044, 4.98000049, 4.98000036, 4.9800004, 4.98000044,
91        3.53553391, 36.86800781, 31.91002977, 30.25309901, 27.40894015,
92        0.89242934, 0.7972, 0.8583065, 0.82982507, 3.1819238,
93        2.21334297, 1.53890382, 4.60630929, 6.58467989, 3.88641412,
94        1.50514845, 2.3237848, 0.94413208, 1.31910843
95    ])
96    rtol = 1e-5 if dtype == np.float32 else 1e-8
97    assert_allclose(dE2, oracle, rtol=rtol)
98
99
100@pytest.mark.parametrize("channel_axis", [0, 1, -1])
101@pytest.mark.parametrize('dtype', [np.float32, np.float64])
102def test_ciede94(dtype, channel_axis):
103    data = load_ciede2000_data()
104    N = len(data)
105    lab1 = np.zeros((N, 3), dtype=dtype)
106    lab1[:, 0] = data['L1']
107    lab1[:, 1] = data['a1']
108    lab1[:, 2] = data['b1']
109
110    lab2 = np.zeros((N, 3), dtype=dtype)
111    lab2[:, 0] = data['L2']
112    lab2[:, 1] = data['a2']
113    lab2[:, 2] = data['b2']
114
115    lab1 = np.moveaxis(lab1, source=-1, destination=channel_axis)
116    lab2 = np.moveaxis(lab2, source=-1, destination=channel_axis)
117    dE2 = deltaE_ciede94(lab1, lab2, channel_axis=channel_axis)
118    assert dE2.dtype == _supported_float_type(dtype)
119    oracle = np.array([
120        1.39503887, 1.93410055, 2.45433566, 0.68449187, 0.6695627,
121        0.69194527, 2.23606798, 2.03163832, 4.80069441, 4.80069445,
122        4.80069449, 4.80069453, 4.80069441, 4.80069445, 4.80069449,
123        3.40774352, 34.6891632, 29.44137328, 27.91408781, 24.93766082,
124        0.82213163, 0.71658427, 0.8048753, 0.75284394, 1.39099471,
125        1.24808929, 1.29795787, 1.82045088, 2.55613309, 1.42491303,
126        1.41945261, 2.3225685, 0.93853308, 1.30654464
127    ])
128    rtol = 1e-5 if dtype == np.float32 else 1e-8
129    assert_allclose(dE2, oracle, rtol=rtol)
130
131
132@pytest.mark.parametrize("channel_axis", [0, 1, -1])
133@pytest.mark.parametrize('dtype', [np.float32, np.float64])
134def test_cmc(dtype, channel_axis):
135    data = load_ciede2000_data()
136    N = len(data)
137    lab1 = np.zeros((N, 3), dtype=dtype)
138    lab1[:, 0] = data['L1']
139    lab1[:, 1] = data['a1']
140    lab1[:, 2] = data['b1']
141
142    lab2 = np.zeros((N, 3), dtype=dtype)
143    lab2[:, 0] = data['L2']
144    lab2[:, 1] = data['a2']
145    lab2[:, 2] = data['b2']
146
147    lab1 = np.moveaxis(lab1, source=-1, destination=channel_axis)
148    lab2 = np.moveaxis(lab2, source=-1, destination=channel_axis)
149    dE2 = deltaE_cmc(lab1, lab2, channel_axis=channel_axis)
150    assert dE2.dtype == _supported_float_type(dtype)
151    oracle = np.array([
152        1.73873611, 2.49660844, 3.30494501, 0.85735576, 0.88332927,
153        0.97822692, 3.50480874, 2.87930032, 6.5783807, 6.57838075,
154        6.5783808, 6.57838086, 6.67492321, 6.67492326, 6.67492331,
155        4.66852997, 42.10875485, 39.45889064, 38.36005919, 33.93663807,
156        1.14400168, 1.00600419, 1.11302547, 1.05335328, 1.42822951,
157        1.2548143, 1.76838061, 2.02583367, 3.08695508, 1.74893533,
158        1.90095165, 1.70258148, 1.80317207, 2.44934417
159    ])
160    rtol = 1e-5 if dtype == np.float32 else 1e-8
161    assert_allclose(dE2, oracle, rtol=rtol)
162
163    # Equal or close colors make `delta_e.get_dH2` function to return
164    # negative values resulting in NaNs when passed to sqrt (see #1908
165    # issue on Github):
166    lab1 = lab2
167    expected = np.zeros_like(oracle)
168    assert_almost_equal(
169        deltaE_cmc(lab1, lab2, channel_axis=channel_axis), expected, decimal=6
170    )
171
172    lab2[0, 0] += np.finfo(float).eps
173    assert_almost_equal(
174        deltaE_cmc(lab1, lab2, channel_axis=channel_axis), expected, decimal=6
175    )
176
177
178def test_cmc_single_item():
179    # Single item case:
180    lab1 = lab2 = np.array([0., 1.59607713, 0.87755709])
181    assert_equal(deltaE_cmc(lab1, lab2), 0)
182
183    lab2[0] += np.finfo(float).eps
184    assert_equal(deltaE_cmc(lab1, lab2), 0)
185
186
187def test_single_color_cie76():
188    lab1 = (0.5, 0.5, 0.5)
189    lab2 = (0.4, 0.4, 0.4)
190    deltaE_cie76(lab1, lab2)
191
192
193def test_single_color_ciede94():
194    lab1 = (0.5, 0.5, 0.5)
195    lab2 = (0.4, 0.4, 0.4)
196    deltaE_ciede94(lab1, lab2)
197
198
199def test_single_color_ciede2000():
200    lab1 = (0.5, 0.5, 0.5)
201    lab2 = (0.4, 0.4, 0.4)
202    deltaE_ciede2000(lab1, lab2)
203
204
205def test_single_color_cmc():
206    lab1 = (0.5, 0.5, 0.5)
207    lab2 = (0.4, 0.4, 0.4)
208    deltaE_cmc(lab1, lab2)
209