1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5 
6 #ifndef LIB_JXL_MODULAR_ENCODING_CONTEXT_PREDICT_H_
7 #define LIB_JXL_MODULAR_ENCODING_CONTEXT_PREDICT_H_
8 
9 #include <utility>
10 #include <vector>
11 
12 #include "lib/jxl/fields.h"
13 #include "lib/jxl/modular/modular_image.h"
14 #include "lib/jxl/modular/options.h"
15 
16 namespace jxl {
17 
18 namespace weighted {
19 constexpr static size_t kNumPredictors = 4;
20 constexpr static int64_t kPredExtraBits = 3;
21 constexpr static int64_t kPredictionRound = ((1 << kPredExtraBits) >> 1) - 1;
22 constexpr static size_t kNumProperties = 1;
23 
24 struct Header : public Fields {
NameHeader25   const char *Name() const override { return "WeightedPredictorHeader"; }
26   // TODO(janwas): move to cc file, avoid including fields.h.
HeaderHeader27   Header() { Bundle::Init(this); }
28 
VisitFieldsHeader29   Status VisitFields(Visitor *JXL_RESTRICT visitor) override {
30     if (visitor->AllDefault(*this, &all_default)) {
31       // Overwrite all serialized fields, but not any nonserialized_*.
32       visitor->SetDefault(this);
33       return true;
34     }
35     auto visit_p = [visitor](pixel_type val, pixel_type *p) {
36       uint32_t up = *p;
37       JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(5, val, &up));
38       *p = up;
39       return Status(true);
40     };
41     JXL_QUIET_RETURN_IF_ERROR(visit_p(16, &p1C));
42     JXL_QUIET_RETURN_IF_ERROR(visit_p(10, &p2C));
43     JXL_QUIET_RETURN_IF_ERROR(visit_p(7, &p3Ca));
44     JXL_QUIET_RETURN_IF_ERROR(visit_p(7, &p3Cb));
45     JXL_QUIET_RETURN_IF_ERROR(visit_p(7, &p3Cc));
46     JXL_QUIET_RETURN_IF_ERROR(visit_p(0, &p3Cd));
47     JXL_QUIET_RETURN_IF_ERROR(visit_p(0, &p3Ce));
48     JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xd, &w[0]));
49     JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xc, &w[1]));
50     JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xc, &w[2]));
51     JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xc, &w[3]));
52     return true;
53   }
54 
55   bool all_default;
56   pixel_type p1C = 0, p2C = 0, p3Ca = 0, p3Cb = 0, p3Cc = 0, p3Cd = 0, p3Ce = 0;
57   uint32_t w[kNumPredictors] = {};
58 };
59 
60 struct State {
61   pixel_type_w prediction[kNumPredictors] = {};
62   pixel_type_w pred = 0;  // *before* removing the added bits.
63   std::vector<uint32_t> pred_errors[kNumPredictors];
64   std::vector<int32_t> error;
65   Header header;
66 
67   // Allows to approximate division by a number from 1 to 64.
68   uint32_t divlookup[64];
69 
AddBitsState70   constexpr static pixel_type_w AddBits(pixel_type_w x) {
71     return uint64_t(x) << kPredExtraBits;
72   }
73 
StateState74   State(Header header, size_t xsize, size_t ysize) : header(header) {
75     // Extra margin to avoid out-of-bounds writes.
76     // All have space for two rows of data.
77     for (size_t i = 0; i < 4; i++) {
78       pred_errors[i].resize((xsize + 2) * 2);
79     }
80     error.resize((xsize + 2) * 2);
81     // Initialize division lookup table.
82     for (int i = 0; i < 64; i++) {
83       divlookup[i] = (1 << 24) / (i + 1);
84     }
85   }
86 
87   // Approximates 4+(maxweight<<24)/(x+1), avoiding division
ErrorWeightState88   JXL_INLINE uint32_t ErrorWeight(uint64_t x, uint32_t maxweight) const {
89     int shift = FloorLog2Nonzero(x + 1) - 5;
90     if (shift < 0) shift = 0;
91     return 4 + ((maxweight * divlookup[x >> shift]) >> shift);
92   }
93 
94   // Approximates the weighted average of the input values with the given
95   // weights, avoiding division. Weights must sum to at least 16.
96   JXL_INLINE pixel_type_w
WeightedAverageState97   WeightedAverage(const pixel_type_w *JXL_RESTRICT p,
98                   std::array<uint32_t, kNumPredictors> w) const {
99     uint32_t weight_sum = 0;
100     for (size_t i = 0; i < kNumPredictors; i++) {
101       weight_sum += w[i];
102     }
103     JXL_DASSERT(weight_sum > 15);
104     uint32_t log_weight = FloorLog2Nonzero(weight_sum);  // at least 4.
105     weight_sum = 0;
106     for (size_t i = 0; i < kNumPredictors; i++) {
107       w[i] >>= log_weight - 4;
108       weight_sum += w[i];
109     }
110     // for rounding.
111     pixel_type_w sum = (weight_sum >> 1) - 1;
112     for (size_t i = 0; i < kNumPredictors; i++) {
113       sum += p[i] * w[i];
114     }
115     return (sum * divlookup[weight_sum - 1]) >> 24;
116   }
117 
118   template <bool compute_properties>
PredictState119   JXL_INLINE pixel_type_w Predict(size_t x, size_t y, size_t xsize,
120                                   pixel_type_w N, pixel_type_w W,
121                                   pixel_type_w NE, pixel_type_w NW,
122                                   pixel_type_w NN, Properties *properties,
123                                   size_t offset) {
124     size_t cur_row = y & 1 ? 0 : (xsize + 2);
125     size_t prev_row = y & 1 ? (xsize + 2) : 0;
126     size_t pos_N = prev_row + x;
127     size_t pos_NE = x < xsize - 1 ? pos_N + 1 : pos_N;
128     size_t pos_NW = x > 0 ? pos_N - 1 : pos_N;
129     std::array<uint32_t, kNumPredictors> weights;
130     for (size_t i = 0; i < kNumPredictors; i++) {
131       // pred_errors[pos_N] also contains the error of pixel W.
132       // pred_errors[pos_NW] also contains the error of pixel WW.
133       weights[i] = pred_errors[i][pos_N] + pred_errors[i][pos_NE] +
134                    pred_errors[i][pos_NW];
135       weights[i] = ErrorWeight(weights[i], header.w[i]);
136     }
137 
138     N = AddBits(N);
139     W = AddBits(W);
140     NE = AddBits(NE);
141     NW = AddBits(NW);
142     NN = AddBits(NN);
143 
144     pixel_type_w teW = x == 0 ? 0 : error[cur_row + x - 1];
145     pixel_type_w teN = error[pos_N];
146     pixel_type_w teNW = error[pos_NW];
147     pixel_type_w sumWN = teN + teW;
148     pixel_type_w teNE = error[pos_NE];
149 
150     if (compute_properties) {
151       pixel_type_w p = teW;
152       if (std::abs(teN) > std::abs(p)) p = teN;
153       if (std::abs(teNW) > std::abs(p)) p = teNW;
154       if (std::abs(teNE) > std::abs(p)) p = teNE;
155       (*properties)[offset++] = p;
156     }
157 
158     prediction[0] = W + NE - N;
159     prediction[1] = N - (((sumWN + teNE) * header.p1C) >> 5);
160     prediction[2] = W - (((sumWN + teNW) * header.p2C) >> 5);
161     prediction[3] =
162         N - ((teNW * header.p3Ca + teN * header.p3Cb + teNE * header.p3Cc +
163               (NN - N) * header.p3Cd + (NW - W) * header.p3Ce) >>
164              5);
165 
166     pred = WeightedAverage(prediction, weights);
167 
168     // If all three have the same sign, skip clamping.
169     if (((teN ^ teW) | (teN ^ teNW)) > 0) {
170       return (pred + kPredictionRound) >> kPredExtraBits;
171     }
172 
173     // Otherwise, clamp to min/max of neighbouring pixels (just W, NE, N).
174     pixel_type_w mx = std::max(W, std::max(NE, N));
175     pixel_type_w mn = std::min(W, std::min(NE, N));
176     pred = std::max(mn, std::min(mx, pred));
177     return (pred + kPredictionRound) >> kPredExtraBits;
178   }
179 
UpdateErrorsState180   JXL_INLINE void UpdateErrors(pixel_type_w val, size_t x, size_t y,
181                                size_t xsize) {
182     size_t cur_row = y & 1 ? 0 : (xsize + 2);
183     size_t prev_row = y & 1 ? (xsize + 2) : 0;
184     val = AddBits(val);
185     error[cur_row + x] = ClampToRange<pixel_type>(pred - val);
186     for (size_t i = 0; i < kNumPredictors; i++) {
187       pixel_type_w err =
188           (std::abs(prediction[i] - val) + kPredictionRound) >> kPredExtraBits;
189       // For predicting in the next row.
190       pred_errors[i][cur_row + x] = err;
191       // Add the error on this pixel to the error on the NE pixel. This has the
192       // effect of adding the error on this pixel to the E and EE pixels.
193       pred_errors[i][prev_row + x + 1] += err;
194     }
195   }
196 };
197 
198 // Encoder helper function to set the parameters to some presets.
PredictorMode(int i,Header * header)199 inline void PredictorMode(int i, Header *header) {
200   switch (i) {
201     case 0:
202       // ~ lossless16 predictor
203       header->w[0] = 0xd;
204       header->w[1] = 0xc;
205       header->w[2] = 0xc;
206       header->w[3] = 0xc;
207       header->p1C = 16;
208       header->p2C = 10;
209       header->p3Ca = 7;
210       header->p3Cb = 7;
211       header->p3Cc = 7;
212       header->p3Cd = 0;
213       header->p3Ce = 0;
214       break;
215     case 1:
216       // ~ default lossless8 predictor
217       header->w[0] = 0xd;
218       header->w[1] = 0xc;
219       header->w[2] = 0xc;
220       header->w[3] = 0xb;
221       header->p1C = 8;
222       header->p2C = 8;
223       header->p3Ca = 4;
224       header->p3Cb = 0;
225       header->p3Cc = 3;
226       header->p3Cd = 23;
227       header->p3Ce = 2;
228       break;
229     case 2:
230       // ~ west lossless8 predictor
231       header->w[0] = 0xd;
232       header->w[1] = 0xc;
233       header->w[2] = 0xd;
234       header->w[3] = 0xc;
235       header->p1C = 10;
236       header->p2C = 9;
237       header->p3Ca = 7;
238       header->p3Cb = 0;
239       header->p3Cc = 0;
240       header->p3Cd = 16;
241       header->p3Ce = 9;
242       break;
243     case 3:
244       // ~ north lossless8 predictor
245       header->w[0] = 0xd;
246       header->w[1] = 0xd;
247       header->w[2] = 0xc;
248       header->w[3] = 0xc;
249       header->p1C = 16;
250       header->p2C = 8;
251       header->p3Ca = 0;
252       header->p3Cb = 16;
253       header->p3Cc = 0;
254       header->p3Cd = 23;
255       header->p3Ce = 0;
256       break;
257     case 4:
258     default:
259       // something else, because why not
260       header->w[0] = 0xd;
261       header->w[1] = 0xc;
262       header->w[2] = 0xc;
263       header->w[3] = 0xc;
264       header->p1C = 10;
265       header->p2C = 10;
266       header->p3Ca = 5;
267       header->p3Cb = 5;
268       header->p3Cc = 5;
269       header->p3Cd = 12;
270       header->p3Ce = 4;
271       break;
272   }
273 }
274 }  // namespace weighted
275 
276 // Stores a node and its two children at the same time. This significantly
277 // reduces the number of branches needed during decoding.
278 struct FlatDecisionNode {
279   // Property + splitval of the top node.
280   int32_t property0;  // -1 if leaf.
281   union {
282     PropertyVal splitval0;
283     Predictor predictor;
284   };
285   uint32_t childID;  // childID is ctx id if leaf.
286   // Property+splitval of the two child nodes.
287   union {
288     PropertyVal splitvals[2];
289     int32_t multiplier;
290   };
291   union {
292     int32_t properties[2];
293     int64_t predictor_offset;
294   };
295 };
296 using FlatTree = std::vector<FlatDecisionNode>;
297 
298 class MATreeLookup {
299  public:
MATreeLookup(const FlatTree & tree)300   explicit MATreeLookup(const FlatTree &tree) : nodes_(tree) {}
301   struct LookupResult {
302     uint32_t context;
303     Predictor predictor;
304     int64_t offset;
305     int32_t multiplier;
306   };
Lookup(const Properties & properties)307   LookupResult Lookup(const Properties &properties) const {
308     uint32_t pos = 0;
309     while (true) {
310       const FlatDecisionNode &node = nodes_[pos];
311       if (node.property0 < 0) {
312         return {node.childID, node.predictor, node.predictor_offset,
313                 node.multiplier};
314       }
315       bool p0 = properties[node.property0] <= node.splitval0;
316       uint32_t off0 = properties[node.properties[0]] <= node.splitvals[0];
317       uint32_t off1 = 2 | (properties[node.properties[1]] <= node.splitvals[1]);
318       pos = node.childID + (p0 ? off1 : off0);
319     }
320   }
321 
322  private:
323   const FlatTree &nodes_;
324 };
325 
326 static constexpr size_t kExtraPropsPerChannel = 4;
327 static constexpr size_t kNumNonrefProperties =
328     kNumStaticProperties + 13 + weighted::kNumProperties;
329 
330 constexpr size_t kWPProp = kNumNonrefProperties - weighted::kNumProperties;
331 constexpr size_t kGradientProp = 9;
332 
333 // Clamps gradient to the min/max of n, w (and l, implicitly).
ClampedGradient(const int32_t n,const int32_t w,const int32_t l)334 static JXL_INLINE int32_t ClampedGradient(const int32_t n, const int32_t w,
335                                           const int32_t l) {
336   const int32_t m = std::min(n, w);
337   const int32_t M = std::max(n, w);
338   // The end result of this operation doesn't overflow or underflow if the
339   // result is between m and M, but the intermediate value may overflow, so we
340   // do the intermediate operations in uint32_t and check later if we had an
341   // overflow or underflow condition comparing m, M and l directly.
342   // grad = M + m - l = n + w - l
343   const int32_t grad =
344       static_cast<int32_t>(static_cast<uint32_t>(n) + static_cast<uint32_t>(w) -
345                            static_cast<uint32_t>(l));
346   // We use two sets of ternary operators to force the evaluation of them in
347   // any case, allowing the compiler to avoid branches and use cmovl/cmovg in
348   // x86.
349   const int32_t grad_clamp_M = (l < m) ? M : grad;
350   return (l > M) ? m : grad_clamp_M;
351 }
352 
Select(pixel_type_w a,pixel_type_w b,pixel_type_w c)353 inline pixel_type_w Select(pixel_type_w a, pixel_type_w b, pixel_type_w c) {
354   pixel_type_w p = a + b - c;
355   pixel_type_w pa = std::abs(p - a);
356   pixel_type_w pb = std::abs(p - b);
357   return pa < pb ? a : b;
358 }
359 
PrecomputeReferences(const Channel & ch,size_t y,const Image & image,uint32_t i,Channel * references)360 inline void PrecomputeReferences(const Channel &ch, size_t y,
361                                  const Image &image, uint32_t i,
362                                  Channel *references) {
363   ZeroFillImage(&references->plane);
364   uint32_t offset = 0;
365   size_t num_extra_props = references->w;
366   intptr_t onerow = references->plane.PixelsPerRow();
367   for (int32_t j = static_cast<int32_t>(i) - 1;
368        j >= 0 && offset < num_extra_props; j--) {
369     if (image.channel[j].w != image.channel[i].w ||
370         image.channel[j].h != image.channel[i].h) {
371       continue;
372     }
373     if (image.channel[j].hshift != image.channel[i].hshift) continue;
374     if (image.channel[j].vshift != image.channel[i].vshift) continue;
375     pixel_type *JXL_RESTRICT rp = references->Row(0) + offset;
376     const pixel_type *JXL_RESTRICT rpp = image.channel[j].Row(y);
377     const pixel_type *JXL_RESTRICT rpprev = image.channel[j].Row(y ? y - 1 : 0);
378     for (size_t x = 0; x < ch.w; x++, rp += onerow) {
379       pixel_type_w v = rpp[x];
380       rp[0] = std::abs(v);
381       rp[1] = v;
382       pixel_type_w vleft = (x ? rpp[x - 1] : 0);
383       pixel_type_w vtop = (y ? rpprev[x] : vleft);
384       pixel_type_w vtopleft = (x && y ? rpprev[x - 1] : vleft);
385       pixel_type_w vpredicted = ClampedGradient(vleft, vtop, vtopleft);
386       rp[2] = std::abs(v - vpredicted);
387       rp[3] = v - vpredicted;
388     }
389 
390     offset += kExtraPropsPerChannel;
391   }
392 }
393 
394 struct PredictionResult {
395   int context = 0;
396   pixel_type_w guess = 0;
397   Predictor predictor;
398   int32_t multiplier;
399 };
400 
PropertyName(size_t i)401 inline std::string PropertyName(size_t i) {
402   static_assert(kNumNonrefProperties == 16, "Update this function");
403   switch (i) {
404     case 0:
405       return "c";
406     case 1:
407       return "g";
408     case 2:
409       return "y";
410     case 3:
411       return "x";
412     case 4:
413       return "|N|";
414     case 5:
415       return "|W|";
416     case 6:
417       return "N";
418     case 7:
419       return "W";
420     case 8:
421       return "W-WW-NW+NWW";
422     case 9:
423       return "W+N-NW";
424     case 10:
425       return "W-NW";
426     case 11:
427       return "NW-N";
428     case 12:
429       return "N-NE";
430     case 13:
431       return "N-NN";
432     case 14:
433       return "W-WW";
434     case 15:
435       return "WGH";
436     default:
437       return "ch[" + ToString(15 - (int)i) + "]";
438   }
439 }
440 
InitPropsRow(Properties * p,const std::array<pixel_type,kNumStaticProperties> & static_props,const int y)441 inline void InitPropsRow(
442     Properties *p,
443     const std::array<pixel_type, kNumStaticProperties> &static_props,
444     const int y) {
445   for (size_t i = 0; i < kNumStaticProperties; i++) {
446     (*p)[i] = static_props[i];
447   }
448   (*p)[2] = y;
449   (*p)[9] = 0;  // local gradient.
450 }
451 
452 namespace detail {
453 enum PredictorMode {
454   kUseTree = 1,
455   kUseWP = 2,
456   kForceComputeProperties = 4,
457   kAllPredictions = 8,
458 };
459 
PredictOne(Predictor p,pixel_type_w left,pixel_type_w top,pixel_type_w toptop,pixel_type_w topleft,pixel_type_w topright,pixel_type_w leftleft,pixel_type_w toprightright,pixel_type_w wp_pred)460 JXL_INLINE pixel_type_w PredictOne(Predictor p, pixel_type_w left,
461                                    pixel_type_w top, pixel_type_w toptop,
462                                    pixel_type_w topleft, pixel_type_w topright,
463                                    pixel_type_w leftleft,
464                                    pixel_type_w toprightright,
465                                    pixel_type_w wp_pred) {
466   switch (p) {
467     case Predictor::Zero:
468       return pixel_type_w{0};
469     case Predictor::Left:
470       return left;
471     case Predictor::Top:
472       return top;
473     case Predictor::Select:
474       return Select(left, top, topleft);
475     case Predictor::Weighted:
476       return wp_pred;
477     case Predictor::Gradient:
478       return pixel_type_w{ClampedGradient(left, top, topleft)};
479     case Predictor::TopLeft:
480       return topleft;
481     case Predictor::TopRight:
482       return topright;
483     case Predictor::LeftLeft:
484       return leftleft;
485     case Predictor::Average0:
486       return (left + top) / 2;
487     case Predictor::Average1:
488       return (left + topleft) / 2;
489     case Predictor::Average2:
490       return (topleft + top) / 2;
491     case Predictor::Average3:
492       return (top + topright) / 2;
493     case Predictor::Average4:
494       return (6 * top - 2 * toptop + 7 * left + 1 * leftleft +
495               1 * toprightright + 3 * topright + 8) /
496              16;
497     default:
498       return pixel_type_w{0};
499   }
500 }
501 
502 template <int mode>
Predict(Properties * p,size_t w,const pixel_type * JXL_RESTRICT pp,const intptr_t onerow,const size_t x,const size_t y,Predictor predictor,const MATreeLookup * lookup,const Channel * references,weighted::State * wp_state,pixel_type_w * predictions)503 inline PredictionResult Predict(
504     Properties *p, size_t w, const pixel_type *JXL_RESTRICT pp,
505     const intptr_t onerow, const size_t x, const size_t y, Predictor predictor,
506     const MATreeLookup *lookup, const Channel *references,
507     weighted::State *wp_state, pixel_type_w *predictions) {
508   // We start in position 3 because of 2 static properties + y.
509   size_t offset = 3;
510   constexpr bool compute_properties =
511       mode & kUseTree || mode & kForceComputeProperties;
512   pixel_type_w left = (x ? pp[-1] : (y ? pp[-onerow] : 0));
513   pixel_type_w top = (y ? pp[-onerow] : left);
514   pixel_type_w topleft = (x && y ? pp[-1 - onerow] : left);
515   pixel_type_w topright = (x + 1 < w && y ? pp[1 - onerow] : top);
516   pixel_type_w leftleft = (x > 1 ? pp[-2] : left);
517   pixel_type_w toptop = (y > 1 ? pp[-onerow - onerow] : top);
518   pixel_type_w toprightright = (x + 2 < w && y ? pp[2 - onerow] : topright);
519 
520   if (compute_properties) {
521     // location
522     (*p)[offset++] = x;
523     // neighbors
524     (*p)[offset++] = std::abs(top);
525     (*p)[offset++] = std::abs(left);
526     (*p)[offset++] = top;
527     (*p)[offset++] = left;
528 
529     // local gradient
530     (*p)[offset] = left - (*p)[offset + 1];
531     offset++;
532     // local gradient
533     (*p)[offset++] = left + top - topleft;
534 
535     // FFV1 context properties
536     (*p)[offset++] = left - topleft;
537     (*p)[offset++] = topleft - top;
538     (*p)[offset++] = top - topright;
539     (*p)[offset++] = top - toptop;
540     (*p)[offset++] = left - leftleft;
541   }
542 
543   pixel_type_w wp_pred = 0;
544   if (mode & kUseWP) {
545     wp_pred = wp_state->Predict<compute_properties>(
546         x, y, w, top, left, topright, topleft, toptop, p, offset);
547   }
548   if (compute_properties) {
549     offset += weighted::kNumProperties;
550     // Extra properties.
551     const pixel_type *JXL_RESTRICT rp = references->Row(x);
552     for (size_t i = 0; i < references->w; i++) {
553       (*p)[offset++] = rp[i];
554     }
555   }
556   PredictionResult result;
557   if (mode & kUseTree) {
558     MATreeLookup::LookupResult lr = lookup->Lookup(*p);
559     result.context = lr.context;
560     result.guess = lr.offset;
561     result.multiplier = lr.multiplier;
562     predictor = lr.predictor;
563   }
564   if (mode & kAllPredictions) {
565     for (size_t i = 0; i < kNumModularPredictors; i++) {
566       predictions[i] = PredictOne((Predictor)i, left, top, toptop, topleft,
567                                   topright, leftleft, toprightright, wp_pred);
568     }
569   }
570   result.guess += PredictOne(predictor, left, top, toptop, topleft, topright,
571                              leftleft, toprightright, wp_pred);
572   result.predictor = predictor;
573 
574   return result;
575 }
576 }  // namespace detail
577 
PredictNoTreeNoWP(size_t w,const pixel_type * JXL_RESTRICT pp,const intptr_t onerow,const int x,const int y,Predictor predictor)578 inline PredictionResult PredictNoTreeNoWP(size_t w,
579                                           const pixel_type *JXL_RESTRICT pp,
580                                           const intptr_t onerow, const int x,
581                                           const int y, Predictor predictor) {
582   return detail::Predict</*mode=*/0>(
583       /*p=*/nullptr, w, pp, onerow, x, y, predictor, /*lookup=*/nullptr,
584       /*references=*/nullptr, /*wp_state=*/nullptr, /*predictions=*/nullptr);
585 }
586 
PredictNoTreeWP(size_t w,const pixel_type * JXL_RESTRICT pp,const intptr_t onerow,const int x,const int y,Predictor predictor,weighted::State * wp_state)587 inline PredictionResult PredictNoTreeWP(size_t w,
588                                         const pixel_type *JXL_RESTRICT pp,
589                                         const intptr_t onerow, const int x,
590                                         const int y, Predictor predictor,
591                                         weighted::State *wp_state) {
592   return detail::Predict<detail::kUseWP>(
593       /*p=*/nullptr, w, pp, onerow, x, y, predictor, /*lookup=*/nullptr,
594       /*references=*/nullptr, wp_state, /*predictions=*/nullptr);
595 }
596 
PredictTreeNoWP(Properties * p,size_t w,const pixel_type * JXL_RESTRICT pp,const intptr_t onerow,const int x,const int y,const MATreeLookup & tree_lookup,const Channel & references)597 inline PredictionResult PredictTreeNoWP(Properties *p, size_t w,
598                                         const pixel_type *JXL_RESTRICT pp,
599                                         const intptr_t onerow, const int x,
600                                         const int y,
601                                         const MATreeLookup &tree_lookup,
602                                         const Channel &references) {
603   return detail::Predict<detail::kUseTree>(
604       p, w, pp, onerow, x, y, Predictor::Zero, &tree_lookup, &references,
605       /*wp_state=*/nullptr, /*predictions=*/nullptr);
606 }
607 
PredictTreeWP(Properties * p,size_t w,const pixel_type * JXL_RESTRICT pp,const intptr_t onerow,const int x,const int y,const MATreeLookup & tree_lookup,const Channel & references,weighted::State * wp_state)608 inline PredictionResult PredictTreeWP(Properties *p, size_t w,
609                                       const pixel_type *JXL_RESTRICT pp,
610                                       const intptr_t onerow, const int x,
611                                       const int y,
612                                       const MATreeLookup &tree_lookup,
613                                       const Channel &references,
614                                       weighted::State *wp_state) {
615   return detail::Predict<detail::kUseTree | detail::kUseWP>(
616       p, w, pp, onerow, x, y, Predictor::Zero, &tree_lookup, &references,
617       wp_state, /*predictions=*/nullptr);
618 }
619 
PredictLearn(Properties * p,size_t w,const pixel_type * JXL_RESTRICT pp,const intptr_t onerow,const int x,const int y,Predictor predictor,const Channel & references,weighted::State * wp_state)620 inline PredictionResult PredictLearn(Properties *p, size_t w,
621                                      const pixel_type *JXL_RESTRICT pp,
622                                      const intptr_t onerow, const int x,
623                                      const int y, Predictor predictor,
624                                      const Channel &references,
625                                      weighted::State *wp_state) {
626   return detail::Predict<detail::kForceComputeProperties | detail::kUseWP>(
627       p, w, pp, onerow, x, y, predictor, /*lookup=*/nullptr, &references,
628       wp_state, /*predictions=*/nullptr);
629 }
630 
PredictLearnAll(Properties * p,size_t w,const pixel_type * JXL_RESTRICT pp,const intptr_t onerow,const int x,const int y,const Channel & references,weighted::State * wp_state,pixel_type_w * predictions)631 inline void PredictLearnAll(Properties *p, size_t w,
632                             const pixel_type *JXL_RESTRICT pp,
633                             const intptr_t onerow, const int x, const int y,
634                             const Channel &references,
635                             weighted::State *wp_state,
636                             pixel_type_w *predictions) {
637   detail::Predict<detail::kForceComputeProperties | detail::kUseWP |
638                   detail::kAllPredictions>(
639       p, w, pp, onerow, x, y, Predictor::Zero,
640       /*lookup=*/nullptr, &references, wp_state, predictions);
641 }
642 
PredictAllNoWP(size_t w,const pixel_type * JXL_RESTRICT pp,const intptr_t onerow,const int x,const int y,pixel_type_w * predictions)643 inline void PredictAllNoWP(size_t w, const pixel_type *JXL_RESTRICT pp,
644                            const intptr_t onerow, const int x, const int y,
645                            pixel_type_w *predictions) {
646   detail::Predict<detail::kAllPredictions>(
647       /*p=*/nullptr, w, pp, onerow, x, y, Predictor::Zero,
648       /*lookup=*/nullptr,
649       /*references=*/nullptr, /*wp_state=*/nullptr, predictions);
650 }
651 }  // namespace jxl
652 
653 #endif  // LIB_JXL_MODULAR_ENCODING_CONTEXT_PREDICT_H_
654