1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5 
6 #ifndef LIB_JXL_MODULAR_ENCODING_CONTEXT_PREDICT_H_
7 #define LIB_JXL_MODULAR_ENCODING_CONTEXT_PREDICT_H_
8 
9 #include <utility>
10 #include <vector>
11 
12 #include "lib/jxl/fields.h"
13 #include "lib/jxl/modular/modular_image.h"
14 #include "lib/jxl/modular/options.h"
15 
16 namespace jxl {
17 
18 namespace weighted {
19 constexpr static size_t kNumPredictors = 4;
20 constexpr static int64_t kPredExtraBits = 3;
21 constexpr static int64_t kPredictionRound = ((1 << kPredExtraBits) >> 1) - 1;
22 constexpr static size_t kNumProperties = 1;
23 
24 struct Header : public Fields {
25   JXL_FIELDS_NAME(WeightedPredictorHeader)
26   // TODO(janwas): move to cc file, avoid including fields.h.
HeaderHeader27   Header() { Bundle::Init(this); }
28 
VisitFieldsHeader29   Status VisitFields(Visitor *JXL_RESTRICT visitor) override {
30     if (visitor->AllDefault(*this, &all_default)) {
31       // Overwrite all serialized fields, but not any nonserialized_*.
32       visitor->SetDefault(this);
33       return true;
34     }
35     auto visit_p = [visitor](pixel_type val, pixel_type *p) {
36       uint32_t up = *p;
37       JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(5, val, &up));
38       *p = up;
39       return Status(true);
40     };
41     JXL_QUIET_RETURN_IF_ERROR(visit_p(16, &p1C));
42     JXL_QUIET_RETURN_IF_ERROR(visit_p(10, &p2C));
43     JXL_QUIET_RETURN_IF_ERROR(visit_p(7, &p3Ca));
44     JXL_QUIET_RETURN_IF_ERROR(visit_p(7, &p3Cb));
45     JXL_QUIET_RETURN_IF_ERROR(visit_p(7, &p3Cc));
46     JXL_QUIET_RETURN_IF_ERROR(visit_p(0, &p3Cd));
47     JXL_QUIET_RETURN_IF_ERROR(visit_p(0, &p3Ce));
48     JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xd, &w[0]));
49     JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xc, &w[1]));
50     JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xc, &w[2]));
51     JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xc, &w[3]));
52     return true;
53   }
54 
55   bool all_default;
56   pixel_type p1C = 0, p2C = 0, p3Ca = 0, p3Cb = 0, p3Cc = 0, p3Cd = 0, p3Ce = 0;
57   uint32_t w[kNumPredictors] = {};
58 };
59 
60 struct State {
61   pixel_type_w prediction[kNumPredictors] = {};
62   pixel_type_w pred = 0;  // *before* removing the added bits.
63   std::vector<uint32_t> pred_errors[kNumPredictors];
64   std::vector<int32_t> error;
65   Header header;
66 
67   // Allows to approximate division by a number from 1 to 64.
68   uint32_t divlookup[64];
69 
AddBitsState70   constexpr static pixel_type_w AddBits(pixel_type_w x) {
71     return uint64_t(x) << kPredExtraBits;
72   }
73 
StateState74   State(Header header, size_t xsize, size_t ysize) : header(header) {
75     // Extra margin to avoid out-of-bounds writes.
76     // All have space for two rows of data.
77     for (size_t i = 0; i < 4; i++) {
78       pred_errors[i].resize((xsize + 2) * 2);
79     }
80     error.resize((xsize + 2) * 2);
81     // Initialize division lookup table.
82     for (int i = 0; i < 64; i++) {
83       divlookup[i] = (1 << 24) / (i + 1);
84     }
85   }
86 
87   // Approximates 4+(maxweight<<24)/(x+1), avoiding division
ErrorWeightState88   JXL_INLINE uint32_t ErrorWeight(uint64_t x, uint32_t maxweight) const {
89     int shift = static_cast<int>(FloorLog2Nonzero(x + 1)) - 5;
90     if (shift < 0) shift = 0;
91     return 4 + ((maxweight * divlookup[x >> shift]) >> shift);
92   }
93 
94   // Approximates the weighted average of the input values with the given
95   // weights, avoiding division. Weights must sum to at least 16.
96   JXL_INLINE pixel_type_w
WeightedAverageState97   WeightedAverage(const pixel_type_w *JXL_RESTRICT p,
98                   std::array<uint32_t, kNumPredictors> w) const {
99     uint32_t weight_sum = 0;
100     for (size_t i = 0; i < kNumPredictors; i++) {
101       weight_sum += w[i];
102     }
103     JXL_DASSERT(weight_sum > 15);
104     uint32_t log_weight = FloorLog2Nonzero(weight_sum);  // at least 4.
105     weight_sum = 0;
106     for (size_t i = 0; i < kNumPredictors; i++) {
107       w[i] >>= log_weight - 4;
108       weight_sum += w[i];
109     }
110     // for rounding.
111     pixel_type_w sum = (weight_sum >> 1) - 1;
112     for (size_t i = 0; i < kNumPredictors; i++) {
113       sum += p[i] * w[i];
114     }
115     return (sum * divlookup[weight_sum - 1]) >> 24;
116   }
117 
118   template <bool compute_properties>
PredictState119   JXL_INLINE pixel_type_w Predict(size_t x, size_t y, size_t xsize,
120                                   pixel_type_w N, pixel_type_w W,
121                                   pixel_type_w NE, pixel_type_w NW,
122                                   pixel_type_w NN, Properties *properties,
123                                   size_t offset) {
124     size_t cur_row = y & 1 ? 0 : (xsize + 2);
125     size_t prev_row = y & 1 ? (xsize + 2) : 0;
126     size_t pos_N = prev_row + x;
127     size_t pos_NE = x < xsize - 1 ? pos_N + 1 : pos_N;
128     size_t pos_NW = x > 0 ? pos_N - 1 : pos_N;
129     std::array<uint32_t, kNumPredictors> weights;
130     for (size_t i = 0; i < kNumPredictors; i++) {
131       // pred_errors[pos_N] also contains the error of pixel W.
132       // pred_errors[pos_NW] also contains the error of pixel WW.
133       weights[i] = pred_errors[i][pos_N] + pred_errors[i][pos_NE] +
134                    pred_errors[i][pos_NW];
135       weights[i] = ErrorWeight(weights[i], header.w[i]);
136     }
137 
138     N = AddBits(N);
139     W = AddBits(W);
140     NE = AddBits(NE);
141     NW = AddBits(NW);
142     NN = AddBits(NN);
143 
144     pixel_type_w teW = x == 0 ? 0 : error[cur_row + x - 1];
145     pixel_type_w teN = error[pos_N];
146     pixel_type_w teNW = error[pos_NW];
147     pixel_type_w sumWN = teN + teW;
148     pixel_type_w teNE = error[pos_NE];
149 
150     if (compute_properties) {
151       pixel_type_w p = teW;
152       if (std::abs(teN) > std::abs(p)) p = teN;
153       if (std::abs(teNW) > std::abs(p)) p = teNW;
154       if (std::abs(teNE) > std::abs(p)) p = teNE;
155       (*properties)[offset++] = p;
156     }
157 
158     prediction[0] = W + NE - N;
159     prediction[1] = N - (((sumWN + teNE) * header.p1C) >> 5);
160     prediction[2] = W - (((sumWN + teNW) * header.p2C) >> 5);
161     prediction[3] =
162         N - ((teNW * header.p3Ca + teN * header.p3Cb + teNE * header.p3Cc +
163               (NN - N) * header.p3Cd + (NW - W) * header.p3Ce) >>
164              5);
165 
166     pred = WeightedAverage(prediction, weights);
167 
168     // If all three have the same sign, skip clamping.
169     if (((teN ^ teW) | (teN ^ teNW)) > 0) {
170       return (pred + kPredictionRound) >> kPredExtraBits;
171     }
172 
173     // Otherwise, clamp to min/max of neighbouring pixels (just W, NE, N).
174     pixel_type_w mx = std::max(W, std::max(NE, N));
175     pixel_type_w mn = std::min(W, std::min(NE, N));
176     pred = std::max(mn, std::min(mx, pred));
177     return (pred + kPredictionRound) >> kPredExtraBits;
178   }
179 
UpdateErrorsState180   JXL_INLINE void UpdateErrors(pixel_type_w val, size_t x, size_t y,
181                                size_t xsize) {
182     size_t cur_row = y & 1 ? 0 : (xsize + 2);
183     size_t prev_row = y & 1 ? (xsize + 2) : 0;
184     val = AddBits(val);
185     error[cur_row + x] = pred - val;
186     for (size_t i = 0; i < kNumPredictors; i++) {
187       pixel_type_w err =
188           (std::abs(prediction[i] - val) + kPredictionRound) >> kPredExtraBits;
189       // For predicting in the next row.
190       pred_errors[i][cur_row + x] = err;
191       // Add the error on this pixel to the error on the NE pixel. This has the
192       // effect of adding the error on this pixel to the E and EE pixels.
193       pred_errors[i][prev_row + x + 1] += err;
194     }
195   }
196 };
197 
198 // Encoder helper function to set the parameters to some presets.
PredictorMode(int i,Header * header)199 inline void PredictorMode(int i, Header *header) {
200   switch (i) {
201     case 0:
202       // ~ lossless16 predictor
203       header->w[0] = 0xd;
204       header->w[1] = 0xc;
205       header->w[2] = 0xc;
206       header->w[3] = 0xc;
207       header->p1C = 16;
208       header->p2C = 10;
209       header->p3Ca = 7;
210       header->p3Cb = 7;
211       header->p3Cc = 7;
212       header->p3Cd = 0;
213       header->p3Ce = 0;
214       break;
215     case 1:
216       // ~ default lossless8 predictor
217       header->w[0] = 0xd;
218       header->w[1] = 0xc;
219       header->w[2] = 0xc;
220       header->w[3] = 0xb;
221       header->p1C = 8;
222       header->p2C = 8;
223       header->p3Ca = 4;
224       header->p3Cb = 0;
225       header->p3Cc = 3;
226       header->p3Cd = 23;
227       header->p3Ce = 2;
228       break;
229     case 2:
230       // ~ west lossless8 predictor
231       header->w[0] = 0xd;
232       header->w[1] = 0xc;
233       header->w[2] = 0xd;
234       header->w[3] = 0xc;
235       header->p1C = 10;
236       header->p2C = 9;
237       header->p3Ca = 7;
238       header->p3Cb = 0;
239       header->p3Cc = 0;
240       header->p3Cd = 16;
241       header->p3Ce = 9;
242       break;
243     case 3:
244       // ~ north lossless8 predictor
245       header->w[0] = 0xd;
246       header->w[1] = 0xd;
247       header->w[2] = 0xc;
248       header->w[3] = 0xc;
249       header->p1C = 16;
250       header->p2C = 8;
251       header->p3Ca = 0;
252       header->p3Cb = 16;
253       header->p3Cc = 0;
254       header->p3Cd = 23;
255       header->p3Ce = 0;
256       break;
257     case 4:
258     default:
259       // something else, because why not
260       header->w[0] = 0xd;
261       header->w[1] = 0xc;
262       header->w[2] = 0xc;
263       header->w[3] = 0xc;
264       header->p1C = 10;
265       header->p2C = 10;
266       header->p3Ca = 5;
267       header->p3Cb = 5;
268       header->p3Cc = 5;
269       header->p3Cd = 12;
270       header->p3Ce = 4;
271       break;
272   }
273 }
274 }  // namespace weighted
275 
276 // Stores a node and its two children at the same time. This significantly
277 // reduces the number of branches needed during decoding.
278 struct FlatDecisionNode {
279   // Property + splitval of the top node.
280   int32_t property0;  // -1 if leaf.
281   union {
282     PropertyVal splitval0;
283     Predictor predictor;
284   };
285   uint32_t childID;  // childID is ctx id if leaf.
286   // Property+splitval of the two child nodes.
287   union {
288     PropertyVal splitvals[2];
289     int32_t multiplier;
290   };
291   union {
292     int32_t properties[2];
293     int64_t predictor_offset;
294   };
295 };
296 using FlatTree = std::vector<FlatDecisionNode>;
297 
298 class MATreeLookup {
299  public:
MATreeLookup(const FlatTree & tree)300   explicit MATreeLookup(const FlatTree &tree) : nodes_(tree) {}
301   struct LookupResult {
302     uint32_t context;
303     Predictor predictor;
304     int64_t offset;
305     int32_t multiplier;
306   };
Lookup(const Properties & properties)307   LookupResult Lookup(const Properties &properties) const {
308     uint32_t pos = 0;
309     while (true) {
310       const FlatDecisionNode &node = nodes_[pos];
311       if (node.property0 < 0) {
312         return {node.childID, node.predictor, node.predictor_offset,
313                 node.multiplier};
314       }
315       bool p0 = properties[node.property0] <= node.splitval0;
316       uint32_t off0 = properties[node.properties[0]] <= node.splitvals[0];
317       uint32_t off1 =
318           2 | (properties[node.properties[1]] <= node.splitvals[1] ? 1 : 0);
319       pos = node.childID + (p0 ? off1 : off0);
320     }
321   }
322 
323  private:
324   const FlatTree &nodes_;
325 };
326 
327 static constexpr size_t kExtraPropsPerChannel = 4;
328 static constexpr size_t kNumNonrefProperties =
329     kNumStaticProperties + 13 + weighted::kNumProperties;
330 
331 constexpr size_t kWPProp = kNumNonrefProperties - weighted::kNumProperties;
332 constexpr size_t kGradientProp = 9;
333 
334 // Clamps gradient to the min/max of n, w (and l, implicitly).
ClampedGradient(const int32_t n,const int32_t w,const int32_t l)335 static JXL_INLINE int32_t ClampedGradient(const int32_t n, const int32_t w,
336                                           const int32_t l) {
337   const int32_t m = std::min(n, w);
338   const int32_t M = std::max(n, w);
339   // The end result of this operation doesn't overflow or underflow if the
340   // result is between m and M, but the intermediate value may overflow, so we
341   // do the intermediate operations in uint32_t and check later if we had an
342   // overflow or underflow condition comparing m, M and l directly.
343   // grad = M + m - l = n + w - l
344   const int32_t grad =
345       static_cast<int32_t>(static_cast<uint32_t>(n) + static_cast<uint32_t>(w) -
346                            static_cast<uint32_t>(l));
347   // We use two sets of ternary operators to force the evaluation of them in
348   // any case, allowing the compiler to avoid branches and use cmovl/cmovg in
349   // x86.
350   const int32_t grad_clamp_M = (l < m) ? M : grad;
351   return (l > M) ? m : grad_clamp_M;
352 }
353 
Select(pixel_type_w a,pixel_type_w b,pixel_type_w c)354 inline pixel_type_w Select(pixel_type_w a, pixel_type_w b, pixel_type_w c) {
355   pixel_type_w p = a + b - c;
356   pixel_type_w pa = std::abs(p - a);
357   pixel_type_w pb = std::abs(p - b);
358   return pa < pb ? a : b;
359 }
360 
PrecomputeReferences(const Channel & ch,size_t y,const Image & image,uint32_t i,Channel * references)361 inline void PrecomputeReferences(const Channel &ch, size_t y,
362                                  const Image &image, uint32_t i,
363                                  Channel *references) {
364   ZeroFillImage(&references->plane);
365   uint32_t offset = 0;
366   size_t num_extra_props = references->w;
367   intptr_t onerow = references->plane.PixelsPerRow();
368   for (int32_t j = static_cast<int32_t>(i) - 1;
369        j >= 0 && offset < num_extra_props; j--) {
370     if (image.channel[j].w != image.channel[i].w ||
371         image.channel[j].h != image.channel[i].h) {
372       continue;
373     }
374     if (image.channel[j].hshift != image.channel[i].hshift) continue;
375     if (image.channel[j].vshift != image.channel[i].vshift) continue;
376     pixel_type *JXL_RESTRICT rp = references->Row(0) + offset;
377     const pixel_type *JXL_RESTRICT rpp = image.channel[j].Row(y);
378     const pixel_type *JXL_RESTRICT rpprev = image.channel[j].Row(y ? y - 1 : 0);
379     for (size_t x = 0; x < ch.w; x++, rp += onerow) {
380       pixel_type_w v = rpp[x];
381       rp[0] = std::abs(v);
382       rp[1] = v;
383       pixel_type_w vleft = (x ? rpp[x - 1] : 0);
384       pixel_type_w vtop = (y ? rpprev[x] : vleft);
385       pixel_type_w vtopleft = (x && y ? rpprev[x - 1] : vleft);
386       pixel_type_w vpredicted = ClampedGradient(vleft, vtop, vtopleft);
387       rp[2] = std::abs(v - vpredicted);
388       rp[3] = v - vpredicted;
389     }
390 
391     offset += kExtraPropsPerChannel;
392   }
393 }
394 
395 struct PredictionResult {
396   int context = 0;
397   pixel_type_w guess = 0;
398   Predictor predictor;
399   int32_t multiplier;
400 };
401 
InitPropsRow(Properties * p,const std::array<pixel_type,kNumStaticProperties> & static_props,const int y)402 inline void InitPropsRow(
403     Properties *p,
404     const std::array<pixel_type, kNumStaticProperties> &static_props,
405     const int y) {
406   for (size_t i = 0; i < kNumStaticProperties; i++) {
407     (*p)[i] = static_props[i];
408   }
409   (*p)[2] = y;
410   (*p)[9] = 0;  // local gradient.
411 }
412 
413 namespace detail {
414 enum PredictorMode {
415   kUseTree = 1,
416   kUseWP = 2,
417   kForceComputeProperties = 4,
418   kAllPredictions = 8,
419 };
420 
PredictOne(Predictor p,pixel_type_w left,pixel_type_w top,pixel_type_w toptop,pixel_type_w topleft,pixel_type_w topright,pixel_type_w leftleft,pixel_type_w toprightright,pixel_type_w wp_pred)421 JXL_INLINE pixel_type_w PredictOne(Predictor p, pixel_type_w left,
422                                    pixel_type_w top, pixel_type_w toptop,
423                                    pixel_type_w topleft, pixel_type_w topright,
424                                    pixel_type_w leftleft,
425                                    pixel_type_w toprightright,
426                                    pixel_type_w wp_pred) {
427   switch (p) {
428     case Predictor::Zero:
429       return pixel_type_w{0};
430     case Predictor::Left:
431       return left;
432     case Predictor::Top:
433       return top;
434     case Predictor::Select:
435       return Select(left, top, topleft);
436     case Predictor::Weighted:
437       return wp_pred;
438     case Predictor::Gradient:
439       return pixel_type_w{ClampedGradient(left, top, topleft)};
440     case Predictor::TopLeft:
441       return topleft;
442     case Predictor::TopRight:
443       return topright;
444     case Predictor::LeftLeft:
445       return leftleft;
446     case Predictor::Average0:
447       return (left + top) / 2;
448     case Predictor::Average1:
449       return (left + topleft) / 2;
450     case Predictor::Average2:
451       return (topleft + top) / 2;
452     case Predictor::Average3:
453       return (top + topright) / 2;
454     case Predictor::Average4:
455       return (6 * top - 2 * toptop + 7 * left + 1 * leftleft +
456               1 * toprightright + 3 * topright + 8) /
457              16;
458     default:
459       return pixel_type_w{0};
460   }
461 }
462 
463 template <int mode>
Predict(Properties * p,size_t w,const pixel_type * JXL_RESTRICT pp,const intptr_t onerow,const size_t x,const size_t y,Predictor predictor,const MATreeLookup * lookup,const Channel * references,weighted::State * wp_state,pixel_type_w * predictions)464 inline PredictionResult Predict(
465     Properties *p, size_t w, const pixel_type *JXL_RESTRICT pp,
466     const intptr_t onerow, const size_t x, const size_t y, Predictor predictor,
467     const MATreeLookup *lookup, const Channel *references,
468     weighted::State *wp_state, pixel_type_w *predictions) {
469   // We start in position 3 because of 2 static properties + y.
470   size_t offset = 3;
471   constexpr bool compute_properties =
472       mode & kUseTree || mode & kForceComputeProperties;
473   pixel_type_w left = (x ? pp[-1] : (y ? pp[-onerow] : 0));
474   pixel_type_w top = (y ? pp[-onerow] : left);
475   pixel_type_w topleft = (x && y ? pp[-1 - onerow] : left);
476   pixel_type_w topright = (x + 1 < w && y ? pp[1 - onerow] : top);
477   pixel_type_w leftleft = (x > 1 ? pp[-2] : left);
478   pixel_type_w toptop = (y > 1 ? pp[-onerow - onerow] : top);
479   pixel_type_w toprightright = (x + 2 < w && y ? pp[2 - onerow] : topright);
480 
481   if (compute_properties) {
482     // location
483     (*p)[offset++] = x;
484     // neighbors
485     (*p)[offset++] = std::abs(top);
486     (*p)[offset++] = std::abs(left);
487     (*p)[offset++] = top;
488     (*p)[offset++] = left;
489 
490     // local gradient
491     (*p)[offset] = left - (*p)[offset + 1];
492     offset++;
493     // local gradient
494     (*p)[offset++] = left + top - topleft;
495 
496     // FFV1 context properties
497     (*p)[offset++] = left - topleft;
498     (*p)[offset++] = topleft - top;
499     (*p)[offset++] = top - topright;
500     (*p)[offset++] = top - toptop;
501     (*p)[offset++] = left - leftleft;
502   }
503 
504   pixel_type_w wp_pred = 0;
505   if (mode & kUseWP) {
506     wp_pred = wp_state->Predict<compute_properties>(
507         x, y, w, top, left, topright, topleft, toptop, p, offset);
508   }
509   if (compute_properties) {
510     offset += weighted::kNumProperties;
511     // Extra properties.
512     const pixel_type *JXL_RESTRICT rp = references->Row(x);
513     for (size_t i = 0; i < references->w; i++) {
514       (*p)[offset++] = rp[i];
515     }
516   }
517   PredictionResult result;
518   if (mode & kUseTree) {
519     MATreeLookup::LookupResult lr = lookup->Lookup(*p);
520     result.context = lr.context;
521     result.guess = lr.offset;
522     result.multiplier = lr.multiplier;
523     predictor = lr.predictor;
524   }
525   if (mode & kAllPredictions) {
526     for (size_t i = 0; i < kNumModularPredictors; i++) {
527       predictions[i] = PredictOne((Predictor)i, left, top, toptop, topleft,
528                                   topright, leftleft, toprightright, wp_pred);
529     }
530   }
531   result.guess += PredictOne(predictor, left, top, toptop, topleft, topright,
532                              leftleft, toprightright, wp_pred);
533   result.predictor = predictor;
534 
535   return result;
536 }
537 }  // namespace detail
538 
PredictNoTreeNoWP(size_t w,const pixel_type * JXL_RESTRICT pp,const intptr_t onerow,const int x,const int y,Predictor predictor)539 inline PredictionResult PredictNoTreeNoWP(size_t w,
540                                           const pixel_type *JXL_RESTRICT pp,
541                                           const intptr_t onerow, const int x,
542                                           const int y, Predictor predictor) {
543   return detail::Predict</*mode=*/0>(
544       /*p=*/nullptr, w, pp, onerow, x, y, predictor, /*lookup=*/nullptr,
545       /*references=*/nullptr, /*wp_state=*/nullptr, /*predictions=*/nullptr);
546 }
547 
PredictNoTreeWP(size_t w,const pixel_type * JXL_RESTRICT pp,const intptr_t onerow,const int x,const int y,Predictor predictor,weighted::State * wp_state)548 inline PredictionResult PredictNoTreeWP(size_t w,
549                                         const pixel_type *JXL_RESTRICT pp,
550                                         const intptr_t onerow, const int x,
551                                         const int y, Predictor predictor,
552                                         weighted::State *wp_state) {
553   return detail::Predict<detail::kUseWP>(
554       /*p=*/nullptr, w, pp, onerow, x, y, predictor, /*lookup=*/nullptr,
555       /*references=*/nullptr, wp_state, /*predictions=*/nullptr);
556 }
557 
PredictTreeNoWP(Properties * p,size_t w,const pixel_type * JXL_RESTRICT pp,const intptr_t onerow,const int x,const int y,const MATreeLookup & tree_lookup,const Channel & references)558 inline PredictionResult PredictTreeNoWP(Properties *p, size_t w,
559                                         const pixel_type *JXL_RESTRICT pp,
560                                         const intptr_t onerow, const int x,
561                                         const int y,
562                                         const MATreeLookup &tree_lookup,
563                                         const Channel &references) {
564   return detail::Predict<detail::kUseTree>(
565       p, w, pp, onerow, x, y, Predictor::Zero, &tree_lookup, &references,
566       /*wp_state=*/nullptr, /*predictions=*/nullptr);
567 }
568 
PredictTreeWP(Properties * p,size_t w,const pixel_type * JXL_RESTRICT pp,const intptr_t onerow,const int x,const int y,const MATreeLookup & tree_lookup,const Channel & references,weighted::State * wp_state)569 inline PredictionResult PredictTreeWP(Properties *p, size_t w,
570                                       const pixel_type *JXL_RESTRICT pp,
571                                       const intptr_t onerow, const int x,
572                                       const int y,
573                                       const MATreeLookup &tree_lookup,
574                                       const Channel &references,
575                                       weighted::State *wp_state) {
576   return detail::Predict<detail::kUseTree | detail::kUseWP>(
577       p, w, pp, onerow, x, y, Predictor::Zero, &tree_lookup, &references,
578       wp_state, /*predictions=*/nullptr);
579 }
580 
PredictLearn(Properties * p,size_t w,const pixel_type * JXL_RESTRICT pp,const intptr_t onerow,const int x,const int y,Predictor predictor,const Channel & references,weighted::State * wp_state)581 inline PredictionResult PredictLearn(Properties *p, size_t w,
582                                      const pixel_type *JXL_RESTRICT pp,
583                                      const intptr_t onerow, const int x,
584                                      const int y, Predictor predictor,
585                                      const Channel &references,
586                                      weighted::State *wp_state) {
587   return detail::Predict<detail::kForceComputeProperties | detail::kUseWP>(
588       p, w, pp, onerow, x, y, predictor, /*lookup=*/nullptr, &references,
589       wp_state, /*predictions=*/nullptr);
590 }
591 
PredictLearnAll(Properties * p,size_t w,const pixel_type * JXL_RESTRICT pp,const intptr_t onerow,const int x,const int y,const Channel & references,weighted::State * wp_state,pixel_type_w * predictions)592 inline void PredictLearnAll(Properties *p, size_t w,
593                             const pixel_type *JXL_RESTRICT pp,
594                             const intptr_t onerow, const int x, const int y,
595                             const Channel &references,
596                             weighted::State *wp_state,
597                             pixel_type_w *predictions) {
598   detail::Predict<detail::kForceComputeProperties | detail::kUseWP |
599                   detail::kAllPredictions>(
600       p, w, pp, onerow, x, y, Predictor::Zero,
601       /*lookup=*/nullptr, &references, wp_state, predictions);
602 }
603 
PredictAllNoWP(size_t w,const pixel_type * JXL_RESTRICT pp,const intptr_t onerow,const int x,const int y,pixel_type_w * predictions)604 inline void PredictAllNoWP(size_t w, const pixel_type *JXL_RESTRICT pp,
605                            const intptr_t onerow, const int x, const int y,
606                            pixel_type_w *predictions) {
607   detail::Predict<detail::kAllPredictions>(
608       /*p=*/nullptr, w, pp, onerow, x, y, Predictor::Zero,
609       /*lookup=*/nullptr,
610       /*references=*/nullptr, /*wp_state=*/nullptr, predictions);
611 }
612 }  // namespace jxl
613 
614 #endif  // LIB_JXL_MODULAR_ENCODING_CONTEXT_PREDICT_H_
615