1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5 
6 #include "lib/jxl/enc_icc_codec.h"
7 
8 #include <stdint.h>
9 
10 #include <map>
11 #include <string>
12 #include <vector>
13 
14 #include "lib/jxl/aux_out.h"
15 #include "lib/jxl/aux_out_fwd.h"
16 #include "lib/jxl/base/byte_order.h"
17 #include "lib/jxl/common.h"
18 #include "lib/jxl/enc_ans.h"
19 #include "lib/jxl/fields.h"
20 #include "lib/jxl/icc_codec_common.h"
21 
22 namespace jxl {
23 namespace {
24 
EncodeVarInt(uint64_t value,size_t output_size,size_t * output_pos,uint8_t * output)25 bool EncodeVarInt(uint64_t value, size_t output_size, size_t* output_pos,
26                   uint8_t* output) {
27   // While more than 7 bits of data are left,
28   // store 7 bits and set the next byte flag
29   while (value > 127) {
30     if (*output_pos > output_size) return false;
31     // |128: Set the next byte flag
32     output[(*output_pos)++] = ((uint8_t)(value & 127)) | 128;
33     // Remove the seven bits we just wrote
34     value >>= 7;
35   }
36   if (*output_pos > output_size) return false;
37   output[(*output_pos)++] = ((uint8_t)value) & 127;
38   return true;
39 }
40 
EncodeVarInt(uint64_t value,PaddedBytes * data)41 void EncodeVarInt(uint64_t value, PaddedBytes* data) {
42   size_t pos = data->size();
43   data->resize(data->size() + 9);
44   JXL_CHECK(EncodeVarInt(value, data->size(), &pos, data->data()));
45   data->resize(pos);
46 }
47 
48 // Unshuffles or de-interleaves bytes, for example with width 2, turns
49 // "AaBbCcDc" into "ABCDabcd", this for example de-interleaves UTF-16 bytes into
50 // first all the high order bytes, then all the low order bytes.
51 // Transposes a matrix of width columns and ceil(size / width) rows. There are
52 // size elements, size may be < width * height, if so the
53 // last elements of the bottom row are missing, the missing spots are
54 // transposed along with the filled spots, and the result has the missing
55 // elements at the bottom of the rightmost column. The input is the input matrix
56 // in scanline order, the output is the result matrix in scanline order, with
57 // missing elements skipped over (this may occur at multiple positions).
Unshuffle(uint8_t * data,size_t size,size_t width)58 void Unshuffle(uint8_t* data, size_t size, size_t width) {
59   size_t height = (size + width - 1) / width;  // amount of rows of input
60   PaddedBytes result(size);
61   // i = input index, j output index
62   size_t s = 0, j = 0;
63   for (size_t i = 0; i < size; i++) {
64     result[j] = data[i];
65     j += height;
66     if (j >= size) j = ++s;
67   }
68 
69   for (size_t i = 0; i < size; i++) {
70     data[i] = result[i];
71   }
72 }
73 
74 // This is performed by the encoder, the encoder must be able to encode any
75 // random byte stream (not just byte streams that are a valid ICC profile), so
76 // an error returned by this function is an implementation error.
PredictAndShuffle(size_t stride,size_t width,int order,size_t num,const uint8_t * data,size_t size,size_t * pos,PaddedBytes * result)77 Status PredictAndShuffle(size_t stride, size_t width, int order, size_t num,
78                          const uint8_t* data, size_t size, size_t* pos,
79                          PaddedBytes* result) {
80   JXL_RETURN_IF_ERROR(CheckOutOfBounds(*pos, num, size));
81   // Required by the specification, see decoder. stride * 4 must be < *pos.
82   if (!*pos || ((*pos - 1u) >> 2u) < stride) {
83     return JXL_FAILURE("Invalid stride");
84   }
85   if (*pos < stride * 4) return JXL_FAILURE("Too large stride");
86   size_t start = result->size();
87   for (size_t i = 0; i < num; i++) {
88     uint8_t predicted =
89         LinearPredictICCValue(data, *pos, i, stride, width, order);
90     result->push_back(data[*pos + i] - predicted);
91   }
92   *pos += num;
93   if (width > 1) Unshuffle(result->data() + start, num, width);
94   return true;
95 }
96 }  // namespace
97 
98 // Outputs a transformed form of the given icc profile. The result itself is
99 // not particularly smaller than the input data in bytes, but it will be in a
100 // form that is easier to compress (more zeroes, ...) and will compress better
101 // with brotli.
PredictICC(const uint8_t * icc,size_t size,PaddedBytes * result)102 Status PredictICC(const uint8_t* icc, size_t size, PaddedBytes* result) {
103   PaddedBytes commands;
104   PaddedBytes data;
105 
106   EncodeVarInt(size, result);
107 
108   // Header
109   PaddedBytes header = ICCInitialHeaderPrediction();
110   EncodeUint32(0, size, &header);
111   for (size_t i = 0; i < kICCHeaderSize && i < size; i++) {
112     ICCPredictHeader(icc, size, header.data(), i);
113     data.push_back(icc[i] - header[i]);
114   }
115   if (size <= kICCHeaderSize) {
116     EncodeVarInt(0, result);  // 0 commands
117     for (size_t i = 0; i < data.size(); i++) {
118       result->push_back(data[i]);
119     }
120     return true;
121   }
122 
123   std::vector<Tag> tags;
124   std::vector<size_t> tagstarts;
125   std::vector<size_t> tagsizes;
126   std::map<size_t, size_t> tagmap;
127 
128   // Tag list
129   size_t pos = kICCHeaderSize;
130   if (pos + 4 <= size) {
131     uint64_t numtags = DecodeUint32(icc, size, pos);
132     pos += 4;
133     EncodeVarInt(numtags + 1, &commands);
134     uint64_t prevtagstart = kICCHeaderSize + numtags * 12;
135     uint32_t prevtagsize = 0;
136     for (size_t i = 0; i < numtags; i++) {
137       if (pos + 12 > size) break;
138 
139       Tag tag = DecodeKeyword(icc, size, pos + 0);
140       uint32_t tagstart = DecodeUint32(icc, size, pos + 4);
141       uint32_t tagsize = DecodeUint32(icc, size, pos + 8);
142       pos += 12;
143 
144       tags.push_back(tag);
145       tagstarts.push_back(tagstart);
146       tagsizes.push_back(tagsize);
147       tagmap[tagstart] = tags.size() - 1;
148 
149       uint8_t tagcode = kCommandTagUnknown;
150       for (size_t j = 0; j < kNumTagStrings; j++) {
151         if (tag == *kTagStrings[j]) {
152           tagcode = j + kCommandTagStringFirst;
153           break;
154         }
155       }
156 
157       if (tag == kRtrcTag && pos + 24 < size) {
158         bool ok = true;
159         ok &= DecodeKeyword(icc, size, pos + 0) == kGtrcTag;
160         ok &= DecodeKeyword(icc, size, pos + 12) == kBtrcTag;
161         if (ok) {
162           for (size_t i = 0; i < 8; i++) {
163             if (icc[pos - 8 + i] != icc[pos + 4 + i]) ok = false;
164             if (icc[pos - 8 + i] != icc[pos + 16 + i]) ok = false;
165           }
166         }
167         if (ok) {
168           tagcode = kCommandTagTRC;
169           pos += 24;
170           i += 2;
171         }
172       }
173 
174       if (tag == kRxyzTag && pos + 24 < size) {
175         bool ok = true;
176         ok &= DecodeKeyword(icc, size, pos + 0) == kGxyzTag;
177         ok &= DecodeKeyword(icc, size, pos + 12) == kBxyzTag;
178         uint32_t offsetr = tagstart;
179         uint32_t offsetg = DecodeUint32(icc, size, pos + 4);
180         uint32_t offsetb = DecodeUint32(icc, size, pos + 16);
181         uint32_t sizer = tagsize;
182         uint32_t sizeg = DecodeUint32(icc, size, pos + 8);
183         uint32_t sizeb = DecodeUint32(icc, size, pos + 20);
184         ok &= sizer == 20;
185         ok &= sizeg == 20;
186         ok &= sizeb == 20;
187         ok &= (offsetg == offsetr + 20);
188         ok &= (offsetb == offsetr + 40);
189         if (ok) {
190           tagcode = kCommandTagXYZ;
191           pos += 24;
192           i += 2;
193         }
194       }
195 
196       uint8_t command = tagcode;
197       uint64_t predicted_tagstart = prevtagstart + prevtagsize;
198       if (predicted_tagstart != tagstart) command |= kFlagBitOffset;
199       size_t predicted_tagsize = prevtagsize;
200       if (tag == kRxyzTag || tag == kGxyzTag || tag == kBxyzTag ||
201           tag == kKxyzTag || tag == kWtptTag || tag == kBkptTag ||
202           tag == kLumiTag) {
203         predicted_tagsize = 20;
204       }
205       if (predicted_tagsize != tagsize) command |= kFlagBitSize;
206       commands.push_back(command);
207       if (tagcode == 1) {
208         AppendKeyword(tag, &data);
209       }
210       if (command & kFlagBitOffset) EncodeVarInt(tagstart, &commands);
211       if (command & kFlagBitSize) EncodeVarInt(tagsize, &commands);
212 
213       prevtagstart = tagstart;
214       prevtagsize = tagsize;
215     }
216   }
217   // Indicate end of tag list or varint indicating there's none
218   commands.push_back(0);
219 
220   // Main content
221   // The main content in a valid ICC profile contains tagged elements, with the
222   // tag types (4 letter names) given by the tag list above, and the tag list
223   // pointing to the start and indicating the size of each tagged element. It is
224   // allowed for tagged elements to overlap, e.g. the curve for R, G and B could
225   // all point to the same one.
226   Tag tag;
227   size_t tagstart = 0, tagsize = 0, clutstart = 0;
228 
229   size_t last0 = pos;
230   // This loop appends commands to the output, processing some sub-section of a
231   // current tagged element each time. We need to keep track of the tagtype of
232   // the current element, and update it when we encounter the boundary of a
233   // next one.
234   // It is not required that the input data is a valid ICC profile, if the
235   // encoder does not recognize the data it will still be able to output bytes
236   // but will not predict as well.
237   while (pos <= size) {
238     size_t last1 = pos;
239     PaddedBytes commands_add;
240     PaddedBytes data_add;
241 
242     // This means the loop brought the position beyond the tag end.
243     if (pos > tagstart + tagsize) {
244       tag = {0, 0, 0, 0};  // nonsensical value
245     }
246 
247     if (commands_add.empty() && data_add.empty() && tagmap.count(pos) &&
248         pos + 4 <= size) {
249       size_t index = tagmap[pos];
250       tag = DecodeKeyword(icc, size, pos);
251       tagstart = tagstarts[index];
252       tagsize = tagsizes[index];
253 
254       if (tag == kMlucTag && pos + tagsize <= size && tagsize > 8 &&
255           icc[pos + 4] == 0 && icc[pos + 5] == 0 && icc[pos + 6] == 0 &&
256           icc[pos + 7] == 0) {
257         size_t num = tagsize - 8;
258         commands_add.push_back(kCommandTypeStartFirst + 3);
259         pos += 8;
260         commands_add.push_back(kCommandShuffle2);
261         EncodeVarInt(num, &commands_add);
262         size_t start = data_add.size();
263         for (size_t i = 0; i < num; i++) {
264           data_add.push_back(icc[pos]);
265           pos++;
266         }
267         Unshuffle(data_add.data() + start, num, 2);
268       }
269 
270       if (tag == kCurvTag && pos + tagsize <= size && tagsize > 8 &&
271           icc[pos + 4] == 0 && icc[pos + 5] == 0 && icc[pos + 6] == 0 &&
272           icc[pos + 7] == 0) {
273         size_t num = tagsize - 8;
274         if (num > 16 && num < (1 << 28) && pos + num <= size && pos > 0) {
275           commands_add.push_back(kCommandTypeStartFirst + 5);
276           pos += 8;
277           commands_add.push_back(kCommandPredict);
278           int order = 1, width = 2, stride = width;
279           commands_add.push_back((order << 2) | (width - 1));
280           EncodeVarInt(num, &commands_add);
281           JXL_RETURN_IF_ERROR(PredictAndShuffle(stride, width, order, num, icc,
282                                                 size, &pos, &data_add));
283         }
284       }
285     }
286 
287     if (tag == kMab_Tag || tag == kMba_Tag) {
288       Tag subTag = DecodeKeyword(icc, size, pos);
289       if (pos + 12 < size && (subTag == kCurvTag || subTag == kVcgtTag) &&
290           DecodeUint32(icc, size, pos + 4) == 0) {
291         uint32_t num = DecodeUint32(icc, size, pos + 8) * 2;
292         if (num > 16 && num < (1 << 28) && pos + 12 + num <= size) {
293           pos += 12;
294           last1 = pos;
295           commands_add.push_back(kCommandPredict);
296           int order = 1, width = 2, stride = width;
297           commands_add.push_back((order << 2) | (width - 1));
298           EncodeVarInt(num, &commands_add);
299           JXL_RETURN_IF_ERROR(PredictAndShuffle(stride, width, order, num, icc,
300                                                 size, &pos, &data_add));
301         }
302       }
303 
304       if (pos == tagstart + 24 && pos + 4 < size) {
305         // Note that this value can be remembered for next iterations of the
306         // loop, so the "pos == clutstart" if below can trigger during a later
307         // iteration.
308         clutstart = tagstart + DecodeUint32(icc, size, pos);
309       }
310 
311       if (pos == clutstart && clutstart + 16 < size) {
312         size_t numi = icc[tagstart + 8];
313         size_t numo = icc[tagstart + 9];
314         size_t width = icc[clutstart + 16];
315         size_t stride = width * numo;
316         size_t num = width * numo;
317         for (size_t i = 0; i < numi && clutstart + i < size; i++) {
318           num *= icc[clutstart + i];
319         }
320         if ((width == 1 || width == 2) && num > 64 && num < (1 << 28) &&
321             pos + num <= size && pos > stride * 4) {
322           commands_add.push_back(kCommandPredict);
323           int order = 1;
324           uint8_t flags =
325               (order << 2) | (width - 1) | (stride == width ? 0 : 16);
326           commands_add.push_back(flags);
327           if (flags & 16) EncodeVarInt(stride, &commands_add);
328           EncodeVarInt(num, &commands_add);
329           JXL_RETURN_IF_ERROR(PredictAndShuffle(stride, width, order, num, icc,
330                                                 size, &pos, &data_add));
331         }
332       }
333     }
334 
335     if (commands_add.empty() && data_add.empty() && tag == kGbd_Tag &&
336         pos == tagstart + 8 && pos + tagsize - 8 <= size && pos > 16 &&
337         tagsize > 8) {
338       size_t width = 4, order = 0, stride = width;
339       size_t num = tagsize - 8;
340       uint8_t flags = (order << 2) | (width - 1) | (stride == width ? 0 : 16);
341       commands_add.push_back(kCommandPredict);
342       commands_add.push_back(flags);
343       if (flags & 16) EncodeVarInt(stride, &commands_add);
344       EncodeVarInt(num, &commands_add);
345       JXL_RETURN_IF_ERROR(PredictAndShuffle(stride, width, order, num, icc,
346                                             size, &pos, &data_add));
347     }
348 
349     if (commands_add.empty() && data_add.empty() && pos + 20 <= size) {
350       Tag subTag = DecodeKeyword(icc, size, pos);
351       if (subTag == kXyz_Tag && DecodeUint32(icc, size, pos + 4) == 0) {
352         commands_add.push_back(kCommandXYZ);
353         pos += 8;
354         for (size_t j = 0; j < 12; j++) data_add.push_back(icc[pos++]);
355       }
356     }
357 
358     if (commands_add.empty() && data_add.empty() && pos + 8 <= size) {
359       if (DecodeUint32(icc, size, pos + 4) == 0) {
360         Tag subTag = DecodeKeyword(icc, size, pos);
361         for (size_t i = 0; i < kNumTypeStrings; i++) {
362           if (subTag == *kTypeStrings[i]) {
363             commands_add.push_back(kCommandTypeStartFirst + i);
364             pos += 8;
365             break;
366           }
367         }
368       }
369     }
370 
371     if (!(commands_add.empty() && data_add.empty()) || pos == size) {
372       if (last0 < last1) {
373         commands.push_back(kCommandInsert);
374         EncodeVarInt(last1 - last0, &commands);
375         while (last0 < last1) {
376           data.push_back(icc[last0++]);
377         }
378       }
379       for (size_t i = 0; i < commands_add.size(); i++) {
380         commands.push_back(commands_add[i]);
381       }
382       for (size_t i = 0; i < data_add.size(); i++) {
383         data.push_back(data_add[i]);
384       }
385       last0 = pos;
386     }
387     if (commands_add.empty() && data_add.empty()) {
388       pos++;
389     }
390   }
391 
392   EncodeVarInt(commands.size(), result);
393   for (size_t i = 0; i < commands.size(); i++) {
394     result->push_back(commands[i]);
395   }
396   for (size_t i = 0; i < data.size(); i++) {
397     result->push_back(data[i]);
398   }
399 
400   return true;
401 }
402 
WriteICC(const PaddedBytes & icc,BitWriter * JXL_RESTRICT writer,size_t layer,AuxOut * JXL_RESTRICT aux_out)403 Status WriteICC(const PaddedBytes& icc, BitWriter* JXL_RESTRICT writer,
404                 size_t layer, AuxOut* JXL_RESTRICT aux_out) {
405   if (icc.empty()) return JXL_FAILURE("ICC must be non-empty");
406   PaddedBytes enc;
407   JXL_RETURN_IF_ERROR(PredictICC(icc.data(), icc.size(), &enc));
408   std::vector<std::vector<Token>> tokens(1);
409   BitWriter::Allotment allotment(writer, 128);
410   JXL_RETURN_IF_ERROR(U64Coder::Write(enc.size(), writer));
411   ReclaimAndCharge(writer, &allotment, layer, aux_out);
412 
413   for (size_t i = 0; i < enc.size(); i++) {
414     tokens[0].emplace_back(
415         ICCANSContext(i, i > 0 ? enc[i - 1] : 0, i > 1 ? enc[i - 2] : 0),
416         enc[i]);
417   }
418   HistogramParams params;
419   params.lz77_method = enc.size() < 4096 ? HistogramParams::LZ77Method::kOptimal
420                                          : HistogramParams::LZ77Method::kLZ77;
421   EntropyEncodingData code;
422   std::vector<uint8_t> context_map;
423   params.force_huffman = true;
424   BuildAndEncodeHistograms(params, kNumICCContexts, tokens, &code, &context_map,
425                            writer, layer, aux_out);
426   WriteTokens(tokens[0], code, context_map, writer, layer, aux_out);
427   return true;
428 }
429 
430 }  // namespace jxl
431