1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5 
6 #include <stdint.h>
7 #include <stdlib.h>
8 
9 #include <cinttypes>
10 #include <limits>
11 #include <numeric>
12 #include <queue>
13 #include <set>
14 #include <unordered_map>
15 #include <unordered_set>
16 
17 #include "lib/jxl/base/printf_macros.h"
18 #include "lib/jxl/base/status.h"
19 #include "lib/jxl/common.h"
20 #include "lib/jxl/dec_ans.h"
21 #include "lib/jxl/dec_bit_reader.h"
22 #include "lib/jxl/enc_ans.h"
23 #include "lib/jxl/enc_bit_writer.h"
24 #include "lib/jxl/entropy_coder.h"
25 #include "lib/jxl/fields.h"
26 #include "lib/jxl/image_ops.h"
27 #include "lib/jxl/modular/encoding/context_predict.h"
28 #include "lib/jxl/modular/encoding/enc_debug_tree.h"
29 #include "lib/jxl/modular/encoding/enc_ma.h"
30 #include "lib/jxl/modular/encoding/encoding.h"
31 #include "lib/jxl/modular/encoding/ma_common.h"
32 #include "lib/jxl/modular/options.h"
33 #include "lib/jxl/modular/transform/transform.h"
34 #include "lib/jxl/toc.h"
35 
36 namespace jxl {
37 
38 namespace {
39 // Plot tree (if enabled) and predictor usage map.
40 constexpr bool kWantDebug = false;
41 constexpr bool kPrintTree = false;
42 
PredictorColor(Predictor p)43 inline std::array<uint8_t, 3> PredictorColor(Predictor p) {
44   switch (p) {
45     case Predictor::Zero:
46       return {{0, 0, 0}};
47     case Predictor::Left:
48       return {{255, 0, 0}};
49     case Predictor::Top:
50       return {{0, 255, 0}};
51     case Predictor::Average0:
52       return {{0, 0, 255}};
53     case Predictor::Average4:
54       return {{192, 128, 128}};
55     case Predictor::Select:
56       return {{255, 255, 0}};
57     case Predictor::Gradient:
58       return {{255, 0, 255}};
59     case Predictor::Weighted:
60       return {{0, 255, 255}};
61       // TODO
62     default:
63       return {{255, 255, 255}};
64   };
65 }
66 
67 }  // namespace
68 
GatherTreeData(const Image & image,pixel_type chan,size_t group_id,const weighted::Header & wp_header,const ModularOptions & options,TreeSamples & tree_samples,size_t * total_pixels)69 void GatherTreeData(const Image &image, pixel_type chan, size_t group_id,
70                     const weighted::Header &wp_header,
71                     const ModularOptions &options, TreeSamples &tree_samples,
72                     size_t *total_pixels) {
73   const Channel &channel = image.channel[chan];
74 
75   JXL_DEBUG_V(7, "Learning %" PRIuS "x%" PRIuS " channel %d", channel.w,
76               channel.h, chan);
77 
78   std::array<pixel_type, kNumStaticProperties> static_props = {
79       {chan, (int)group_id}};
80   Properties properties(kNumNonrefProperties +
81                         kExtraPropsPerChannel * options.max_properties);
82   double pixel_fraction = std::min(1.0f, options.nb_repeats);
83   // a fraction of 0 is used to disable learning entirely.
84   if (pixel_fraction > 0) {
85     pixel_fraction = std::max(pixel_fraction,
86                               std::min(1.0, 1024.0 / (channel.w * channel.h)));
87   }
88   uint64_t threshold =
89       (std::numeric_limits<uint64_t>::max() >> 32) * pixel_fraction;
90   uint64_t s[2] = {0x94D049BB133111EBull, 0xBF58476D1CE4E5B9ull};
91   // Xorshift128+ adapted from xorshift128+-inl.h
92   auto use_sample = [&]() {
93     auto s1 = s[0];
94     const auto s0 = s[1];
95     const auto bits = s1 + s0;  // b, c
96     s[0] = s0;
97     s1 ^= s1 << 23;
98     s1 ^= s0 ^ (s1 >> 18) ^ (s0 >> 5);
99     s[1] = s1;
100     return (bits >> 32) <= threshold;
101   };
102 
103   const intptr_t onerow = channel.plane.PixelsPerRow();
104   Channel references(properties.size() - kNumNonrefProperties, channel.w);
105   weighted::State wp_state(wp_header, channel.w, channel.h);
106   tree_samples.PrepareForSamples(pixel_fraction * channel.h * channel.w + 64);
107   for (size_t y = 0; y < channel.h; y++) {
108     const pixel_type *JXL_RESTRICT p = channel.Row(y);
109     PrecomputeReferences(channel, y, image, chan, &references);
110     InitPropsRow(&properties, static_props, y);
111     // TODO(veluca): avoid computing WP if we don't use its property or
112     // predictions.
113     for (size_t x = 0; x < channel.w; x++) {
114       pixel_type_w pred[kNumModularPredictors];
115       if (tree_samples.NumPredictors() != 1) {
116         PredictLearnAll(&properties, channel.w, p + x, onerow, x, y, references,
117                         &wp_state, pred);
118       } else {
119         pred[static_cast<int>(tree_samples.PredictorFromIndex(0))] =
120             PredictLearn(&properties, channel.w, p + x, onerow, x, y,
121                          tree_samples.PredictorFromIndex(0), references,
122                          &wp_state)
123                 .guess;
124       }
125       (*total_pixels)++;
126       if (use_sample()) {
127         tree_samples.AddSample(p[x], properties, pred);
128       }
129       wp_state.UpdateErrors(p[x], x, y, channel.w);
130     }
131   }
132 }
133 
LearnTree(TreeSamples && tree_samples,size_t total_pixels,const ModularOptions & options,const std::vector<ModularMultiplierInfo> & multiplier_info={},StaticPropRange static_prop_range={})134 Tree LearnTree(TreeSamples &&tree_samples, size_t total_pixels,
135                const ModularOptions &options,
136                const std::vector<ModularMultiplierInfo> &multiplier_info = {},
137                StaticPropRange static_prop_range = {}) {
138   for (size_t i = 0; i < kNumStaticProperties; i++) {
139     if (static_prop_range[i][1] == 0) {
140       static_prop_range[i][1] = std::numeric_limits<uint32_t>::max();
141     }
142   }
143   if (!tree_samples.HasSamples()) {
144     Tree tree;
145     tree.emplace_back();
146     tree.back().predictor = tree_samples.PredictorFromIndex(0);
147     tree.back().property = -1;
148     tree.back().predictor_offset = 0;
149     tree.back().multiplier = 1;
150     return tree;
151   }
152   float pixel_fraction = tree_samples.NumSamples() * 1.0f / total_pixels;
153   float required_cost = pixel_fraction * 0.9 + 0.1;
154   tree_samples.AllSamplesDone();
155   Tree tree;
156   ComputeBestTree(tree_samples,
157                   options.splitting_heuristics_node_threshold * required_cost,
158                   multiplier_info, static_prop_range,
159                   options.fast_decode_multiplier, &tree);
160   return tree;
161 }
162 
EncodeModularChannelMAANS(const Image & image,pixel_type chan,const weighted::Header & wp_header,const Tree & global_tree,Token ** tokenpp,AuxOut * aux_out,size_t group_id,bool skip_encoder_fast_path)163 Status EncodeModularChannelMAANS(const Image &image, pixel_type chan,
164                                  const weighted::Header &wp_header,
165                                  const Tree &global_tree, Token **tokenpp,
166                                  AuxOut *aux_out, size_t group_id,
167                                  bool skip_encoder_fast_path) {
168   const Channel &channel = image.channel[chan];
169   Token *tokenp = *tokenpp;
170   JXL_ASSERT(channel.w != 0 && channel.h != 0);
171 
172   Image3F predictor_img;
173   if (kWantDebug) predictor_img = Image3F(channel.w, channel.h);
174 
175   JXL_DEBUG_V(6,
176               "Encoding %" PRIuS "x%" PRIuS
177               " channel %d, "
178               "(shift=%i,%i)",
179               channel.w, channel.h, chan, channel.hshift, channel.vshift);
180 
181   std::array<pixel_type, kNumStaticProperties> static_props = {
182       {chan, (int)group_id}};
183   bool use_wp, is_wp_only;
184   bool is_gradient_only;
185   size_t num_props;
186   FlatTree tree = FilterTree(global_tree, static_props, &num_props, &use_wp,
187                              &is_wp_only, &is_gradient_only);
188   Properties properties(num_props);
189   MATreeLookup tree_lookup(tree);
190   JXL_DEBUG_V(3, "Encoding using a MA tree with %" PRIuS " nodes", tree.size());
191 
192   // Check if this tree is a WP-only tree with a small enough property value
193   // range.
194   // Initialized to avoid clang-tidy complaining.
195   uint16_t context_lookup[2 * kPropRangeFast] = {};
196   int8_t offsets[2 * kPropRangeFast] = {};
197   if (is_wp_only) {
198     is_wp_only = TreeToLookupTable(tree, context_lookup, offsets);
199   }
200   if (is_gradient_only) {
201     is_gradient_only = TreeToLookupTable(tree, context_lookup, offsets);
202   }
203 
204   if (is_wp_only && !skip_encoder_fast_path) {
205     for (size_t c = 0; c < 3; c++) {
206       FillImage(static_cast<float>(PredictorColor(Predictor::Weighted)[c]),
207                 &predictor_img.Plane(c));
208     }
209     const intptr_t onerow = channel.plane.PixelsPerRow();
210     weighted::State wp_state(wp_header, channel.w, channel.h);
211     Properties properties(1);
212     for (size_t y = 0; y < channel.h; y++) {
213       const pixel_type *JXL_RESTRICT r = channel.Row(y);
214       for (size_t x = 0; x < channel.w; x++) {
215         size_t offset = 0;
216         pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0);
217         pixel_type_w top = (y ? *(r + x - onerow) : left);
218         pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left);
219         pixel_type_w topright =
220             (x + 1 < channel.w && y ? *(r + x + 1 - onerow) : top);
221         pixel_type_w toptop = (y > 1 ? *(r + x - onerow - onerow) : top);
222         int32_t guess = wp_state.Predict</*compute_properties=*/true>(
223             x, y, channel.w, top, left, topright, topleft, toptop, &properties,
224             offset);
225         uint32_t pos =
226             kPropRangeFast + std::min(std::max(-kPropRangeFast, properties[0]),
227                                       kPropRangeFast - 1);
228         uint32_t ctx_id = context_lookup[pos];
229         int32_t residual = r[x] - guess - offsets[pos];
230         *tokenp++ = Token(ctx_id, PackSigned(residual));
231         wp_state.UpdateErrors(r[x], x, y, channel.w);
232       }
233     }
234   } else if (tree.size() == 1 && tree[0].predictor == Predictor::Gradient &&
235              tree[0].multiplier == 1 && tree[0].predictor_offset == 0 &&
236              !skip_encoder_fast_path) {
237     for (size_t c = 0; c < 3; c++) {
238       FillImage(static_cast<float>(PredictorColor(Predictor::Gradient)[c]),
239                 &predictor_img.Plane(c));
240     }
241     const intptr_t onerow = channel.plane.PixelsPerRow();
242     for (size_t y = 0; y < channel.h; y++) {
243       const pixel_type *JXL_RESTRICT r = channel.Row(y);
244       for (size_t x = 0; x < channel.w; x++) {
245         pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0);
246         pixel_type_w top = (y ? *(r + x - onerow) : left);
247         pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left);
248         int32_t guess = ClampedGradient(top, left, topleft);
249         int32_t residual = r[x] - guess;
250         *tokenp++ = Token(tree[0].childID, PackSigned(residual));
251       }
252     }
253   } else if (is_gradient_only && !skip_encoder_fast_path) {
254     for (size_t c = 0; c < 3; c++) {
255       FillImage(static_cast<float>(PredictorColor(Predictor::Gradient)[c]),
256                 &predictor_img.Plane(c));
257     }
258     const intptr_t onerow = channel.plane.PixelsPerRow();
259     for (size_t y = 0; y < channel.h; y++) {
260       const pixel_type *JXL_RESTRICT r = channel.Row(y);
261       for (size_t x = 0; x < channel.w; x++) {
262         pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0);
263         pixel_type_w top = (y ? *(r + x - onerow) : left);
264         pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left);
265         int32_t guess = ClampedGradient(top, left, topleft);
266         uint32_t pos =
267             kPropRangeFast +
268             std::min<pixel_type_w>(
269                 std::max<pixel_type_w>(-kPropRangeFast, top + left - topleft),
270                 kPropRangeFast - 1);
271         uint32_t ctx_id = context_lookup[pos];
272         int32_t residual = r[x] - guess - offsets[pos];
273         *tokenp++ = Token(ctx_id, PackSigned(residual));
274       }
275     }
276   } else if (tree.size() == 1 && tree[0].predictor == Predictor::Zero &&
277              tree[0].multiplier == 1 && tree[0].predictor_offset == 0 &&
278              !skip_encoder_fast_path) {
279     for (size_t c = 0; c < 3; c++) {
280       FillImage(static_cast<float>(PredictorColor(Predictor::Zero)[c]),
281                 &predictor_img.Plane(c));
282     }
283     for (size_t y = 0; y < channel.h; y++) {
284       const pixel_type *JXL_RESTRICT p = channel.Row(y);
285       for (size_t x = 0; x < channel.w; x++) {
286         *tokenp++ = Token(tree[0].childID, PackSigned(p[x]));
287       }
288     }
289   } else if (tree.size() == 1 && tree[0].predictor != Predictor::Weighted &&
290              (tree[0].multiplier & (tree[0].multiplier - 1)) == 0 &&
291              tree[0].predictor_offset == 0 && !skip_encoder_fast_path) {
292     // multiplier is a power of 2.
293     for (size_t c = 0; c < 3; c++) {
294       FillImage(static_cast<float>(PredictorColor(tree[0].predictor)[c]),
295                 &predictor_img.Plane(c));
296     }
297     uint32_t mul_shift = FloorLog2Nonzero((uint32_t)tree[0].multiplier);
298     const intptr_t onerow = channel.plane.PixelsPerRow();
299     for (size_t y = 0; y < channel.h; y++) {
300       const pixel_type *JXL_RESTRICT r = channel.Row(y);
301       for (size_t x = 0; x < channel.w; x++) {
302         PredictionResult pred = PredictNoTreeNoWP(channel.w, r + x, onerow, x,
303                                                   y, tree[0].predictor);
304         pixel_type_w residual = r[x] - pred.guess;
305         JXL_DASSERT((residual >> mul_shift) * tree[0].multiplier == residual);
306         *tokenp++ = Token(tree[0].childID, PackSigned(residual >> mul_shift));
307       }
308     }
309 
310   } else if (!use_wp && !skip_encoder_fast_path) {
311     const intptr_t onerow = channel.plane.PixelsPerRow();
312     Channel references(properties.size() - kNumNonrefProperties, channel.w);
313     for (size_t y = 0; y < channel.h; y++) {
314       const pixel_type *JXL_RESTRICT p = channel.Row(y);
315       PrecomputeReferences(channel, y, image, chan, &references);
316       float *pred_img_row[3];
317       if (kWantDebug) {
318         for (size_t c = 0; c < 3; c++) {
319           pred_img_row[c] = predictor_img.PlaneRow(c, y);
320         }
321       }
322       InitPropsRow(&properties, static_props, y);
323       for (size_t x = 0; x < channel.w; x++) {
324         PredictionResult res =
325             PredictTreeNoWP(&properties, channel.w, p + x, onerow, x, y,
326                             tree_lookup, references);
327         if (kWantDebug) {
328           for (size_t i = 0; i < 3; i++) {
329             pred_img_row[i][x] = PredictorColor(res.predictor)[i];
330           }
331         }
332         pixel_type_w residual = p[x] - res.guess;
333         JXL_ASSERT(residual % res.multiplier == 0);
334         *tokenp++ = Token(res.context, PackSigned(residual / res.multiplier));
335       }
336     }
337   } else {
338     const intptr_t onerow = channel.plane.PixelsPerRow();
339     Channel references(properties.size() - kNumNonrefProperties, channel.w);
340     weighted::State wp_state(wp_header, channel.w, channel.h);
341     for (size_t y = 0; y < channel.h; y++) {
342       const pixel_type *JXL_RESTRICT p = channel.Row(y);
343       PrecomputeReferences(channel, y, image, chan, &references);
344       float *pred_img_row[3];
345       if (kWantDebug) {
346         for (size_t c = 0; c < 3; c++) {
347           pred_img_row[c] = predictor_img.PlaneRow(c, y);
348         }
349       }
350       InitPropsRow(&properties, static_props, y);
351       for (size_t x = 0; x < channel.w; x++) {
352         PredictionResult res =
353             PredictTreeWP(&properties, channel.w, p + x, onerow, x, y,
354                           tree_lookup, references, &wp_state);
355         if (kWantDebug) {
356           for (size_t i = 0; i < 3; i++) {
357             pred_img_row[i][x] = PredictorColor(res.predictor)[i];
358           }
359         }
360         pixel_type_w residual = p[x] - res.guess;
361         JXL_ASSERT(residual % res.multiplier == 0);
362         *tokenp++ = Token(res.context, PackSigned(residual / res.multiplier));
363         wp_state.UpdateErrors(p[x], x, y, channel.w);
364       }
365     }
366   }
367   if (kWantDebug && WantDebugOutput(aux_out)) {
368     aux_out->DumpImage(
369         ("pred_" + ToString(group_id) + "_" + ToString(chan)).c_str(),
370         predictor_img);
371   }
372   *tokenpp = tokenp;
373   return true;
374 }
375 
ModularEncode(const Image & image,const ModularOptions & options,BitWriter * writer,AuxOut * aux_out,size_t layer,size_t group_id,TreeSamples * tree_samples,size_t * total_pixels,const Tree * tree,GroupHeader * header,std::vector<Token> * tokens,size_t * width)376 Status ModularEncode(const Image &image, const ModularOptions &options,
377                      BitWriter *writer, AuxOut *aux_out, size_t layer,
378                      size_t group_id, TreeSamples *tree_samples,
379                      size_t *total_pixels, const Tree *tree,
380                      GroupHeader *header, std::vector<Token> *tokens,
381                      size_t *width) {
382   if (image.error) return JXL_FAILURE("Invalid image");
383   size_t nb_channels = image.channel.size();
384   JXL_DEBUG_V(
385       2, "Encoding %" PRIuS "-channel, %i-bit, %" PRIuS "x%" PRIuS " image.",
386       nb_channels, image.bitdepth, image.w, image.h);
387 
388   if (nb_channels < 1) {
389     return true;  // is there any use for a zero-channel image?
390   }
391 
392   // encode transforms
393   GroupHeader header_storage;
394   if (header == nullptr) header = &header_storage;
395   Bundle::Init(header);
396   if (options.predictor == Predictor::Weighted) {
397     weighted::PredictorMode(options.wp_mode, &header->wp_header);
398   }
399   header->transforms = image.transform;
400   // This doesn't actually work
401   if (tree != nullptr) {
402     header->use_global_tree = true;
403   }
404   if (tree_samples == nullptr && tree == nullptr) {
405     JXL_RETURN_IF_ERROR(Bundle::Write(*header, writer, layer, aux_out));
406   }
407 
408   TreeSamples tree_samples_storage;
409   size_t total_pixels_storage = 0;
410   if (!total_pixels) total_pixels = &total_pixels_storage;
411   // If there's no tree, compute one (or gather data to).
412   if (tree == nullptr) {
413     bool gather_data = tree_samples != nullptr;
414     if (tree_samples == nullptr) {
415       JXL_RETURN_IF_ERROR(tree_samples_storage.SetPredictor(
416           options.predictor, options.wp_tree_mode));
417       JXL_RETURN_IF_ERROR(tree_samples_storage.SetProperties(
418           options.splitting_heuristics_properties, options.wp_tree_mode));
419       std::vector<pixel_type> pixel_samples;
420       std::vector<pixel_type> diff_samples;
421       std::vector<uint32_t> group_pixel_count;
422       std::vector<uint32_t> channel_pixel_count;
423       CollectPixelSamples(image, options, 0, group_pixel_count,
424                           channel_pixel_count, pixel_samples, diff_samples);
425       std::vector<ModularMultiplierInfo> dummy_multiplier_info;
426       StaticPropRange range;
427       tree_samples_storage.PreQuantizeProperties(
428           range, dummy_multiplier_info, group_pixel_count, channel_pixel_count,
429           pixel_samples, diff_samples, options.max_property_values);
430     }
431     for (size_t i = 0; i < nb_channels; i++) {
432       if (!image.channel[i].w || !image.channel[i].h) {
433         continue;  // skip empty channels
434       }
435       if (i >= image.nb_meta_channels &&
436           (image.channel[i].w > options.max_chan_size ||
437            image.channel[i].h > options.max_chan_size)) {
438         break;
439       }
440       GatherTreeData(image, i, group_id, header->wp_header, options,
441                      gather_data ? *tree_samples : tree_samples_storage,
442                      total_pixels);
443     }
444     if (gather_data) return true;
445   }
446 
447   JXL_ASSERT((tree == nullptr) == (tokens == nullptr));
448 
449   Tree tree_storage;
450   std::vector<std::vector<Token>> tokens_storage(1);
451   // Compute tree.
452   if (tree == nullptr) {
453     EntropyEncodingData code;
454     std::vector<uint8_t> context_map;
455 
456     std::vector<std::vector<Token>> tree_tokens(1);
457     tree_storage =
458         LearnTree(std::move(tree_samples_storage), *total_pixels, options);
459     tree = &tree_storage;
460     tokens = &tokens_storage[0];
461 
462     Tree decoded_tree;
463     TokenizeTree(*tree, &tree_tokens[0], &decoded_tree);
464     JXL_ASSERT(tree->size() == decoded_tree.size());
465     tree_storage = std::move(decoded_tree);
466 
467     if (kWantDebug && kPrintTree && WantDebugOutput(aux_out)) {
468       PrintTree(*tree, aux_out->debug_prefix + "/tree_" + ToString(group_id));
469     }
470     // Write tree
471     BuildAndEncodeHistograms(HistogramParams(), kNumTreeContexts, tree_tokens,
472                              &code, &context_map, writer, kLayerModularTree,
473                              aux_out);
474     WriteTokens(tree_tokens[0], code, context_map, writer, kLayerModularTree,
475                 aux_out);
476   }
477 
478   size_t image_width = 0;
479   size_t total_tokens = 0;
480   for (size_t i = 0; i < nb_channels; i++) {
481     if (i >= image.nb_meta_channels &&
482         (image.channel[i].w > options.max_chan_size ||
483          image.channel[i].h > options.max_chan_size)) {
484       break;
485     }
486     if (image.channel[i].w > image_width) image_width = image.channel[i].w;
487     total_tokens += image.channel[i].w * image.channel[i].h;
488   }
489   if (options.zero_tokens) {
490     tokens->resize(tokens->size() + total_tokens, {0, 0});
491   } else {
492     // Do one big allocation for all the tokens we'll need,
493     // to avoid reallocs that might require copying.
494     size_t pos = tokens->size();
495     tokens->resize(pos + total_tokens);
496     Token *tokenp = tokens->data() + pos;
497     for (size_t i = 0; i < nb_channels; i++) {
498       if (!image.channel[i].w || !image.channel[i].h) {
499         continue;  // skip empty channels
500       }
501       if (i >= image.nb_meta_channels &&
502           (image.channel[i].w > options.max_chan_size ||
503            image.channel[i].h > options.max_chan_size)) {
504         break;
505       }
506       JXL_RETURN_IF_ERROR(EncodeModularChannelMAANS(
507           image, i, header->wp_header, *tree, &tokenp, aux_out, group_id,
508           options.skip_encoder_fast_path));
509     }
510     // Make sure we actually wrote all tokens
511     JXL_CHECK(tokenp == tokens->data() + tokens->size());
512   }
513 
514   // Write data if not using a global tree/ANS stream.
515   if (!header->use_global_tree) {
516     EntropyEncodingData code;
517     std::vector<uint8_t> context_map;
518     HistogramParams histo_params;
519     histo_params.image_widths.push_back(image_width);
520     BuildAndEncodeHistograms(histo_params, (tree->size() + 1) / 2,
521                              tokens_storage, &code, &context_map, writer, layer,
522                              aux_out);
523     WriteTokens(tokens_storage[0], code, context_map, writer, layer, aux_out);
524   } else {
525     *width = image_width;
526   }
527   return true;
528 }
529 
ModularGenericCompress(Image & image,const ModularOptions & opts,BitWriter * writer,AuxOut * aux_out,size_t layer,size_t group_id,TreeSamples * tree_samples,size_t * total_pixels,const Tree * tree,GroupHeader * header,std::vector<Token> * tokens,size_t * width)530 Status ModularGenericCompress(Image &image, const ModularOptions &opts,
531                               BitWriter *writer, AuxOut *aux_out, size_t layer,
532                               size_t group_id, TreeSamples *tree_samples,
533                               size_t *total_pixels, const Tree *tree,
534                               GroupHeader *header, std::vector<Token> *tokens,
535                               size_t *width) {
536   if (image.w == 0 || image.h == 0) return true;
537   ModularOptions options = opts;  // Make a copy to modify it.
538 
539   if (options.predictor == static_cast<Predictor>(-1)) {
540     options.predictor = Predictor::Gradient;
541   }
542 
543   size_t bits = writer ? writer->BitsWritten() : 0;
544   JXL_RETURN_IF_ERROR(ModularEncode(image, options, writer, aux_out, layer,
545                                     group_id, tree_samples, total_pixels, tree,
546                                     header, tokens, width));
547   bits = writer ? writer->BitsWritten() - bits : 0;
548   if (writer) {
549     JXL_DEBUG_V(4,
550                 "Modular-encoded a %" PRIuS "x%" PRIuS
551                 " bitdepth=%i nbchans=%" PRIuS " image in %" PRIuS " bytes",
552                 image.w, image.h, image.bitdepth, image.channel.size(),
553                 bits / 8);
554   }
555   (void)bits;
556   return true;
557 }
558 
559 }  // namespace jxl
560