1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5 #include "lib/jxl/quant_weights.h"
6 
7 #include <stdlib.h>
8 
9 #include <algorithm>
10 #include <cmath>
11 #include <hwy/base.h>  // HWY_ALIGN_MAX
12 #include <hwy/tests/test_util-inl.h>
13 #include <numeric>
14 #include <random>
15 
16 #include "lib/jxl/dct_for_test.h"
17 #include "lib/jxl/dec_transforms_testonly.h"
18 #include "lib/jxl/enc_modular.h"
19 #include "lib/jxl/enc_quant_weights.h"
20 #include "lib/jxl/enc_transforms.h"
21 
22 namespace jxl {
23 namespace {
24 
25 template <typename T>
CheckSimilar(T a,T b)26 void CheckSimilar(T a, T b) {
27   EXPECT_EQ(a, b);
28 }
29 // minimum exponent = -15.
30 template <>
CheckSimilar(float a,float b)31 void CheckSimilar(float a, float b) {
32   float m = std::max(std::abs(a), std::abs(b));
33   // 10 bits of precision are used in the format. Relative error should be
34   // below 2^-10.
35   EXPECT_LE(std::abs(a - b), m / 1024.0f) << "a: " << a << " b: " << b;
36 }
37 
TEST(QuantWeightsTest,DC)38 TEST(QuantWeightsTest, DC) {
39   DequantMatrices mat;
40   float dc_quant[3] = {1e+5, 1e+3, 1e+1};
41   DequantMatricesSetCustomDC(&mat, dc_quant);
42   for (size_t c = 0; c < 3; c++) {
43     CheckSimilar(mat.InvDCQuant(c), dc_quant[c]);
44   }
45 }
46 
RoundtripMatrices(const std::vector<QuantEncoding> & encodings)47 void RoundtripMatrices(const std::vector<QuantEncoding>& encodings) {
48   ASSERT_TRUE(encodings.size() == DequantMatrices::kNum);
49   DequantMatrices mat;
50   CodecMetadata metadata;
51   FrameHeader frame_header(&metadata);
52   ModularFrameEncoder encoder(frame_header, CompressParams{});
53   DequantMatricesSetCustom(&mat, encodings, &encoder);
54   const std::vector<QuantEncoding>& encodings_dec = mat.encodings();
55   for (size_t i = 0; i < encodings.size(); i++) {
56     const QuantEncoding& e = encodings[i];
57     const QuantEncoding& d = encodings_dec[i];
58     // Check values roundtripped correctly.
59     EXPECT_EQ(e.mode, d.mode);
60     EXPECT_EQ(e.predefined, d.predefined);
61     EXPECT_EQ(e.source, d.source);
62 
63     EXPECT_EQ(static_cast<uint64_t>(e.dct_params.num_distance_bands),
64               static_cast<uint64_t>(d.dct_params.num_distance_bands));
65     for (size_t c = 0; c < 3; c++) {
66       for (size_t j = 0; j < DctQuantWeightParams::kMaxDistanceBands; j++) {
67         CheckSimilar(e.dct_params.distance_bands[c][j],
68                      d.dct_params.distance_bands[c][j]);
69       }
70     }
71 
72     if (e.mode == QuantEncoding::kQuantModeRAW) {
73       EXPECT_FALSE(!e.qraw.qtable);
74       EXPECT_FALSE(!d.qraw.qtable);
75       EXPECT_EQ(e.qraw.qtable->size(), d.qraw.qtable->size());
76       for (size_t j = 0; j < e.qraw.qtable->size(); j++) {
77         EXPECT_EQ((*e.qraw.qtable)[j], (*d.qraw.qtable)[j]);
78       }
79       EXPECT_NEAR(e.qraw.qtable_den, d.qraw.qtable_den, 1e-7f);
80     } else {
81       // modes different than kQuantModeRAW use one of the other fields used
82       // here, which all happen to be arrays of floats.
83       for (size_t c = 0; c < 3; c++) {
84         for (size_t j = 0; j < 3; j++) {
85           CheckSimilar(e.idweights[c][j], d.idweights[c][j]);
86         }
87         for (size_t j = 0; j < 6; j++) {
88           CheckSimilar(e.dct2weights[c][j], d.dct2weights[c][j]);
89         }
90         for (size_t j = 0; j < 2; j++) {
91           CheckSimilar(e.dct4multipliers[c][j], d.dct4multipliers[c][j]);
92         }
93         CheckSimilar(e.dct4x8multipliers[c], d.dct4x8multipliers[c]);
94         for (size_t j = 0; j < 9; j++) {
95           CheckSimilar(e.afv_weights[c][j], d.afv_weights[c][j]);
96         }
97         for (size_t j = 0; j < DctQuantWeightParams::kMaxDistanceBands; j++) {
98           CheckSimilar(e.dct_params_afv_4x4.distance_bands[c][j],
99                        d.dct_params_afv_4x4.distance_bands[c][j]);
100         }
101       }
102     }
103   }
104 }
105 
TEST(QuantWeightsTest,AllDefault)106 TEST(QuantWeightsTest, AllDefault) {
107   std::vector<QuantEncoding> encodings(DequantMatrices::kNum,
108                                        QuantEncoding::Library(0));
109   RoundtripMatrices(encodings);
110 }
111 
TestSingleQuantMatrix(DequantMatrices::QuantTable kind)112 void TestSingleQuantMatrix(DequantMatrices::QuantTable kind) {
113   std::vector<QuantEncoding> encodings(DequantMatrices::kNum,
114                                        QuantEncoding::Library(0));
115   encodings[kind] = DequantMatrices::Library()[kind];
116   RoundtripMatrices(encodings);
117 }
118 
119 // Ensure we can reasonably represent default quant tables.
TEST(QuantWeightsTest,DCT)120 TEST(QuantWeightsTest, DCT) { TestSingleQuantMatrix(DequantMatrices::DCT); }
TEST(QuantWeightsTest,IDENTITY)121 TEST(QuantWeightsTest, IDENTITY) {
122   TestSingleQuantMatrix(DequantMatrices::IDENTITY);
123 }
TEST(QuantWeightsTest,DCT2X2)124 TEST(QuantWeightsTest, DCT2X2) {
125   TestSingleQuantMatrix(DequantMatrices::DCT2X2);
126 }
TEST(QuantWeightsTest,DCT4X4)127 TEST(QuantWeightsTest, DCT4X4) {
128   TestSingleQuantMatrix(DequantMatrices::DCT4X4);
129 }
TEST(QuantWeightsTest,DCT16X16)130 TEST(QuantWeightsTest, DCT16X16) {
131   TestSingleQuantMatrix(DequantMatrices::DCT16X16);
132 }
TEST(QuantWeightsTest,DCT32X32)133 TEST(QuantWeightsTest, DCT32X32) {
134   TestSingleQuantMatrix(DequantMatrices::DCT32X32);
135 }
TEST(QuantWeightsTest,DCT8X16)136 TEST(QuantWeightsTest, DCT8X16) {
137   TestSingleQuantMatrix(DequantMatrices::DCT8X16);
138 }
TEST(QuantWeightsTest,DCT8X32)139 TEST(QuantWeightsTest, DCT8X32) {
140   TestSingleQuantMatrix(DequantMatrices::DCT8X32);
141 }
TEST(QuantWeightsTest,DCT16X32)142 TEST(QuantWeightsTest, DCT16X32) {
143   TestSingleQuantMatrix(DequantMatrices::DCT16X32);
144 }
TEST(QuantWeightsTest,DCT4X8)145 TEST(QuantWeightsTest, DCT4X8) {
146   TestSingleQuantMatrix(DequantMatrices::DCT4X8);
147 }
TEST(QuantWeightsTest,AFV0)148 TEST(QuantWeightsTest, AFV0) { TestSingleQuantMatrix(DequantMatrices::AFV0); }
TEST(QuantWeightsTest,RAW)149 TEST(QuantWeightsTest, RAW) {
150   std::vector<QuantEncoding> encodings(DequantMatrices::kNum,
151                                        QuantEncoding::Library(0));
152   std::vector<int> matrix(3 * 32 * 32);
153   std::mt19937 rng;
154   std::uniform_int_distribution<size_t> dist(1, 255);
155   for (size_t i = 0; i < matrix.size(); i++) matrix[i] = dist(rng);
156   encodings[DequantMatrices::kQuantTable[AcStrategy::DCT32X32]] =
157       QuantEncoding::RAW(matrix, 2);
158   RoundtripMatrices(encodings);
159 }
160 
161 class QuantWeightsTargetTest : public hwy::TestWithParamTarget {};
162 HWY_TARGET_INSTANTIATE_TEST_SUITE_P(QuantWeightsTargetTest);
163 
TEST_P(QuantWeightsTargetTest,DCTUniform)164 TEST_P(QuantWeightsTargetTest, DCTUniform) {
165   constexpr float kUniformQuant = 4;
166   float weights[3][2] = {{1.0f / kUniformQuant, 0},
167                          {1.0f / kUniformQuant, 0},
168                          {1.0f / kUniformQuant, 0}};
169   DctQuantWeightParams dct_params(weights);
170   std::vector<QuantEncoding> encodings(DequantMatrices::kNum,
171                                        QuantEncoding::DCT(dct_params));
172   DequantMatrices dequant_matrices;
173   CodecMetadata metadata;
174   FrameHeader frame_header(&metadata);
175   ModularFrameEncoder encoder(frame_header, CompressParams{});
176   DequantMatricesSetCustom(&dequant_matrices, encodings, &encoder);
177 
178   const float dc_quant[3] = {1.0f / kUniformQuant, 1.0f / kUniformQuant,
179                              1.0f / kUniformQuant};
180   DequantMatricesSetCustomDC(&dequant_matrices, dc_quant);
181 
182   HWY_ALIGN_MAX float scratch_space[16 * 16 * 2];
183 
184   // DCT8
185   {
186     HWY_ALIGN_MAX float pixels[64];
187     std::iota(std::begin(pixels), std::end(pixels), 0);
188     HWY_ALIGN_MAX float coeffs[64];
189     const AcStrategy::Type dct = AcStrategy::DCT;
190     TransformFromPixels(dct, pixels, 8, coeffs, scratch_space);
191     HWY_ALIGN_MAX double slow_coeffs[64];
192     for (size_t i = 0; i < 64; i++) slow_coeffs[i] = pixels[i];
193     DCTSlow<8>(slow_coeffs);
194 
195     for (size_t i = 0; i < 64; i++) {
196       // DCTSlow doesn't multiply/divide by 1/N, so we do it manually.
197       slow_coeffs[i] = roundf(slow_coeffs[i] / kUniformQuant) * kUniformQuant;
198       coeffs[i] = roundf(coeffs[i] / dequant_matrices.Matrix(dct, 0)[i]) *
199                   dequant_matrices.Matrix(dct, 0)[i];
200     }
201     IDCTSlow<8>(slow_coeffs);
202     TransformToPixels(dct, coeffs, pixels, 8, scratch_space);
203     for (size_t i = 0; i < 64; i++) {
204       EXPECT_NEAR(pixels[i], slow_coeffs[i], 1e-4);
205     }
206   }
207 
208   // DCT16
209   {
210     HWY_ALIGN_MAX float pixels[64 * 4];
211     std::iota(std::begin(pixels), std::end(pixels), 0);
212     HWY_ALIGN_MAX float coeffs[64 * 4];
213     const AcStrategy::Type dct = AcStrategy::DCT16X16;
214     TransformFromPixels(dct, pixels, 16, coeffs, scratch_space);
215     HWY_ALIGN_MAX double slow_coeffs[64 * 4];
216     for (size_t i = 0; i < 64 * 4; i++) slow_coeffs[i] = pixels[i];
217     DCTSlow<16>(slow_coeffs);
218 
219     for (size_t i = 0; i < 64 * 4; i++) {
220       slow_coeffs[i] = roundf(slow_coeffs[i] / kUniformQuant) * kUniformQuant;
221       coeffs[i] = roundf(coeffs[i] / dequant_matrices.Matrix(dct, 0)[i]) *
222                   dequant_matrices.Matrix(dct, 0)[i];
223     }
224 
225     IDCTSlow<16>(slow_coeffs);
226     TransformToPixels(dct, coeffs, pixels, 16, scratch_space);
227     for (size_t i = 0; i < 64 * 4; i++) {
228       EXPECT_NEAR(pixels[i], slow_coeffs[i], 1e-4);
229     }
230   }
231 
232   // Check that all matrices have the same DC quantization, i.e. that they all
233   // have the same scaling.
234   for (size_t i = 0; i < AcStrategy::kNumValidStrategies; i++) {
235     EXPECT_NEAR(dequant_matrices.Matrix(i, 0)[0], kUniformQuant, 1e-6);
236   }
237 }
238 
239 }  // namespace
240 }  // namespace jxl
241