1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5 
6 #include <stddef.h>
7 #include <stdint.h>
8 
9 #include <vector>
10 
11 #include "gtest/gtest.h"
12 #include "lib/jxl/ans_params.h"
13 #include "lib/jxl/aux_out_fwd.h"
14 #include "lib/jxl/base/random.h"
15 #include "lib/jxl/base/span.h"
16 #include "lib/jxl/dec_ans.h"
17 #include "lib/jxl/dec_bit_reader.h"
18 #include "lib/jxl/enc_ans.h"
19 #include "lib/jxl/enc_bit_writer.h"
20 
21 namespace jxl {
22 namespace {
23 
RoundtripTestcase(int n_histograms,int alphabet_size,const std::vector<Token> & input_values)24 void RoundtripTestcase(int n_histograms, int alphabet_size,
25                        const std::vector<Token>& input_values) {
26   constexpr uint16_t kMagic1 = 0x9e33;
27   constexpr uint16_t kMagic2 = 0x8b04;
28 
29   BitWriter writer;
30   // Space for magic bytes.
31   BitWriter::Allotment allotment_magic1(&writer, 16);
32   writer.Write(16, kMagic1);
33   ReclaimAndCharge(&writer, &allotment_magic1, 0, nullptr);
34 
35   std::vector<uint8_t> context_map;
36   EntropyEncodingData codes;
37   std::vector<std::vector<Token>> input_values_vec;
38   input_values_vec.push_back(input_values);
39 
40   BuildAndEncodeHistograms(HistogramParams(), n_histograms, input_values_vec,
41                            &codes, &context_map, &writer, 0, nullptr);
42   WriteTokens(input_values_vec[0], codes, context_map, &writer, 0, nullptr);
43 
44   // Magic bytes + padding
45   BitWriter::Allotment allotment_magic2(&writer, 24);
46   writer.Write(16, kMagic2);
47   writer.ZeroPadToByte();
48   ReclaimAndCharge(&writer, &allotment_magic2, 0, nullptr);
49 
50   // We do not truncate the output. Reading past the end reads out zeroes
51   // anyway.
52   BitReader br(writer.GetSpan());
53 
54   ASSERT_EQ(br.ReadBits(16), kMagic1);
55 
56   std::vector<uint8_t> dec_context_map;
57   ANSCode decoded_codes;
58   ASSERT_TRUE(
59       DecodeHistograms(&br, n_histograms, &decoded_codes, &dec_context_map));
60   ASSERT_EQ(dec_context_map, context_map);
61   ANSSymbolReader reader(&decoded_codes, &br);
62 
63   for (const Token& symbol : input_values) {
64     uint32_t read_symbol =
65         reader.ReadHybridUint(symbol.context, &br, dec_context_map);
66     ASSERT_EQ(read_symbol, symbol.value);
67   }
68   ASSERT_TRUE(reader.CheckANSFinalState());
69 
70   ASSERT_EQ(br.ReadBits(16), kMagic2);
71   EXPECT_TRUE(br.Close());
72 }
73 
TEST(ANSTest,EmptyRoundtrip)74 TEST(ANSTest, EmptyRoundtrip) {
75   RoundtripTestcase(2, ANS_MAX_ALPHABET_SIZE, std::vector<Token>());
76 }
77 
TEST(ANSTest,SingleSymbolRoundtrip)78 TEST(ANSTest, SingleSymbolRoundtrip) {
79   for (uint32_t i = 0; i < ANS_MAX_ALPHABET_SIZE; i++) {
80     RoundtripTestcase(2, ANS_MAX_ALPHABET_SIZE, {{0, i}});
81   }
82   for (uint32_t i = 0; i < ANS_MAX_ALPHABET_SIZE; i++) {
83     RoundtripTestcase(2, ANS_MAX_ALPHABET_SIZE,
84                       std::vector<Token>(1024, {0, i}));
85   }
86 }
87 
88 #if defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER) || \
89     defined(THREAD_SANITIZER)
90 constexpr size_t kReps = 3;
91 #else
92 constexpr size_t kReps = 10;
93 #endif
94 
RoundtripRandomStream(int alphabet_size,size_t reps=kReps,size_t num=1<<18)95 void RoundtripRandomStream(int alphabet_size, size_t reps = kReps,
96                            size_t num = 1 << 18) {
97   constexpr int kNumHistograms = 3;
98   Rng rng(0);
99   for (size_t i = 0; i < reps; i++) {
100     std::vector<Token> symbols;
101     for (size_t j = 0; j < num; j++) {
102       int context = rng.UniformI(0, kNumHistograms);
103       int value = rng.UniformU(0, alphabet_size);
104       symbols.emplace_back(context, value);
105     }
106     RoundtripTestcase(kNumHistograms, alphabet_size, symbols);
107   }
108 }
109 
RoundtripRandomUnbalancedStream(int alphabet_size)110 void RoundtripRandomUnbalancedStream(int alphabet_size) {
111   constexpr int kNumHistograms = 3;
112   constexpr int kPrecision = 1 << 10;
113   Rng rng(0);
114   for (size_t i = 0; i < kReps; i++) {
115     std::vector<int> distributions[kNumHistograms];
116     for (int j = 0; j < kNumHistograms; j++) {
117       distributions[j].resize(kPrecision);
118       int symbol = 0;
119       int remaining = 1;
120       for (int k = 0; k < kPrecision; k++) {
121         if (remaining == 0) {
122           if (symbol < alphabet_size - 1) symbol++;
123           // There is no meaning behind this distribution: it's anything that
124           // will create a nonuniform distribution and won't have too few
125           // symbols usually. Also we want different distributions we get to be
126           // sufficiently dissimilar.
127           remaining = rng.UniformU(0, kPrecision - k + 1);
128         }
129         distributions[j][k] = symbol;
130         remaining--;
131       }
132     }
133     std::vector<Token> symbols;
134     for (int j = 0; j < 1 << 18; j++) {
135       int context = rng.UniformI(0, kNumHistograms);
136       int value = rng.UniformU(0, kPrecision);
137       symbols.emplace_back(context, value);
138     }
139     RoundtripTestcase(kNumHistograms + 1, alphabet_size, symbols);
140   }
141 }
142 
TEST(ANSTest,RandomStreamRoundtrip3Small)143 TEST(ANSTest, RandomStreamRoundtrip3Small) { RoundtripRandomStream(3, 1, 16); }
144 
TEST(ANSTest,RandomStreamRoundtrip3)145 TEST(ANSTest, RandomStreamRoundtrip3) { RoundtripRandomStream(3); }
146 
TEST(ANSTest,RandomStreamRoundtripBig)147 TEST(ANSTest, RandomStreamRoundtripBig) {
148   RoundtripRandomStream(ANS_MAX_ALPHABET_SIZE);
149 }
150 
TEST(ANSTest,RandomUnbalancedStreamRoundtrip3)151 TEST(ANSTest, RandomUnbalancedStreamRoundtrip3) {
152   RoundtripRandomUnbalancedStream(3);
153 }
154 
TEST(ANSTest,RandomUnbalancedStreamRoundtripBig)155 TEST(ANSTest, RandomUnbalancedStreamRoundtripBig) {
156   RoundtripRandomUnbalancedStream(ANS_MAX_ALPHABET_SIZE);
157 }
158 
TEST(ANSTest,UintConfigRoundtrip)159 TEST(ANSTest, UintConfigRoundtrip) {
160   for (size_t log_alpha_size = 5; log_alpha_size <= 8; log_alpha_size++) {
161     std::vector<HybridUintConfig> uint_config, uint_config_dec;
162     for (size_t i = 0; i < log_alpha_size; i++) {
163       for (size_t j = 0; j <= i; j++) {
164         for (size_t k = 0; k <= i - j; k++) {
165           uint_config.emplace_back(i, j, k);
166         }
167       }
168     }
169     uint_config.emplace_back(log_alpha_size, 0, 0);
170     uint_config_dec.resize(uint_config.size());
171     BitWriter writer;
172     BitWriter::Allotment allotment(&writer, 10 * uint_config.size());
173     EncodeUintConfigs(uint_config, &writer, log_alpha_size);
174     ReclaimAndCharge(&writer, &allotment, 0, nullptr);
175     writer.ZeroPadToByte();
176     BitReader br(writer.GetSpan());
177     EXPECT_TRUE(DecodeUintConfigs(log_alpha_size, &uint_config_dec, &br));
178     EXPECT_TRUE(br.Close());
179     for (size_t i = 0; i < uint_config.size(); i++) {
180       EXPECT_EQ(uint_config[i].split_token, uint_config_dec[i].split_token);
181       EXPECT_EQ(uint_config[i].msb_in_token, uint_config_dec[i].msb_in_token);
182       EXPECT_EQ(uint_config[i].lsb_in_token, uint_config_dec[i].lsb_in_token);
183     }
184   }
185 }
186 
TestCheckpointing(bool ans,bool lz77)187 void TestCheckpointing(bool ans, bool lz77) {
188   std::vector<std::vector<Token>> input_values(1);
189   for (size_t i = 0; i < 1024; i++) {
190     input_values[0].push_back(Token(0, i % 4));
191   }
192   // up to lz77 window size.
193   for (size_t i = 0; i < (1 << 20) - 1022; i++) {
194     input_values[0].push_back(Token(0, (i % 5) + 4));
195   }
196   // Ensure that when the window wraps around, new values are different.
197   input_values[0].push_back(Token(0, 0));
198   for (size_t i = 0; i < 1024; i++) {
199     input_values[0].push_back(Token(0, i % 4));
200   }
201 
202   std::vector<uint8_t> context_map;
203   EntropyEncodingData codes;
204   HistogramParams params;
205   params.lz77_method = lz77 ? HistogramParams::LZ77Method::kLZ77
206                             : HistogramParams::LZ77Method::kNone;
207   params.force_huffman = !ans;
208 
209   BitWriter writer;
210   {
211     auto input_values_copy = input_values;
212     BuildAndEncodeHistograms(params, 1, input_values_copy, &codes, &context_map,
213                              &writer, 0, nullptr);
214     WriteTokens(input_values_copy[0], codes, context_map, &writer, 0, nullptr);
215     writer.ZeroPadToByte();
216   }
217 
218   // We do not truncate the output. Reading past the end reads out zeroes
219   // anyway.
220   BitReader br(writer.GetSpan());
221   Status status = true;
222   {
223     BitReaderScopedCloser bc(&br, &status);
224 
225     std::vector<uint8_t> dec_context_map;
226     ANSCode decoded_codes;
227     ASSERT_TRUE(DecodeHistograms(&br, 1, &decoded_codes, &dec_context_map));
228     ASSERT_EQ(dec_context_map, context_map);
229     ANSSymbolReader reader(&decoded_codes, &br);
230 
231     ANSSymbolReader::Checkpoint checkpoint;
232     size_t br_pos = 0;
233     constexpr size_t kInterval = ANSSymbolReader::kMaxCheckpointInterval - 2;
234     for (size_t i = 0; i < input_values[0].size(); i++) {
235       if (i % kInterval == 0 && i > 0) {
236         reader.Restore(checkpoint);
237         ASSERT_TRUE(br.Close());
238         br = BitReader(writer.GetSpan());
239         br.SkipBits(br_pos);
240         for (size_t j = i - kInterval; j < i; j++) {
241           Token symbol = input_values[0][j];
242           uint32_t read_symbol =
243               reader.ReadHybridUint(symbol.context, &br, dec_context_map);
244           ASSERT_EQ(read_symbol, symbol.value) << "j = " << j;
245         }
246       }
247       if (i % kInterval == 0) {
248         reader.Save(&checkpoint);
249         br_pos = br.TotalBitsConsumed();
250       }
251       Token symbol = input_values[0][i];
252       uint32_t read_symbol =
253           reader.ReadHybridUint(symbol.context, &br, dec_context_map);
254       ASSERT_EQ(read_symbol, symbol.value) << "i = " << i;
255     }
256     ASSERT_TRUE(reader.CheckANSFinalState());
257   }
258   EXPECT_TRUE(status);
259 }
260 
TEST(ANSTest,TestCheckpointingANS)261 TEST(ANSTest, TestCheckpointingANS) {
262   TestCheckpointing(/*ans=*/true, /*lz77=*/false);
263 }
264 
TEST(ANSTest,TestCheckpointingPrefix)265 TEST(ANSTest, TestCheckpointingPrefix) {
266   TestCheckpointing(/*ans=*/false, /*lz77=*/false);
267 }
268 
TEST(ANSTest,TestCheckpointingANSLZ77)269 TEST(ANSTest, TestCheckpointingANSLZ77) {
270   TestCheckpointing(/*ans=*/true, /*lz77=*/true);
271 }
272 
TEST(ANSTest,TestCheckpointingPrefixLZ77)273 TEST(ANSTest, TestCheckpointingPrefixLZ77) {
274   TestCheckpointing(/*ans=*/false, /*lz77=*/true);
275 }
276 
277 }  // namespace
278 }  // namespace jxl
279