1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5
6 #include <stddef.h>
7 #include <stdint.h>
8
9 #include <vector>
10
11 #include "gtest/gtest.h"
12 #include "lib/jxl/ans_params.h"
13 #include "lib/jxl/aux_out_fwd.h"
14 #include "lib/jxl/base/random.h"
15 #include "lib/jxl/base/span.h"
16 #include "lib/jxl/dec_ans.h"
17 #include "lib/jxl/dec_bit_reader.h"
18 #include "lib/jxl/enc_ans.h"
19 #include "lib/jxl/enc_bit_writer.h"
20
21 namespace jxl {
22 namespace {
23
RoundtripTestcase(int n_histograms,int alphabet_size,const std::vector<Token> & input_values)24 void RoundtripTestcase(int n_histograms, int alphabet_size,
25 const std::vector<Token>& input_values) {
26 constexpr uint16_t kMagic1 = 0x9e33;
27 constexpr uint16_t kMagic2 = 0x8b04;
28
29 BitWriter writer;
30 // Space for magic bytes.
31 BitWriter::Allotment allotment_magic1(&writer, 16);
32 writer.Write(16, kMagic1);
33 ReclaimAndCharge(&writer, &allotment_magic1, 0, nullptr);
34
35 std::vector<uint8_t> context_map;
36 EntropyEncodingData codes;
37 std::vector<std::vector<Token>> input_values_vec;
38 input_values_vec.push_back(input_values);
39
40 BuildAndEncodeHistograms(HistogramParams(), n_histograms, input_values_vec,
41 &codes, &context_map, &writer, 0, nullptr);
42 WriteTokens(input_values_vec[0], codes, context_map, &writer, 0, nullptr);
43
44 // Magic bytes + padding
45 BitWriter::Allotment allotment_magic2(&writer, 24);
46 writer.Write(16, kMagic2);
47 writer.ZeroPadToByte();
48 ReclaimAndCharge(&writer, &allotment_magic2, 0, nullptr);
49
50 // We do not truncate the output. Reading past the end reads out zeroes
51 // anyway.
52 BitReader br(writer.GetSpan());
53
54 ASSERT_EQ(br.ReadBits(16), kMagic1);
55
56 std::vector<uint8_t> dec_context_map;
57 ANSCode decoded_codes;
58 ASSERT_TRUE(
59 DecodeHistograms(&br, n_histograms, &decoded_codes, &dec_context_map));
60 ASSERT_EQ(dec_context_map, context_map);
61 ANSSymbolReader reader(&decoded_codes, &br);
62
63 for (const Token& symbol : input_values) {
64 uint32_t read_symbol =
65 reader.ReadHybridUint(symbol.context, &br, dec_context_map);
66 ASSERT_EQ(read_symbol, symbol.value);
67 }
68 ASSERT_TRUE(reader.CheckANSFinalState());
69
70 ASSERT_EQ(br.ReadBits(16), kMagic2);
71 EXPECT_TRUE(br.Close());
72 }
73
TEST(ANSTest,EmptyRoundtrip)74 TEST(ANSTest, EmptyRoundtrip) {
75 RoundtripTestcase(2, ANS_MAX_ALPHABET_SIZE, std::vector<Token>());
76 }
77
TEST(ANSTest,SingleSymbolRoundtrip)78 TEST(ANSTest, SingleSymbolRoundtrip) {
79 for (uint32_t i = 0; i < ANS_MAX_ALPHABET_SIZE; i++) {
80 RoundtripTestcase(2, ANS_MAX_ALPHABET_SIZE, {{0, i}});
81 }
82 for (uint32_t i = 0; i < ANS_MAX_ALPHABET_SIZE; i++) {
83 RoundtripTestcase(2, ANS_MAX_ALPHABET_SIZE,
84 std::vector<Token>(1024, {0, i}));
85 }
86 }
87
88 #if defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER) || \
89 defined(THREAD_SANITIZER)
90 constexpr size_t kReps = 3;
91 #else
92 constexpr size_t kReps = 10;
93 #endif
94
RoundtripRandomStream(int alphabet_size,size_t reps=kReps,size_t num=1<<18)95 void RoundtripRandomStream(int alphabet_size, size_t reps = kReps,
96 size_t num = 1 << 18) {
97 constexpr int kNumHistograms = 3;
98 Rng rng(0);
99 for (size_t i = 0; i < reps; i++) {
100 std::vector<Token> symbols;
101 for (size_t j = 0; j < num; j++) {
102 int context = rng.UniformI(0, kNumHistograms);
103 int value = rng.UniformU(0, alphabet_size);
104 symbols.emplace_back(context, value);
105 }
106 RoundtripTestcase(kNumHistograms, alphabet_size, symbols);
107 }
108 }
109
RoundtripRandomUnbalancedStream(int alphabet_size)110 void RoundtripRandomUnbalancedStream(int alphabet_size) {
111 constexpr int kNumHistograms = 3;
112 constexpr int kPrecision = 1 << 10;
113 Rng rng(0);
114 for (size_t i = 0; i < kReps; i++) {
115 std::vector<int> distributions[kNumHistograms];
116 for (int j = 0; j < kNumHistograms; j++) {
117 distributions[j].resize(kPrecision);
118 int symbol = 0;
119 int remaining = 1;
120 for (int k = 0; k < kPrecision; k++) {
121 if (remaining == 0) {
122 if (symbol < alphabet_size - 1) symbol++;
123 // There is no meaning behind this distribution: it's anything that
124 // will create a nonuniform distribution and won't have too few
125 // symbols usually. Also we want different distributions we get to be
126 // sufficiently dissimilar.
127 remaining = rng.UniformU(0, kPrecision - k + 1);
128 }
129 distributions[j][k] = symbol;
130 remaining--;
131 }
132 }
133 std::vector<Token> symbols;
134 for (int j = 0; j < 1 << 18; j++) {
135 int context = rng.UniformI(0, kNumHistograms);
136 int value = rng.UniformU(0, kPrecision);
137 symbols.emplace_back(context, value);
138 }
139 RoundtripTestcase(kNumHistograms + 1, alphabet_size, symbols);
140 }
141 }
142
TEST(ANSTest,RandomStreamRoundtrip3Small)143 TEST(ANSTest, RandomStreamRoundtrip3Small) { RoundtripRandomStream(3, 1, 16); }
144
TEST(ANSTest,RandomStreamRoundtrip3)145 TEST(ANSTest, RandomStreamRoundtrip3) { RoundtripRandomStream(3); }
146
TEST(ANSTest,RandomStreamRoundtripBig)147 TEST(ANSTest, RandomStreamRoundtripBig) {
148 RoundtripRandomStream(ANS_MAX_ALPHABET_SIZE);
149 }
150
TEST(ANSTest,RandomUnbalancedStreamRoundtrip3)151 TEST(ANSTest, RandomUnbalancedStreamRoundtrip3) {
152 RoundtripRandomUnbalancedStream(3);
153 }
154
TEST(ANSTest,RandomUnbalancedStreamRoundtripBig)155 TEST(ANSTest, RandomUnbalancedStreamRoundtripBig) {
156 RoundtripRandomUnbalancedStream(ANS_MAX_ALPHABET_SIZE);
157 }
158
TEST(ANSTest,UintConfigRoundtrip)159 TEST(ANSTest, UintConfigRoundtrip) {
160 for (size_t log_alpha_size = 5; log_alpha_size <= 8; log_alpha_size++) {
161 std::vector<HybridUintConfig> uint_config, uint_config_dec;
162 for (size_t i = 0; i < log_alpha_size; i++) {
163 for (size_t j = 0; j <= i; j++) {
164 for (size_t k = 0; k <= i - j; k++) {
165 uint_config.emplace_back(i, j, k);
166 }
167 }
168 }
169 uint_config.emplace_back(log_alpha_size, 0, 0);
170 uint_config_dec.resize(uint_config.size());
171 BitWriter writer;
172 BitWriter::Allotment allotment(&writer, 10 * uint_config.size());
173 EncodeUintConfigs(uint_config, &writer, log_alpha_size);
174 ReclaimAndCharge(&writer, &allotment, 0, nullptr);
175 writer.ZeroPadToByte();
176 BitReader br(writer.GetSpan());
177 EXPECT_TRUE(DecodeUintConfigs(log_alpha_size, &uint_config_dec, &br));
178 EXPECT_TRUE(br.Close());
179 for (size_t i = 0; i < uint_config.size(); i++) {
180 EXPECT_EQ(uint_config[i].split_token, uint_config_dec[i].split_token);
181 EXPECT_EQ(uint_config[i].msb_in_token, uint_config_dec[i].msb_in_token);
182 EXPECT_EQ(uint_config[i].lsb_in_token, uint_config_dec[i].lsb_in_token);
183 }
184 }
185 }
186
TestCheckpointing(bool ans,bool lz77)187 void TestCheckpointing(bool ans, bool lz77) {
188 std::vector<std::vector<Token>> input_values(1);
189 for (size_t i = 0; i < 1024; i++) {
190 input_values[0].push_back(Token(0, i % 4));
191 }
192 // up to lz77 window size.
193 for (size_t i = 0; i < (1 << 20) - 1022; i++) {
194 input_values[0].push_back(Token(0, (i % 5) + 4));
195 }
196 // Ensure that when the window wraps around, new values are different.
197 input_values[0].push_back(Token(0, 0));
198 for (size_t i = 0; i < 1024; i++) {
199 input_values[0].push_back(Token(0, i % 4));
200 }
201
202 std::vector<uint8_t> context_map;
203 EntropyEncodingData codes;
204 HistogramParams params;
205 params.lz77_method = lz77 ? HistogramParams::LZ77Method::kLZ77
206 : HistogramParams::LZ77Method::kNone;
207 params.force_huffman = !ans;
208
209 BitWriter writer;
210 {
211 auto input_values_copy = input_values;
212 BuildAndEncodeHistograms(params, 1, input_values_copy, &codes, &context_map,
213 &writer, 0, nullptr);
214 WriteTokens(input_values_copy[0], codes, context_map, &writer, 0, nullptr);
215 writer.ZeroPadToByte();
216 }
217
218 // We do not truncate the output. Reading past the end reads out zeroes
219 // anyway.
220 BitReader br(writer.GetSpan());
221 Status status = true;
222 {
223 BitReaderScopedCloser bc(&br, &status);
224
225 std::vector<uint8_t> dec_context_map;
226 ANSCode decoded_codes;
227 ASSERT_TRUE(DecodeHistograms(&br, 1, &decoded_codes, &dec_context_map));
228 ASSERT_EQ(dec_context_map, context_map);
229 ANSSymbolReader reader(&decoded_codes, &br);
230
231 ANSSymbolReader::Checkpoint checkpoint;
232 size_t br_pos = 0;
233 constexpr size_t kInterval = ANSSymbolReader::kMaxCheckpointInterval - 2;
234 for (size_t i = 0; i < input_values[0].size(); i++) {
235 if (i % kInterval == 0 && i > 0) {
236 reader.Restore(checkpoint);
237 ASSERT_TRUE(br.Close());
238 br = BitReader(writer.GetSpan());
239 br.SkipBits(br_pos);
240 for (size_t j = i - kInterval; j < i; j++) {
241 Token symbol = input_values[0][j];
242 uint32_t read_symbol =
243 reader.ReadHybridUint(symbol.context, &br, dec_context_map);
244 ASSERT_EQ(read_symbol, symbol.value) << "j = " << j;
245 }
246 }
247 if (i % kInterval == 0) {
248 reader.Save(&checkpoint);
249 br_pos = br.TotalBitsConsumed();
250 }
251 Token symbol = input_values[0][i];
252 uint32_t read_symbol =
253 reader.ReadHybridUint(symbol.context, &br, dec_context_map);
254 ASSERT_EQ(read_symbol, symbol.value) << "i = " << i;
255 }
256 ASSERT_TRUE(reader.CheckANSFinalState());
257 }
258 EXPECT_TRUE(status);
259 }
260
TEST(ANSTest,TestCheckpointingANS)261 TEST(ANSTest, TestCheckpointingANS) {
262 TestCheckpointing(/*ans=*/true, /*lz77=*/false);
263 }
264
TEST(ANSTest,TestCheckpointingPrefix)265 TEST(ANSTest, TestCheckpointingPrefix) {
266 TestCheckpointing(/*ans=*/false, /*lz77=*/false);
267 }
268
TEST(ANSTest,TestCheckpointingANSLZ77)269 TEST(ANSTest, TestCheckpointingANSLZ77) {
270 TestCheckpointing(/*ans=*/true, /*lz77=*/true);
271 }
272
TEST(ANSTest,TestCheckpointingPrefixLZ77)273 TEST(ANSTest, TestCheckpointingPrefixLZ77) {
274 TestCheckpointing(/*ans=*/false, /*lz77=*/true);
275 }
276
277 } // namespace
278 } // namespace jxl
279