1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5
6 #include "lib/jxl/icc_codec.h"
7
8 #include <stdint.h>
9
10 #include <map>
11 #include <string>
12 #include <vector>
13
14 #include "lib/jxl/aux_out.h"
15 #include "lib/jxl/aux_out_fwd.h"
16 #include "lib/jxl/base/byte_order.h"
17 #include "lib/jxl/common.h"
18 #include "lib/jxl/dec_ans.h"
19 #include "lib/jxl/fields.h"
20 #include "lib/jxl/icc_codec_common.h"
21
22 namespace jxl {
23 namespace {
24
DecodeVarInt(const uint8_t * input,size_t inputSize,size_t * pos)25 uint64_t DecodeVarInt(const uint8_t* input, size_t inputSize, size_t* pos) {
26 size_t i;
27 uint64_t ret = 0;
28 for (i = 0; *pos + i < inputSize && i < 10; ++i) {
29 ret |= uint64_t(input[*pos + i] & 127) << uint64_t(7 * i);
30 // If the next-byte flag is not set, stop
31 if ((input[*pos + i] & 128) == 0) break;
32 }
33 // TODO: Return a decoding error if i == 10.
34 *pos += i + 1;
35 return ret;
36 }
37
38 // Shuffles or interleaves bytes, for example with width 2, turns "ABCDabcd"
39 // into "AaBbCcDc". Transposes a matrix of ceil(size / width) columns and
40 // width rows. There are size elements, size may be < width * height, if so the
41 // last elements of the rightmost column are missing, the missing spots are
42 // transposed along with the filled spots, and the result has the missing
43 // elements at the end of the bottom row. The input is the input matrix in
44 // scanline order but with missing elements skipped (which may occur in multiple
45 // locations), the output is the result matrix in scanline order (with
46 // no need to skip missing elements as they are past the end of the data).
Shuffle(uint8_t * data,size_t size,size_t width)47 void Shuffle(uint8_t* data, size_t size, size_t width) {
48 size_t height = (size + width - 1) / width; // amount of rows of output
49 PaddedBytes result(size);
50 // i = output index, j input index
51 size_t s = 0, j = 0;
52 for (size_t i = 0; i < size; i++) {
53 result[i] = data[j];
54 j += height;
55 if (j >= size) j = ++s;
56 }
57
58 for (size_t i = 0; i < size; i++) {
59 data[i] = result[i];
60 }
61 }
62
63 // TODO(eustas): should be 20, or even 18, once DecodeVarInt is improved;
64 // currently DecodeVarInt does not signal the errors, and marks
65 // 11 bytes as used even if only 10 are used (and 9 is enough for
66 // 63-bit values).
67 constexpr const size_t kPreambleSize = 22; // enough for reading 2 VarInts
68
69 } // namespace
70
71 // Mimics the beginning of UnpredictICC for quick validity check.
72 // At least kPreambleSize bytes of data should be valid at invocation time.
CheckPreamble(const PaddedBytes & data,size_t enc_size,size_t output_limit)73 Status CheckPreamble(const PaddedBytes& data, size_t enc_size,
74 size_t output_limit) {
75 const uint8_t* enc = data.data();
76 size_t size = data.size();
77 size_t pos = 0;
78 uint64_t osize = DecodeVarInt(enc, size, &pos);
79 JXL_RETURN_IF_ERROR(CheckIs32Bit(osize));
80 if (pos >= size) return JXL_FAILURE("Out of bounds");
81 uint64_t csize = DecodeVarInt(enc, size, &pos);
82 JXL_RETURN_IF_ERROR(CheckIs32Bit(csize));
83 JXL_RETURN_IF_ERROR(CheckOutOfBounds(pos, csize, size));
84 // We expect that UnpredictICC inflates input, not the other way round.
85 if (osize + 65536 < enc_size) return JXL_FAILURE("Malformed ICC");
86 if (output_limit && osize > output_limit) {
87 return JXL_FAILURE("Decoded ICC is too large");
88 }
89 return true;
90 }
91
92 // Decodes the result of PredictICC back to a valid ICC profile.
UnpredictICC(const uint8_t * enc,size_t size,PaddedBytes * result)93 Status UnpredictICC(const uint8_t* enc, size_t size, PaddedBytes* result) {
94 if (!result->empty()) return JXL_FAILURE("result must be empty initially");
95 size_t pos = 0;
96 // TODO(lode): technically speaking we need to check that the entire varint
97 // decoding never goes out of bounds, not just the first byte. This requires
98 // a DecodeVarInt function that returns an error code. It is safe to use
99 // DecodeVarInt with out of bounds values, it silently returns, but the
100 // specification requires an error. Idem for all DecodeVarInt below.
101 if (pos >= size) return JXL_FAILURE("Out of bounds");
102 uint64_t osize = DecodeVarInt(enc, size, &pos); // Output size
103 JXL_RETURN_IF_ERROR(CheckIs32Bit(osize));
104 if (pos >= size) return JXL_FAILURE("Out of bounds");
105 uint64_t csize = DecodeVarInt(enc, size, &pos); // Commands size
106 // Every command is translated to at least on byte.
107 JXL_RETURN_IF_ERROR(CheckIs32Bit(csize));
108 size_t cpos = pos; // pos in commands stream
109 JXL_RETURN_IF_ERROR(CheckOutOfBounds(pos, csize, size));
110 size_t commands_end = cpos + csize;
111 pos = commands_end; // pos in data stream
112
113 // Header
114 PaddedBytes header = ICCInitialHeaderPrediction();
115 EncodeUint32(0, osize, &header);
116 for (size_t i = 0; i <= kICCHeaderSize; i++) {
117 if (result->size() == osize) {
118 if (cpos != commands_end) return JXL_FAILURE("Not all commands used");
119 if (pos != size) return JXL_FAILURE("Not all data used");
120 return true; // Valid end
121 }
122 if (i == kICCHeaderSize) break; // Done
123 ICCPredictHeader(result->data(), result->size(), header.data(), i);
124 if (pos >= size) return JXL_FAILURE("Out of bounds");
125 result->push_back(enc[pos++] + header[i]);
126 }
127 if (cpos >= commands_end) return JXL_FAILURE("Out of bounds");
128
129 // Tag list
130 uint64_t numtags = DecodeVarInt(enc, size, &cpos);
131
132 if (numtags != 0) {
133 numtags--;
134 JXL_RETURN_IF_ERROR(CheckIs32Bit(numtags));
135 AppendUint32(numtags, result);
136 uint64_t prevtagstart = kICCHeaderSize + numtags * 12;
137 uint64_t prevtagsize = 0;
138 for (;;) {
139 if (result->size() > osize) return JXL_FAILURE("Invalid result size");
140 if (cpos > commands_end) return JXL_FAILURE("Out of bounds");
141 if (cpos == commands_end) break; // Valid end
142 uint8_t command = enc[cpos++];
143 uint8_t tagcode = command & 63;
144 Tag tag;
145 if (tagcode == 0) {
146 break;
147 } else if (tagcode == kCommandTagUnknown) {
148 JXL_RETURN_IF_ERROR(CheckOutOfBounds(pos, 4, size));
149 tag = DecodeKeyword(enc, size, pos);
150 pos += 4;
151 } else if (tagcode == kCommandTagTRC) {
152 tag = kRtrcTag;
153 } else if (tagcode == kCommandTagXYZ) {
154 tag = kRxyzTag;
155 } else {
156 if (tagcode - kCommandTagStringFirst >= kNumTagStrings) {
157 return JXL_FAILURE("Unknown tagcode");
158 }
159 tag = *kTagStrings[tagcode - kCommandTagStringFirst];
160 }
161 AppendKeyword(tag, result);
162
163 uint64_t tagstart;
164 uint64_t tagsize = prevtagsize;
165 if (tag == kRxyzTag || tag == kGxyzTag || tag == kBxyzTag ||
166 tag == kKxyzTag || tag == kWtptTag || tag == kBkptTag ||
167 tag == kLumiTag) {
168 tagsize = 20;
169 }
170
171 if (command & kFlagBitOffset) {
172 if (cpos >= commands_end) return JXL_FAILURE("Out of bounds");
173 tagstart = DecodeVarInt(enc, size, &cpos);
174 } else {
175 JXL_RETURN_IF_ERROR(CheckIs32Bit(prevtagstart));
176 tagstart = prevtagstart + prevtagsize;
177 }
178 JXL_RETURN_IF_ERROR(CheckIs32Bit(tagstart));
179 AppendUint32(tagstart, result);
180 if (command & kFlagBitSize) {
181 if (cpos >= commands_end) return JXL_FAILURE("Out of bounds");
182 tagsize = DecodeVarInt(enc, size, &cpos);
183 }
184 JXL_RETURN_IF_ERROR(CheckIs32Bit(tagsize));
185 AppendUint32(tagsize, result);
186 prevtagstart = tagstart;
187 prevtagsize = tagsize;
188
189 if (tagcode == kCommandTagTRC) {
190 AppendKeyword(kGtrcTag, result);
191 AppendUint32(tagstart, result);
192 AppendUint32(tagsize, result);
193 AppendKeyword(kBtrcTag, result);
194 AppendUint32(tagstart, result);
195 AppendUint32(tagsize, result);
196 }
197
198 if (tagcode == kCommandTagXYZ) {
199 JXL_RETURN_IF_ERROR(CheckIs32Bit(tagstart + tagsize * 2));
200 AppendKeyword(kGxyzTag, result);
201 AppendUint32(tagstart + tagsize, result);
202 AppendUint32(tagsize, result);
203 AppendKeyword(kBxyzTag, result);
204 AppendUint32(tagstart + tagsize * 2, result);
205 AppendUint32(tagsize, result);
206 }
207 }
208 }
209
210 // Main Content
211 for (;;) {
212 if (result->size() > osize) return JXL_FAILURE("Invalid result size");
213 if (cpos > commands_end) return JXL_FAILURE("Out of bounds");
214 if (cpos == commands_end) break; // Valid end
215 uint8_t command = enc[cpos++];
216 if (command == kCommandInsert) {
217 if (cpos >= commands_end) return JXL_FAILURE("Out of bounds");
218 uint64_t num = DecodeVarInt(enc, size, &cpos);
219 JXL_RETURN_IF_ERROR(CheckOutOfBounds(pos, num, size));
220 for (size_t i = 0; i < num; i++) {
221 result->push_back(enc[pos++]);
222 }
223 } else if (command == kCommandShuffle2 || command == kCommandShuffle4) {
224 if (cpos >= commands_end) return JXL_FAILURE("Out of bounds");
225 uint64_t num = DecodeVarInt(enc, size, &cpos);
226 JXL_RETURN_IF_ERROR(CheckOutOfBounds(pos, num, size));
227 PaddedBytes shuffled(num);
228 for (size_t i = 0; i < num; i++) {
229 shuffled[i] = enc[pos + i];
230 }
231 if (command == kCommandShuffle2) {
232 Shuffle(shuffled.data(), num, 2);
233 } else if (command == kCommandShuffle4) {
234 Shuffle(shuffled.data(), num, 4);
235 }
236 for (size_t i = 0; i < num; i++) {
237 result->push_back(shuffled[i]);
238 pos++;
239 }
240 } else if (command == kCommandPredict) {
241 JXL_RETURN_IF_ERROR(CheckOutOfBounds(cpos, 2, commands_end));
242 uint8_t flags = enc[cpos++];
243
244 size_t width = (flags & 3) + 1;
245 if (width == 3) return JXL_FAILURE("Invalid width");
246
247 int order = (flags & 12) >> 2;
248 if (order == 3) return JXL_FAILURE("Invalid order");
249
250 uint64_t stride = width;
251 if (flags & 16) {
252 if (cpos >= commands_end) return JXL_FAILURE("Out of bounds");
253 stride = DecodeVarInt(enc, size, &cpos);
254 if (stride < width) {
255 return JXL_FAILURE("Invalid stride");
256 }
257 }
258 // If stride * 4 >= result->size(), return failure. The check
259 // "size == 0 || ((size - 1) >> 2) < stride" corresponds to
260 // "stride * 4 >= size", but does not suffer from integer overflow.
261 // This check is more strict than necessary but follows the specification
262 // and the encoder should ensure this is followed.
263 if (result->empty() || ((result->size() - 1u) >> 2u) < stride) {
264 return JXL_FAILURE("Invalid stride");
265 }
266
267 if (cpos >= commands_end) return JXL_FAILURE("Out of bounds");
268 uint64_t num = DecodeVarInt(enc, size, &cpos); // in bytes
269 JXL_RETURN_IF_ERROR(CheckOutOfBounds(pos, num, size));
270
271 PaddedBytes shuffled(num);
272 for (size_t i = 0; i < num; i++) {
273 shuffled[i] = enc[pos + i];
274 }
275 if (width > 1) Shuffle(shuffled.data(), num, width);
276
277 size_t start = result->size();
278 for (size_t i = 0; i < num; i++) {
279 uint8_t predicted = LinearPredictICCValue(result->data(), start, i,
280 stride, width, order);
281 result->push_back(predicted + shuffled[i]);
282 }
283 pos += num;
284 } else if (command == kCommandXYZ) {
285 AppendKeyword(kXyz_Tag, result);
286 for (int i = 0; i < 4; i++) result->push_back(0);
287 JXL_RETURN_IF_ERROR(CheckOutOfBounds(pos, 12, size));
288 for (size_t i = 0; i < 12; i++) {
289 result->push_back(enc[pos++]);
290 }
291 } else if (command >= kCommandTypeStartFirst &&
292 command < kCommandTypeStartFirst + kNumTypeStrings) {
293 AppendKeyword(*kTypeStrings[command - kCommandTypeStartFirst], result);
294 for (size_t i = 0; i < 4; i++) {
295 result->push_back(0);
296 }
297 } else {
298 return JXL_FAILURE("Unknown command");
299 }
300 }
301
302 if (pos != size) return JXL_FAILURE("Not all data used");
303 if (result->size() != osize) return JXL_FAILURE("Invalid result size");
304
305 return true;
306 }
307
Init(BitReader * reader,size_t output_limit)308 Status ICCReader::Init(BitReader* reader, size_t output_limit) {
309 JXL_RETURN_IF_ERROR(CheckEOI(reader));
310 used_bits_base_ = reader->TotalBitsConsumed();
311 if (bits_to_skip_ == 0) {
312 enc_size_ = U64Coder::Read(reader);
313 if (enc_size_ > 268435456) {
314 // Avoid too large memory allocation for invalid file.
315 return JXL_FAILURE("Too large encoded profile");
316 }
317 JXL_RETURN_IF_ERROR(
318 DecodeHistograms(reader, kNumICCContexts, &code_, &context_map_));
319 ans_reader_ = ANSSymbolReader(&code_, reader);
320 i_ = 0;
321 decompressed_.resize(std::min<size_t>(i_ + 0x400, enc_size_));
322 for (; i_ < std::min<size_t>(2, enc_size_); i_++) {
323 decompressed_[i_] = ans_reader_.ReadHybridUint(
324 ICCANSContext(i_, i_ > 0 ? decompressed_[i_ - 1] : 0,
325 i_ > 1 ? decompressed_[i_ - 2] : 0),
326 reader, context_map_);
327 }
328 if (enc_size_ > kPreambleSize) {
329 for (; i_ < kPreambleSize; i_++) {
330 decompressed_[i_] = ans_reader_.ReadHybridUint(
331 ICCANSContext(i_, decompressed_[i_ - 1], decompressed_[i_ - 2]),
332 reader, context_map_);
333 }
334 JXL_RETURN_IF_ERROR(CheckEOI(reader));
335 JXL_RETURN_IF_ERROR(
336 CheckPreamble(decompressed_, enc_size_, output_limit));
337 }
338 bits_to_skip_ = reader->TotalBitsConsumed() - used_bits_base_;
339 } else {
340 reader->SkipBits(bits_to_skip_);
341 }
342 return true;
343 }
344
Process(BitReader * reader,PaddedBytes * icc)345 Status ICCReader::Process(BitReader* reader, PaddedBytes* icc) {
346 ANSSymbolReader::Checkpoint checkpoint;
347 size_t saved_i = 0;
348 auto save = [&]() {
349 ans_reader_.Save(&checkpoint);
350 bits_to_skip_ = reader->TotalBitsConsumed() - used_bits_base_;
351 saved_i = i_;
352 };
353 save();
354 auto check_and_restore = [&]() {
355 Status status = CheckEOI(reader);
356 if (!status) {
357 // not enough bytes.
358 ans_reader_.Restore(checkpoint);
359 i_ = saved_i;
360 return status;
361 }
362 return Status(true);
363 };
364 for (; i_ < enc_size_; i_++) {
365 if (i_ % ANSSymbolReader::kMaxCheckpointInterval == 0 && i_ > 0) {
366 JXL_RETURN_IF_ERROR(check_and_restore());
367 save();
368 if ((i_ > 0) && (((i_ & 0xFFFF) == 0))) {
369 float used_bytes =
370 (reader->TotalBitsConsumed() - used_bits_base_) / 8.0f;
371 if (i_ > used_bytes * 256) return JXL_FAILURE("Corrupted stream");
372 }
373 decompressed_.resize(std::min<size_t>(i_ + 0x400, enc_size_));
374 }
375 JXL_DASSERT(i_ >= 2);
376 decompressed_[i_] = ans_reader_.ReadHybridUint(
377 ICCANSContext(i_, decompressed_[i_ - 1], decompressed_[i_ - 2]), reader,
378 context_map_);
379 }
380 JXL_RETURN_IF_ERROR(check_and_restore());
381 bits_to_skip_ = reader->TotalBitsConsumed() - used_bits_base_;
382 if (!ans_reader_.CheckANSFinalState()) {
383 return JXL_FAILURE("Corrupted ICC profile");
384 }
385
386 icc->clear();
387 return UnpredictICC(decompressed_.data(), decompressed_.size(), icc);
388 }
389
CheckEOI(BitReader * reader)390 Status ICCReader::CheckEOI(BitReader* reader) {
391 if (reader->AllReadsWithinBounds()) return true;
392 return JXL_STATUS(StatusCode::kNotEnoughBytes,
393 "Not enough bytes for reading ICC profile");
394 }
395
ReadICC(BitReader * JXL_RESTRICT reader,PaddedBytes * JXL_RESTRICT icc,size_t output_limit)396 Status ReadICC(BitReader* JXL_RESTRICT reader, PaddedBytes* JXL_RESTRICT icc,
397 size_t output_limit) {
398 ICCReader icc_reader;
399 JXL_RETURN_IF_ERROR(icc_reader.Init(reader, output_limit));
400 JXL_RETURN_IF_ERROR(icc_reader.Process(reader, icc));
401 return true;
402 }
403
404 } // namespace jxl
405