1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5 #include "lib/jxl/quant_weights.h"
6
7 #include <stdlib.h>
8
9 #include <algorithm>
10 #include <cmath>
11 #include <hwy/base.h> // HWY_ALIGN_MAX
12 #include <hwy/tests/test_util-inl.h>
13 #include <numeric>
14 #include <random>
15
16 #include "lib/jxl/dct_for_test.h"
17 #include "lib/jxl/dec_transforms_testonly.h"
18 #include "lib/jxl/enc_modular.h"
19 #include "lib/jxl/enc_quant_weights.h"
20 #include "lib/jxl/enc_transforms.h"
21
22 namespace jxl {
23 namespace {
24
25 template <typename T>
CheckSimilar(T a,T b)26 void CheckSimilar(T a, T b) {
27 EXPECT_EQ(a, b);
28 }
29 // minimum exponent = -15.
30 template <>
CheckSimilar(float a,float b)31 void CheckSimilar(float a, float b) {
32 float m = std::max(std::abs(a), std::abs(b));
33 // 10 bits of precision are used in the format. Relative error should be
34 // below 2^-10.
35 EXPECT_LE(std::abs(a - b), m / 1024.0f) << "a: " << a << " b: " << b;
36 }
37
TEST(QuantWeightsTest,DC)38 TEST(QuantWeightsTest, DC) {
39 DequantMatrices mat;
40 float dc_quant[3] = {1e+5, 1e+3, 1e+1};
41 DequantMatricesSetCustomDC(&mat, dc_quant);
42 for (size_t c = 0; c < 3; c++) {
43 CheckSimilar(mat.InvDCQuant(c), dc_quant[c]);
44 }
45 }
46
RoundtripMatrices(const std::vector<QuantEncoding> & encodings)47 void RoundtripMatrices(const std::vector<QuantEncoding>& encodings) {
48 ASSERT_TRUE(encodings.size() == DequantMatrices::kNum);
49 DequantMatrices mat;
50 CodecMetadata metadata;
51 FrameHeader frame_header(&metadata);
52 ModularFrameEncoder encoder(frame_header, CompressParams{});
53 DequantMatricesSetCustom(&mat, encodings, &encoder);
54 const std::vector<QuantEncoding>& encodings_dec = mat.encodings();
55 for (size_t i = 0; i < encodings.size(); i++) {
56 const QuantEncoding& e = encodings[i];
57 const QuantEncoding& d = encodings_dec[i];
58 // Check values roundtripped correctly.
59 EXPECT_EQ(e.mode, d.mode);
60 EXPECT_EQ(e.predefined, d.predefined);
61 EXPECT_EQ(e.source, d.source);
62
63 EXPECT_EQ(static_cast<uint64_t>(e.dct_params.num_distance_bands),
64 static_cast<uint64_t>(d.dct_params.num_distance_bands));
65 for (size_t c = 0; c < 3; c++) {
66 for (size_t j = 0; j < DctQuantWeightParams::kMaxDistanceBands; j++) {
67 CheckSimilar(e.dct_params.distance_bands[c][j],
68 d.dct_params.distance_bands[c][j]);
69 }
70 }
71
72 if (e.mode == QuantEncoding::kQuantModeRAW) {
73 EXPECT_FALSE(!e.qraw.qtable);
74 EXPECT_FALSE(!d.qraw.qtable);
75 EXPECT_EQ(e.qraw.qtable->size(), d.qraw.qtable->size());
76 for (size_t j = 0; j < e.qraw.qtable->size(); j++) {
77 EXPECT_EQ((*e.qraw.qtable)[j], (*d.qraw.qtable)[j]);
78 }
79 EXPECT_NEAR(e.qraw.qtable_den, d.qraw.qtable_den, 1e-7f);
80 } else {
81 // modes different than kQuantModeRAW use one of the other fields used
82 // here, which all happen to be arrays of floats.
83 for (size_t c = 0; c < 3; c++) {
84 for (size_t j = 0; j < 3; j++) {
85 CheckSimilar(e.idweights[c][j], d.idweights[c][j]);
86 }
87 for (size_t j = 0; j < 6; j++) {
88 CheckSimilar(e.dct2weights[c][j], d.dct2weights[c][j]);
89 }
90 for (size_t j = 0; j < 2; j++) {
91 CheckSimilar(e.dct4multipliers[c][j], d.dct4multipliers[c][j]);
92 }
93 CheckSimilar(e.dct4x8multipliers[c], d.dct4x8multipliers[c]);
94 for (size_t j = 0; j < 9; j++) {
95 CheckSimilar(e.afv_weights[c][j], d.afv_weights[c][j]);
96 }
97 for (size_t j = 0; j < DctQuantWeightParams::kMaxDistanceBands; j++) {
98 CheckSimilar(e.dct_params_afv_4x4.distance_bands[c][j],
99 d.dct_params_afv_4x4.distance_bands[c][j]);
100 }
101 }
102 }
103 }
104 }
105
TEST(QuantWeightsTest,AllDefault)106 TEST(QuantWeightsTest, AllDefault) {
107 std::vector<QuantEncoding> encodings(DequantMatrices::kNum,
108 QuantEncoding::Library(0));
109 RoundtripMatrices(encodings);
110 }
111
TestSingleQuantMatrix(DequantMatrices::QuantTable kind)112 void TestSingleQuantMatrix(DequantMatrices::QuantTable kind) {
113 std::vector<QuantEncoding> encodings(DequantMatrices::kNum,
114 QuantEncoding::Library(0));
115 encodings[kind] = DequantMatrices::Library()[kind];
116 RoundtripMatrices(encodings);
117 }
118
119 // Ensure we can reasonably represent default quant tables.
TEST(QuantWeightsTest,DCT)120 TEST(QuantWeightsTest, DCT) { TestSingleQuantMatrix(DequantMatrices::DCT); }
TEST(QuantWeightsTest,IDENTITY)121 TEST(QuantWeightsTest, IDENTITY) {
122 TestSingleQuantMatrix(DequantMatrices::IDENTITY);
123 }
TEST(QuantWeightsTest,DCT2X2)124 TEST(QuantWeightsTest, DCT2X2) {
125 TestSingleQuantMatrix(DequantMatrices::DCT2X2);
126 }
TEST(QuantWeightsTest,DCT4X4)127 TEST(QuantWeightsTest, DCT4X4) {
128 TestSingleQuantMatrix(DequantMatrices::DCT4X4);
129 }
TEST(QuantWeightsTest,DCT16X16)130 TEST(QuantWeightsTest, DCT16X16) {
131 TestSingleQuantMatrix(DequantMatrices::DCT16X16);
132 }
TEST(QuantWeightsTest,DCT32X32)133 TEST(QuantWeightsTest, DCT32X32) {
134 TestSingleQuantMatrix(DequantMatrices::DCT32X32);
135 }
TEST(QuantWeightsTest,DCT8X16)136 TEST(QuantWeightsTest, DCT8X16) {
137 TestSingleQuantMatrix(DequantMatrices::DCT8X16);
138 }
TEST(QuantWeightsTest,DCT8X32)139 TEST(QuantWeightsTest, DCT8X32) {
140 TestSingleQuantMatrix(DequantMatrices::DCT8X32);
141 }
TEST(QuantWeightsTest,DCT16X32)142 TEST(QuantWeightsTest, DCT16X32) {
143 TestSingleQuantMatrix(DequantMatrices::DCT16X32);
144 }
TEST(QuantWeightsTest,DCT4X8)145 TEST(QuantWeightsTest, DCT4X8) {
146 TestSingleQuantMatrix(DequantMatrices::DCT4X8);
147 }
TEST(QuantWeightsTest,AFV0)148 TEST(QuantWeightsTest, AFV0) { TestSingleQuantMatrix(DequantMatrices::AFV0); }
TEST(QuantWeightsTest,RAW)149 TEST(QuantWeightsTest, RAW) {
150 std::vector<QuantEncoding> encodings(DequantMatrices::kNum,
151 QuantEncoding::Library(0));
152 std::vector<int> matrix(3 * 32 * 32);
153 std::mt19937 rng;
154 std::uniform_int_distribution<size_t> dist(1, 255);
155 for (size_t i = 0; i < matrix.size(); i++) matrix[i] = dist(rng);
156 encodings[DequantMatrices::kQuantTable[AcStrategy::DCT32X32]] =
157 QuantEncoding::RAW(matrix, 2);
158 RoundtripMatrices(encodings);
159 }
160
161 class QuantWeightsTargetTest : public hwy::TestWithParamTarget {};
162 HWY_TARGET_INSTANTIATE_TEST_SUITE_P(QuantWeightsTargetTest);
163
TEST_P(QuantWeightsTargetTest,DCTUniform)164 TEST_P(QuantWeightsTargetTest, DCTUniform) {
165 constexpr float kUniformQuant = 4;
166 float weights[3][2] = {{1.0f / kUniformQuant, 0},
167 {1.0f / kUniformQuant, 0},
168 {1.0f / kUniformQuant, 0}};
169 DctQuantWeightParams dct_params(weights);
170 std::vector<QuantEncoding> encodings(DequantMatrices::kNum,
171 QuantEncoding::DCT(dct_params));
172 DequantMatrices dequant_matrices;
173 CodecMetadata metadata;
174 FrameHeader frame_header(&metadata);
175 ModularFrameEncoder encoder(frame_header, CompressParams{});
176 DequantMatricesSetCustom(&dequant_matrices, encodings, &encoder);
177
178 const float dc_quant[3] = {1.0f / kUniformQuant, 1.0f / kUniformQuant,
179 1.0f / kUniformQuant};
180 DequantMatricesSetCustomDC(&dequant_matrices, dc_quant);
181
182 HWY_ALIGN_MAX float scratch_space[16 * 16 * 2];
183
184 // DCT8
185 {
186 HWY_ALIGN_MAX float pixels[64];
187 std::iota(std::begin(pixels), std::end(pixels), 0);
188 HWY_ALIGN_MAX float coeffs[64];
189 const AcStrategy::Type dct = AcStrategy::DCT;
190 TransformFromPixels(dct, pixels, 8, coeffs, scratch_space);
191 HWY_ALIGN_MAX double slow_coeffs[64];
192 for (size_t i = 0; i < 64; i++) slow_coeffs[i] = pixels[i];
193 DCTSlow<8>(slow_coeffs);
194
195 for (size_t i = 0; i < 64; i++) {
196 // DCTSlow doesn't multiply/divide by 1/N, so we do it manually.
197 slow_coeffs[i] = roundf(slow_coeffs[i] / kUniformQuant) * kUniformQuant;
198 coeffs[i] = roundf(coeffs[i] / dequant_matrices.Matrix(dct, 0)[i]) *
199 dequant_matrices.Matrix(dct, 0)[i];
200 }
201 IDCTSlow<8>(slow_coeffs);
202 TransformToPixels(dct, coeffs, pixels, 8, scratch_space);
203 for (size_t i = 0; i < 64; i++) {
204 EXPECT_NEAR(pixels[i], slow_coeffs[i], 1e-4);
205 }
206 }
207
208 // DCT16
209 {
210 HWY_ALIGN_MAX float pixels[64 * 4];
211 std::iota(std::begin(pixels), std::end(pixels), 0);
212 HWY_ALIGN_MAX float coeffs[64 * 4];
213 const AcStrategy::Type dct = AcStrategy::DCT16X16;
214 TransformFromPixels(dct, pixels, 16, coeffs, scratch_space);
215 HWY_ALIGN_MAX double slow_coeffs[64 * 4];
216 for (size_t i = 0; i < 64 * 4; i++) slow_coeffs[i] = pixels[i];
217 DCTSlow<16>(slow_coeffs);
218
219 for (size_t i = 0; i < 64 * 4; i++) {
220 slow_coeffs[i] = roundf(slow_coeffs[i] / kUniformQuant) * kUniformQuant;
221 coeffs[i] = roundf(coeffs[i] / dequant_matrices.Matrix(dct, 0)[i]) *
222 dequant_matrices.Matrix(dct, 0)[i];
223 }
224
225 IDCTSlow<16>(slow_coeffs);
226 TransformToPixels(dct, coeffs, pixels, 16, scratch_space);
227 for (size_t i = 0; i < 64 * 4; i++) {
228 EXPECT_NEAR(pixels[i], slow_coeffs[i], 1e-4);
229 }
230 }
231
232 // Check that all matrices have the same DC quantization, i.e. that they all
233 // have the same scaling.
234 for (size_t i = 0; i < AcStrategy::kNumValidStrategies; i++) {
235 EXPECT_NEAR(dequant_matrices.Matrix(i, 0)[0], kUniformQuant, 1e-6);
236 }
237 }
238
239 } // namespace
240 } // namespace jxl
241