1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5 
6 #include "lib/jxl/icc_codec.h"
7 
8 #include <stdint.h>
9 
10 #include <map>
11 #include <string>
12 #include <vector>
13 
14 #include "lib/jxl/aux_out.h"
15 #include "lib/jxl/aux_out_fwd.h"
16 #include "lib/jxl/base/byte_order.h"
17 #include "lib/jxl/common.h"
18 #include "lib/jxl/dec_ans.h"
19 #include "lib/jxl/fields.h"
20 #include "lib/jxl/icc_codec_common.h"
21 
22 namespace jxl {
23 namespace {
24 
DecodeVarInt(const uint8_t * input,size_t inputSize,size_t * pos)25 uint64_t DecodeVarInt(const uint8_t* input, size_t inputSize, size_t* pos) {
26   size_t i;
27   uint64_t ret = 0;
28   for (i = 0; *pos + i < inputSize && i < 10; ++i) {
29     ret |= uint64_t(input[*pos + i] & 127) << uint64_t(7 * i);
30     // If the next-byte flag is not set, stop
31     if ((input[*pos + i] & 128) == 0) break;
32   }
33   // TODO: Return a decoding error if i == 10.
34   *pos += i + 1;
35   return ret;
36 }
37 
38 // Shuffles or interleaves bytes, for example with width 2, turns "ABCDabcd"
39 // into "AaBbCcDc". Transposes a matrix of ceil(size / width) columns and
40 // width rows. There are size elements, size may be < width * height, if so the
41 // last elements of the rightmost column are missing, the missing spots are
42 // transposed along with the filled spots, and the result has the missing
43 // elements at the end of the bottom row. The input is the input matrix in
44 // scanline order but with missing elements skipped (which may occur in multiple
45 // locations), the output is the result matrix in scanline order (with
46 // no need to skip missing elements as they are past the end of the data).
Shuffle(uint8_t * data,size_t size,size_t width)47 void Shuffle(uint8_t* data, size_t size, size_t width) {
48   size_t height = (size + width - 1) / width;  // amount of rows of output
49   PaddedBytes result(size);
50   // i = output index, j input index
51   size_t s = 0, j = 0;
52   for (size_t i = 0; i < size; i++) {
53     result[i] = data[j];
54     j += height;
55     if (j >= size) j = ++s;
56   }
57 
58   for (size_t i = 0; i < size; i++) {
59     data[i] = result[i];
60   }
61 }
62 
63 // TODO(eustas): should be 20, or even 18, once DecodeVarInt is improved;
64 //               currently DecodeVarInt does not signal the errors, and marks
65 //               11 bytes as used even if only 10 are used (and 9 is enough for
66 //               63-bit values).
67 constexpr const size_t kPreambleSize = 22;  // enough for reading 2 VarInts
68 
69 }  // namespace
70 
71 // Mimics the beginning of UnpredictICC for quick validity check.
72 // At least kPreambleSize bytes of data should be valid at invocation time.
CheckPreamble(const PaddedBytes & data,size_t enc_size,size_t output_limit)73 Status CheckPreamble(const PaddedBytes& data, size_t enc_size,
74                      size_t output_limit) {
75   const uint8_t* enc = data.data();
76   size_t size = data.size();
77   size_t pos = 0;
78   uint64_t osize = DecodeVarInt(enc, size, &pos);
79   JXL_RETURN_IF_ERROR(CheckIs32Bit(osize));
80   if (pos >= size) return JXL_FAILURE("Out of bounds");
81   uint64_t csize = DecodeVarInt(enc, size, &pos);
82   JXL_RETURN_IF_ERROR(CheckIs32Bit(csize));
83   JXL_RETURN_IF_ERROR(CheckOutOfBounds(pos, csize, size));
84   // We expect that UnpredictICC inflates input, not the other way round.
85   if (osize + 65536 < enc_size) return JXL_FAILURE("Malformed ICC");
86   if (output_limit && osize > output_limit) {
87     return JXL_FAILURE("Decoded ICC is too large");
88   }
89   return true;
90 }
91 
92 // Decodes the result of PredictICC back to a valid ICC profile.
UnpredictICC(const uint8_t * enc,size_t size,PaddedBytes * result)93 Status UnpredictICC(const uint8_t* enc, size_t size, PaddedBytes* result) {
94   if (!result->empty()) return JXL_FAILURE("result must be empty initially");
95   size_t pos = 0;
96   // TODO(lode): technically speaking we need to check that the entire varint
97   // decoding never goes out of bounds, not just the first byte. This requires
98   // a DecodeVarInt function that returns an error code. It is safe to use
99   // DecodeVarInt with out of bounds values, it silently returns, but the
100   // specification requires an error. Idem for all DecodeVarInt below.
101   if (pos >= size) return JXL_FAILURE("Out of bounds");
102   uint64_t osize = DecodeVarInt(enc, size, &pos);  // Output size
103   JXL_RETURN_IF_ERROR(CheckIs32Bit(osize));
104   if (pos >= size) return JXL_FAILURE("Out of bounds");
105   uint64_t csize = DecodeVarInt(enc, size, &pos);  // Commands size
106   // Every command is translated to at least on byte.
107   JXL_RETURN_IF_ERROR(CheckIs32Bit(csize));
108   size_t cpos = pos;  // pos in commands stream
109   JXL_RETURN_IF_ERROR(CheckOutOfBounds(pos, csize, size));
110   size_t commands_end = cpos + csize;
111   pos = commands_end;  // pos in data stream
112 
113   // Header
114   PaddedBytes header = ICCInitialHeaderPrediction();
115   EncodeUint32(0, osize, &header);
116   for (size_t i = 0; i <= kICCHeaderSize; i++) {
117     if (result->size() == osize) {
118       if (cpos != commands_end) return JXL_FAILURE("Not all commands used");
119       if (pos != size) return JXL_FAILURE("Not all data used");
120       return true;  // Valid end
121     }
122     if (i == kICCHeaderSize) break;  // Done
123     ICCPredictHeader(result->data(), result->size(), header.data(), i);
124     if (pos >= size) return JXL_FAILURE("Out of bounds");
125     result->push_back(enc[pos++] + header[i]);
126   }
127   if (cpos >= commands_end) return JXL_FAILURE("Out of bounds");
128 
129   // Tag list
130   uint64_t numtags = DecodeVarInt(enc, size, &cpos);
131 
132   if (numtags != 0) {
133     numtags--;
134     JXL_RETURN_IF_ERROR(CheckIs32Bit(numtags));
135     AppendUint32(numtags, result);
136     uint64_t prevtagstart = kICCHeaderSize + numtags * 12;
137     uint64_t prevtagsize = 0;
138     for (;;) {
139       if (result->size() > osize) return JXL_FAILURE("Invalid result size");
140       if (cpos > commands_end) return JXL_FAILURE("Out of bounds");
141       if (cpos == commands_end) break;  // Valid end
142       uint8_t command = enc[cpos++];
143       uint8_t tagcode = command & 63;
144       Tag tag;
145       if (tagcode == 0) {
146         break;
147       } else if (tagcode == kCommandTagUnknown) {
148         JXL_RETURN_IF_ERROR(CheckOutOfBounds(pos, 4, size));
149         tag = DecodeKeyword(enc, size, pos);
150         pos += 4;
151       } else if (tagcode == kCommandTagTRC) {
152         tag = kRtrcTag;
153       } else if (tagcode == kCommandTagXYZ) {
154         tag = kRxyzTag;
155       } else {
156         if (tagcode - kCommandTagStringFirst >= kNumTagStrings) {
157           return JXL_FAILURE("Unknown tagcode");
158         }
159         tag = *kTagStrings[tagcode - kCommandTagStringFirst];
160       }
161       AppendKeyword(tag, result);
162 
163       uint64_t tagstart;
164       uint64_t tagsize = prevtagsize;
165       if (tag == kRxyzTag || tag == kGxyzTag || tag == kBxyzTag ||
166           tag == kKxyzTag || tag == kWtptTag || tag == kBkptTag ||
167           tag == kLumiTag) {
168         tagsize = 20;
169       }
170 
171       if (command & kFlagBitOffset) {
172         if (cpos >= commands_end) return JXL_FAILURE("Out of bounds");
173         tagstart = DecodeVarInt(enc, size, &cpos);
174       } else {
175         JXL_RETURN_IF_ERROR(CheckIs32Bit(prevtagstart));
176         tagstart = prevtagstart + prevtagsize;
177       }
178       JXL_RETURN_IF_ERROR(CheckIs32Bit(tagstart));
179       AppendUint32(tagstart, result);
180       if (command & kFlagBitSize) {
181         if (cpos >= commands_end) return JXL_FAILURE("Out of bounds");
182         tagsize = DecodeVarInt(enc, size, &cpos);
183       }
184       JXL_RETURN_IF_ERROR(CheckIs32Bit(tagsize));
185       AppendUint32(tagsize, result);
186       prevtagstart = tagstart;
187       prevtagsize = tagsize;
188 
189       if (tagcode == kCommandTagTRC) {
190         AppendKeyword(kGtrcTag, result);
191         AppendUint32(tagstart, result);
192         AppendUint32(tagsize, result);
193         AppendKeyword(kBtrcTag, result);
194         AppendUint32(tagstart, result);
195         AppendUint32(tagsize, result);
196       }
197 
198       if (tagcode == kCommandTagXYZ) {
199         JXL_RETURN_IF_ERROR(CheckIs32Bit(tagstart + tagsize * 2));
200         AppendKeyword(kGxyzTag, result);
201         AppendUint32(tagstart + tagsize, result);
202         AppendUint32(tagsize, result);
203         AppendKeyword(kBxyzTag, result);
204         AppendUint32(tagstart + tagsize * 2, result);
205         AppendUint32(tagsize, result);
206       }
207     }
208   }
209 
210   // Main Content
211   for (;;) {
212     if (result->size() > osize) return JXL_FAILURE("Invalid result size");
213     if (cpos > commands_end) return JXL_FAILURE("Out of bounds");
214     if (cpos == commands_end) break;  // Valid end
215     uint8_t command = enc[cpos++];
216     if (command == kCommandInsert) {
217       if (cpos >= commands_end) return JXL_FAILURE("Out of bounds");
218       uint64_t num = DecodeVarInt(enc, size, &cpos);
219       JXL_RETURN_IF_ERROR(CheckOutOfBounds(pos, num, size));
220       for (size_t i = 0; i < num; i++) {
221         result->push_back(enc[pos++]);
222       }
223     } else if (command == kCommandShuffle2 || command == kCommandShuffle4) {
224       if (cpos >= commands_end) return JXL_FAILURE("Out of bounds");
225       uint64_t num = DecodeVarInt(enc, size, &cpos);
226       JXL_RETURN_IF_ERROR(CheckOutOfBounds(pos, num, size));
227       PaddedBytes shuffled(num);
228       for (size_t i = 0; i < num; i++) {
229         shuffled[i] = enc[pos + i];
230       }
231       if (command == kCommandShuffle2) {
232         Shuffle(shuffled.data(), num, 2);
233       } else if (command == kCommandShuffle4) {
234         Shuffle(shuffled.data(), num, 4);
235       }
236       for (size_t i = 0; i < num; i++) {
237         result->push_back(shuffled[i]);
238         pos++;
239       }
240     } else if (command == kCommandPredict) {
241       JXL_RETURN_IF_ERROR(CheckOutOfBounds(cpos, 2, commands_end));
242       uint8_t flags = enc[cpos++];
243 
244       size_t width = (flags & 3) + 1;
245       if (width == 3) return JXL_FAILURE("Invalid width");
246 
247       int order = (flags & 12) >> 2;
248       if (order == 3) return JXL_FAILURE("Invalid order");
249 
250       uint64_t stride = width;
251       if (flags & 16) {
252         if (cpos >= commands_end) return JXL_FAILURE("Out of bounds");
253         stride = DecodeVarInt(enc, size, &cpos);
254         if (stride < width) {
255           return JXL_FAILURE("Invalid stride");
256         }
257       }
258       // If stride * 4 >= result->size(), return failure. The check
259       // "size == 0 || ((size - 1) >> 2) < stride" corresponds to
260       // "stride * 4 >= size", but does not suffer from integer overflow.
261       // This check is more strict than necessary but follows the specification
262       // and the encoder should ensure this is followed.
263       if (result->empty() || ((result->size() - 1u) >> 2u) < stride) {
264         return JXL_FAILURE("Invalid stride");
265       }
266 
267       if (cpos >= commands_end) return JXL_FAILURE("Out of bounds");
268       uint64_t num = DecodeVarInt(enc, size, &cpos);  // in bytes
269       JXL_RETURN_IF_ERROR(CheckOutOfBounds(pos, num, size));
270 
271       PaddedBytes shuffled(num);
272       for (size_t i = 0; i < num; i++) {
273         shuffled[i] = enc[pos + i];
274       }
275       if (width > 1) Shuffle(shuffled.data(), num, width);
276 
277       size_t start = result->size();
278       for (size_t i = 0; i < num; i++) {
279         uint8_t predicted = LinearPredictICCValue(result->data(), start, i,
280                                                   stride, width, order);
281         result->push_back(predicted + shuffled[i]);
282       }
283       pos += num;
284     } else if (command == kCommandXYZ) {
285       AppendKeyword(kXyz_Tag, result);
286       for (int i = 0; i < 4; i++) result->push_back(0);
287       JXL_RETURN_IF_ERROR(CheckOutOfBounds(pos, 12, size));
288       for (size_t i = 0; i < 12; i++) {
289         result->push_back(enc[pos++]);
290       }
291     } else if (command >= kCommandTypeStartFirst &&
292                command < kCommandTypeStartFirst + kNumTypeStrings) {
293       AppendKeyword(*kTypeStrings[command - kCommandTypeStartFirst], result);
294       for (size_t i = 0; i < 4; i++) {
295         result->push_back(0);
296       }
297     } else {
298       return JXL_FAILURE("Unknown command");
299     }
300   }
301 
302   if (pos != size) return JXL_FAILURE("Not all data used");
303   if (result->size() != osize) return JXL_FAILURE("Invalid result size");
304 
305   return true;
306 }
307 
Init(BitReader * reader,size_t output_limit)308 Status ICCReader::Init(BitReader* reader, size_t output_limit) {
309   JXL_RETURN_IF_ERROR(CheckEOI(reader));
310   used_bits_base_ = reader->TotalBitsConsumed();
311   if (bits_to_skip_ == 0) {
312     enc_size_ = U64Coder::Read(reader);
313     if (enc_size_ > 268435456) {
314       // Avoid too large memory allocation for invalid file.
315       return JXL_FAILURE("Too large encoded profile");
316     }
317     JXL_RETURN_IF_ERROR(
318         DecodeHistograms(reader, kNumICCContexts, &code_, &context_map_));
319     ans_reader_ = ANSSymbolReader(&code_, reader);
320     i_ = 0;
321     decompressed_.resize(std::min<size_t>(i_ + 0x400, enc_size_));
322     for (; i_ < std::min<size_t>(2, enc_size_); i_++) {
323       decompressed_[i_] = ans_reader_.ReadHybridUint(
324           ICCANSContext(i_, i_ > 0 ? decompressed_[i_ - 1] : 0,
325                         i_ > 1 ? decompressed_[i_ - 2] : 0),
326           reader, context_map_);
327     }
328     if (enc_size_ > kPreambleSize) {
329       for (; i_ < kPreambleSize; i_++) {
330         decompressed_[i_] = ans_reader_.ReadHybridUint(
331             ICCANSContext(i_, decompressed_[i_ - 1], decompressed_[i_ - 2]),
332             reader, context_map_);
333       }
334       JXL_RETURN_IF_ERROR(CheckEOI(reader));
335       JXL_RETURN_IF_ERROR(
336           CheckPreamble(decompressed_, enc_size_, output_limit));
337     }
338     bits_to_skip_ = reader->TotalBitsConsumed() - used_bits_base_;
339   } else {
340     reader->SkipBits(bits_to_skip_);
341   }
342   return true;
343 }
344 
Process(BitReader * reader,PaddedBytes * icc)345 Status ICCReader::Process(BitReader* reader, PaddedBytes* icc) {
346   ANSSymbolReader::Checkpoint checkpoint;
347   size_t saved_i = 0;
348   auto save = [&]() {
349     ans_reader_.Save(&checkpoint);
350     bits_to_skip_ = reader->TotalBitsConsumed() - used_bits_base_;
351     saved_i = i_;
352   };
353   save();
354   auto check_and_restore = [&]() {
355     Status status = CheckEOI(reader);
356     if (!status) {
357       // not enough bytes.
358       ans_reader_.Restore(checkpoint);
359       i_ = saved_i;
360       return status;
361     }
362     return Status(true);
363   };
364   for (; i_ < enc_size_; i_++) {
365     if (i_ % ANSSymbolReader::kMaxCheckpointInterval == 0 && i_ > 0) {
366       JXL_RETURN_IF_ERROR(check_and_restore());
367       save();
368       if ((i_ > 0) && (((i_ & 0xFFFF) == 0))) {
369         float used_bytes =
370             (reader->TotalBitsConsumed() - used_bits_base_) / 8.0f;
371         if (i_ > used_bytes * 256) return JXL_FAILURE("Corrupted stream");
372       }
373       decompressed_.resize(std::min<size_t>(i_ + 0x400, enc_size_));
374     }
375     JXL_DASSERT(i_ >= 2);
376     decompressed_[i_] = ans_reader_.ReadHybridUint(
377         ICCANSContext(i_, decompressed_[i_ - 1], decompressed_[i_ - 2]), reader,
378         context_map_);
379   }
380   JXL_RETURN_IF_ERROR(check_and_restore());
381   bits_to_skip_ = reader->TotalBitsConsumed() - used_bits_base_;
382   if (!ans_reader_.CheckANSFinalState()) {
383     return JXL_FAILURE("Corrupted ICC profile");
384   }
385 
386   icc->clear();
387   return UnpredictICC(decompressed_.data(), decompressed_.size(), icc);
388 }
389 
CheckEOI(BitReader * reader)390 Status ICCReader::CheckEOI(BitReader* reader) {
391   if (reader->AllReadsWithinBounds()) return true;
392   return JXL_STATUS(StatusCode::kNotEnoughBytes,
393                     "Not enough bytes for reading ICC profile");
394 }
395 
ReadICC(BitReader * JXL_RESTRICT reader,PaddedBytes * JXL_RESTRICT icc,size_t output_limit)396 Status ReadICC(BitReader* JXL_RESTRICT reader, PaddedBytes* JXL_RESTRICT icc,
397                size_t output_limit) {
398   ICCReader icc_reader;
399   JXL_RETURN_IF_ERROR(icc_reader.Init(reader, output_limit));
400   JXL_RETURN_IF_ERROR(icc_reader.Process(reader, icc));
401   return true;
402 }
403 
404 }  // namespace jxl
405