1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5
6 #ifndef LIB_JXL_MODULAR_ENCODING_CONTEXT_PREDICT_H_
7 #define LIB_JXL_MODULAR_ENCODING_CONTEXT_PREDICT_H_
8
9 #include <utility>
10 #include <vector>
11
12 #include "lib/jxl/fields.h"
13 #include "lib/jxl/modular/modular_image.h"
14 #include "lib/jxl/modular/options.h"
15
16 namespace jxl {
17
18 namespace weighted {
19 constexpr static size_t kNumPredictors = 4;
20 constexpr static int64_t kPredExtraBits = 3;
21 constexpr static int64_t kPredictionRound = ((1 << kPredExtraBits) >> 1) - 1;
22 constexpr static size_t kNumProperties = 1;
23
24 struct Header : public Fields {
25 JXL_FIELDS_NAME(WeightedPredictorHeader)
26 // TODO(janwas): move to cc file, avoid including fields.h.
HeaderHeader27 Header() { Bundle::Init(this); }
28
VisitFieldsHeader29 Status VisitFields(Visitor *JXL_RESTRICT visitor) override {
30 if (visitor->AllDefault(*this, &all_default)) {
31 // Overwrite all serialized fields, but not any nonserialized_*.
32 visitor->SetDefault(this);
33 return true;
34 }
35 auto visit_p = [visitor](pixel_type val, pixel_type *p) {
36 uint32_t up = *p;
37 JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(5, val, &up));
38 *p = up;
39 return Status(true);
40 };
41 JXL_QUIET_RETURN_IF_ERROR(visit_p(16, &p1C));
42 JXL_QUIET_RETURN_IF_ERROR(visit_p(10, &p2C));
43 JXL_QUIET_RETURN_IF_ERROR(visit_p(7, &p3Ca));
44 JXL_QUIET_RETURN_IF_ERROR(visit_p(7, &p3Cb));
45 JXL_QUIET_RETURN_IF_ERROR(visit_p(7, &p3Cc));
46 JXL_QUIET_RETURN_IF_ERROR(visit_p(0, &p3Cd));
47 JXL_QUIET_RETURN_IF_ERROR(visit_p(0, &p3Ce));
48 JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xd, &w[0]));
49 JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xc, &w[1]));
50 JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xc, &w[2]));
51 JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xc, &w[3]));
52 return true;
53 }
54
55 bool all_default;
56 pixel_type p1C = 0, p2C = 0, p3Ca = 0, p3Cb = 0, p3Cc = 0, p3Cd = 0, p3Ce = 0;
57 uint32_t w[kNumPredictors] = {};
58 };
59
60 struct State {
61 pixel_type_w prediction[kNumPredictors] = {};
62 pixel_type_w pred = 0; // *before* removing the added bits.
63 std::vector<uint32_t> pred_errors[kNumPredictors];
64 std::vector<int32_t> error;
65 Header header;
66
67 // Allows to approximate division by a number from 1 to 64.
68 uint32_t divlookup[64];
69
AddBitsState70 constexpr static pixel_type_w AddBits(pixel_type_w x) {
71 return uint64_t(x) << kPredExtraBits;
72 }
73
StateState74 State(Header header, size_t xsize, size_t ysize) : header(header) {
75 // Extra margin to avoid out-of-bounds writes.
76 // All have space for two rows of data.
77 for (size_t i = 0; i < 4; i++) {
78 pred_errors[i].resize((xsize + 2) * 2);
79 }
80 error.resize((xsize + 2) * 2);
81 // Initialize division lookup table.
82 for (int i = 0; i < 64; i++) {
83 divlookup[i] = (1 << 24) / (i + 1);
84 }
85 }
86
87 // Approximates 4+(maxweight<<24)/(x+1), avoiding division
ErrorWeightState88 JXL_INLINE uint32_t ErrorWeight(uint64_t x, uint32_t maxweight) const {
89 int shift = static_cast<int>(FloorLog2Nonzero(x + 1)) - 5;
90 if (shift < 0) shift = 0;
91 return 4 + ((maxweight * divlookup[x >> shift]) >> shift);
92 }
93
94 // Approximates the weighted average of the input values with the given
95 // weights, avoiding division. Weights must sum to at least 16.
96 JXL_INLINE pixel_type_w
WeightedAverageState97 WeightedAverage(const pixel_type_w *JXL_RESTRICT p,
98 std::array<uint32_t, kNumPredictors> w) const {
99 uint32_t weight_sum = 0;
100 for (size_t i = 0; i < kNumPredictors; i++) {
101 weight_sum += w[i];
102 }
103 JXL_DASSERT(weight_sum > 15);
104 uint32_t log_weight = FloorLog2Nonzero(weight_sum); // at least 4.
105 weight_sum = 0;
106 for (size_t i = 0; i < kNumPredictors; i++) {
107 w[i] >>= log_weight - 4;
108 weight_sum += w[i];
109 }
110 // for rounding.
111 pixel_type_w sum = (weight_sum >> 1) - 1;
112 for (size_t i = 0; i < kNumPredictors; i++) {
113 sum += p[i] * w[i];
114 }
115 return (sum * divlookup[weight_sum - 1]) >> 24;
116 }
117
118 template <bool compute_properties>
PredictState119 JXL_INLINE pixel_type_w Predict(size_t x, size_t y, size_t xsize,
120 pixel_type_w N, pixel_type_w W,
121 pixel_type_w NE, pixel_type_w NW,
122 pixel_type_w NN, Properties *properties,
123 size_t offset) {
124 size_t cur_row = y & 1 ? 0 : (xsize + 2);
125 size_t prev_row = y & 1 ? (xsize + 2) : 0;
126 size_t pos_N = prev_row + x;
127 size_t pos_NE = x < xsize - 1 ? pos_N + 1 : pos_N;
128 size_t pos_NW = x > 0 ? pos_N - 1 : pos_N;
129 std::array<uint32_t, kNumPredictors> weights;
130 for (size_t i = 0; i < kNumPredictors; i++) {
131 // pred_errors[pos_N] also contains the error of pixel W.
132 // pred_errors[pos_NW] also contains the error of pixel WW.
133 weights[i] = pred_errors[i][pos_N] + pred_errors[i][pos_NE] +
134 pred_errors[i][pos_NW];
135 weights[i] = ErrorWeight(weights[i], header.w[i]);
136 }
137
138 N = AddBits(N);
139 W = AddBits(W);
140 NE = AddBits(NE);
141 NW = AddBits(NW);
142 NN = AddBits(NN);
143
144 pixel_type_w teW = x == 0 ? 0 : error[cur_row + x - 1];
145 pixel_type_w teN = error[pos_N];
146 pixel_type_w teNW = error[pos_NW];
147 pixel_type_w sumWN = teN + teW;
148 pixel_type_w teNE = error[pos_NE];
149
150 if (compute_properties) {
151 pixel_type_w p = teW;
152 if (std::abs(teN) > std::abs(p)) p = teN;
153 if (std::abs(teNW) > std::abs(p)) p = teNW;
154 if (std::abs(teNE) > std::abs(p)) p = teNE;
155 (*properties)[offset++] = p;
156 }
157
158 prediction[0] = W + NE - N;
159 prediction[1] = N - (((sumWN + teNE) * header.p1C) >> 5);
160 prediction[2] = W - (((sumWN + teNW) * header.p2C) >> 5);
161 prediction[3] =
162 N - ((teNW * header.p3Ca + teN * header.p3Cb + teNE * header.p3Cc +
163 (NN - N) * header.p3Cd + (NW - W) * header.p3Ce) >>
164 5);
165
166 pred = WeightedAverage(prediction, weights);
167
168 // If all three have the same sign, skip clamping.
169 if (((teN ^ teW) | (teN ^ teNW)) > 0) {
170 return (pred + kPredictionRound) >> kPredExtraBits;
171 }
172
173 // Otherwise, clamp to min/max of neighbouring pixels (just W, NE, N).
174 pixel_type_w mx = std::max(W, std::max(NE, N));
175 pixel_type_w mn = std::min(W, std::min(NE, N));
176 pred = std::max(mn, std::min(mx, pred));
177 return (pred + kPredictionRound) >> kPredExtraBits;
178 }
179
UpdateErrorsState180 JXL_INLINE void UpdateErrors(pixel_type_w val, size_t x, size_t y,
181 size_t xsize) {
182 size_t cur_row = y & 1 ? 0 : (xsize + 2);
183 size_t prev_row = y & 1 ? (xsize + 2) : 0;
184 val = AddBits(val);
185 error[cur_row + x] = pred - val;
186 for (size_t i = 0; i < kNumPredictors; i++) {
187 pixel_type_w err =
188 (std::abs(prediction[i] - val) + kPredictionRound) >> kPredExtraBits;
189 // For predicting in the next row.
190 pred_errors[i][cur_row + x] = err;
191 // Add the error on this pixel to the error on the NE pixel. This has the
192 // effect of adding the error on this pixel to the E and EE pixels.
193 pred_errors[i][prev_row + x + 1] += err;
194 }
195 }
196 };
197
198 // Encoder helper function to set the parameters to some presets.
PredictorMode(int i,Header * header)199 inline void PredictorMode(int i, Header *header) {
200 switch (i) {
201 case 0:
202 // ~ lossless16 predictor
203 header->w[0] = 0xd;
204 header->w[1] = 0xc;
205 header->w[2] = 0xc;
206 header->w[3] = 0xc;
207 header->p1C = 16;
208 header->p2C = 10;
209 header->p3Ca = 7;
210 header->p3Cb = 7;
211 header->p3Cc = 7;
212 header->p3Cd = 0;
213 header->p3Ce = 0;
214 break;
215 case 1:
216 // ~ default lossless8 predictor
217 header->w[0] = 0xd;
218 header->w[1] = 0xc;
219 header->w[2] = 0xc;
220 header->w[3] = 0xb;
221 header->p1C = 8;
222 header->p2C = 8;
223 header->p3Ca = 4;
224 header->p3Cb = 0;
225 header->p3Cc = 3;
226 header->p3Cd = 23;
227 header->p3Ce = 2;
228 break;
229 case 2:
230 // ~ west lossless8 predictor
231 header->w[0] = 0xd;
232 header->w[1] = 0xc;
233 header->w[2] = 0xd;
234 header->w[3] = 0xc;
235 header->p1C = 10;
236 header->p2C = 9;
237 header->p3Ca = 7;
238 header->p3Cb = 0;
239 header->p3Cc = 0;
240 header->p3Cd = 16;
241 header->p3Ce = 9;
242 break;
243 case 3:
244 // ~ north lossless8 predictor
245 header->w[0] = 0xd;
246 header->w[1] = 0xd;
247 header->w[2] = 0xc;
248 header->w[3] = 0xc;
249 header->p1C = 16;
250 header->p2C = 8;
251 header->p3Ca = 0;
252 header->p3Cb = 16;
253 header->p3Cc = 0;
254 header->p3Cd = 23;
255 header->p3Ce = 0;
256 break;
257 case 4:
258 default:
259 // something else, because why not
260 header->w[0] = 0xd;
261 header->w[1] = 0xc;
262 header->w[2] = 0xc;
263 header->w[3] = 0xc;
264 header->p1C = 10;
265 header->p2C = 10;
266 header->p3Ca = 5;
267 header->p3Cb = 5;
268 header->p3Cc = 5;
269 header->p3Cd = 12;
270 header->p3Ce = 4;
271 break;
272 }
273 }
274 } // namespace weighted
275
276 // Stores a node and its two children at the same time. This significantly
277 // reduces the number of branches needed during decoding.
278 struct FlatDecisionNode {
279 // Property + splitval of the top node.
280 int32_t property0; // -1 if leaf.
281 union {
282 PropertyVal splitval0;
283 Predictor predictor;
284 };
285 uint32_t childID; // childID is ctx id if leaf.
286 // Property+splitval of the two child nodes.
287 union {
288 PropertyVal splitvals[2];
289 int32_t multiplier;
290 };
291 union {
292 int32_t properties[2];
293 int64_t predictor_offset;
294 };
295 };
296 using FlatTree = std::vector<FlatDecisionNode>;
297
298 class MATreeLookup {
299 public:
MATreeLookup(const FlatTree & tree)300 explicit MATreeLookup(const FlatTree &tree) : nodes_(tree) {}
301 struct LookupResult {
302 uint32_t context;
303 Predictor predictor;
304 int64_t offset;
305 int32_t multiplier;
306 };
Lookup(const Properties & properties)307 LookupResult Lookup(const Properties &properties) const {
308 uint32_t pos = 0;
309 while (true) {
310 const FlatDecisionNode &node = nodes_[pos];
311 if (node.property0 < 0) {
312 return {node.childID, node.predictor, node.predictor_offset,
313 node.multiplier};
314 }
315 bool p0 = properties[node.property0] <= node.splitval0;
316 uint32_t off0 = properties[node.properties[0]] <= node.splitvals[0];
317 uint32_t off1 =
318 2 | (properties[node.properties[1]] <= node.splitvals[1] ? 1 : 0);
319 pos = node.childID + (p0 ? off1 : off0);
320 }
321 }
322
323 private:
324 const FlatTree &nodes_;
325 };
326
327 static constexpr size_t kExtraPropsPerChannel = 4;
328 static constexpr size_t kNumNonrefProperties =
329 kNumStaticProperties + 13 + weighted::kNumProperties;
330
331 constexpr size_t kWPProp = kNumNonrefProperties - weighted::kNumProperties;
332 constexpr size_t kGradientProp = 9;
333
334 // Clamps gradient to the min/max of n, w (and l, implicitly).
ClampedGradient(const int32_t n,const int32_t w,const int32_t l)335 static JXL_INLINE int32_t ClampedGradient(const int32_t n, const int32_t w,
336 const int32_t l) {
337 const int32_t m = std::min(n, w);
338 const int32_t M = std::max(n, w);
339 // The end result of this operation doesn't overflow or underflow if the
340 // result is between m and M, but the intermediate value may overflow, so we
341 // do the intermediate operations in uint32_t and check later if we had an
342 // overflow or underflow condition comparing m, M and l directly.
343 // grad = M + m - l = n + w - l
344 const int32_t grad =
345 static_cast<int32_t>(static_cast<uint32_t>(n) + static_cast<uint32_t>(w) -
346 static_cast<uint32_t>(l));
347 // We use two sets of ternary operators to force the evaluation of them in
348 // any case, allowing the compiler to avoid branches and use cmovl/cmovg in
349 // x86.
350 const int32_t grad_clamp_M = (l < m) ? M : grad;
351 return (l > M) ? m : grad_clamp_M;
352 }
353
Select(pixel_type_w a,pixel_type_w b,pixel_type_w c)354 inline pixel_type_w Select(pixel_type_w a, pixel_type_w b, pixel_type_w c) {
355 pixel_type_w p = a + b - c;
356 pixel_type_w pa = std::abs(p - a);
357 pixel_type_w pb = std::abs(p - b);
358 return pa < pb ? a : b;
359 }
360
PrecomputeReferences(const Channel & ch,size_t y,const Image & image,uint32_t i,Channel * references)361 inline void PrecomputeReferences(const Channel &ch, size_t y,
362 const Image &image, uint32_t i,
363 Channel *references) {
364 ZeroFillImage(&references->plane);
365 uint32_t offset = 0;
366 size_t num_extra_props = references->w;
367 intptr_t onerow = references->plane.PixelsPerRow();
368 for (int32_t j = static_cast<int32_t>(i) - 1;
369 j >= 0 && offset < num_extra_props; j--) {
370 if (image.channel[j].w != image.channel[i].w ||
371 image.channel[j].h != image.channel[i].h) {
372 continue;
373 }
374 if (image.channel[j].hshift != image.channel[i].hshift) continue;
375 if (image.channel[j].vshift != image.channel[i].vshift) continue;
376 pixel_type *JXL_RESTRICT rp = references->Row(0) + offset;
377 const pixel_type *JXL_RESTRICT rpp = image.channel[j].Row(y);
378 const pixel_type *JXL_RESTRICT rpprev = image.channel[j].Row(y ? y - 1 : 0);
379 for (size_t x = 0; x < ch.w; x++, rp += onerow) {
380 pixel_type_w v = rpp[x];
381 rp[0] = std::abs(v);
382 rp[1] = v;
383 pixel_type_w vleft = (x ? rpp[x - 1] : 0);
384 pixel_type_w vtop = (y ? rpprev[x] : vleft);
385 pixel_type_w vtopleft = (x && y ? rpprev[x - 1] : vleft);
386 pixel_type_w vpredicted = ClampedGradient(vleft, vtop, vtopleft);
387 rp[2] = std::abs(v - vpredicted);
388 rp[3] = v - vpredicted;
389 }
390
391 offset += kExtraPropsPerChannel;
392 }
393 }
394
395 struct PredictionResult {
396 int context = 0;
397 pixel_type_w guess = 0;
398 Predictor predictor;
399 int32_t multiplier;
400 };
401
InitPropsRow(Properties * p,const std::array<pixel_type,kNumStaticProperties> & static_props,const int y)402 inline void InitPropsRow(
403 Properties *p,
404 const std::array<pixel_type, kNumStaticProperties> &static_props,
405 const int y) {
406 for (size_t i = 0; i < kNumStaticProperties; i++) {
407 (*p)[i] = static_props[i];
408 }
409 (*p)[2] = y;
410 (*p)[9] = 0; // local gradient.
411 }
412
413 namespace detail {
414 enum PredictorMode {
415 kUseTree = 1,
416 kUseWP = 2,
417 kForceComputeProperties = 4,
418 kAllPredictions = 8,
419 };
420
PredictOne(Predictor p,pixel_type_w left,pixel_type_w top,pixel_type_w toptop,pixel_type_w topleft,pixel_type_w topright,pixel_type_w leftleft,pixel_type_w toprightright,pixel_type_w wp_pred)421 JXL_INLINE pixel_type_w PredictOne(Predictor p, pixel_type_w left,
422 pixel_type_w top, pixel_type_w toptop,
423 pixel_type_w topleft, pixel_type_w topright,
424 pixel_type_w leftleft,
425 pixel_type_w toprightright,
426 pixel_type_w wp_pred) {
427 switch (p) {
428 case Predictor::Zero:
429 return pixel_type_w{0};
430 case Predictor::Left:
431 return left;
432 case Predictor::Top:
433 return top;
434 case Predictor::Select:
435 return Select(left, top, topleft);
436 case Predictor::Weighted:
437 return wp_pred;
438 case Predictor::Gradient:
439 return pixel_type_w{ClampedGradient(left, top, topleft)};
440 case Predictor::TopLeft:
441 return topleft;
442 case Predictor::TopRight:
443 return topright;
444 case Predictor::LeftLeft:
445 return leftleft;
446 case Predictor::Average0:
447 return (left + top) / 2;
448 case Predictor::Average1:
449 return (left + topleft) / 2;
450 case Predictor::Average2:
451 return (topleft + top) / 2;
452 case Predictor::Average3:
453 return (top + topright) / 2;
454 case Predictor::Average4:
455 return (6 * top - 2 * toptop + 7 * left + 1 * leftleft +
456 1 * toprightright + 3 * topright + 8) /
457 16;
458 default:
459 return pixel_type_w{0};
460 }
461 }
462
463 template <int mode>
Predict(Properties * p,size_t w,const pixel_type * JXL_RESTRICT pp,const intptr_t onerow,const size_t x,const size_t y,Predictor predictor,const MATreeLookup * lookup,const Channel * references,weighted::State * wp_state,pixel_type_w * predictions)464 inline PredictionResult Predict(
465 Properties *p, size_t w, const pixel_type *JXL_RESTRICT pp,
466 const intptr_t onerow, const size_t x, const size_t y, Predictor predictor,
467 const MATreeLookup *lookup, const Channel *references,
468 weighted::State *wp_state, pixel_type_w *predictions) {
469 // We start in position 3 because of 2 static properties + y.
470 size_t offset = 3;
471 constexpr bool compute_properties =
472 mode & kUseTree || mode & kForceComputeProperties;
473 pixel_type_w left = (x ? pp[-1] : (y ? pp[-onerow] : 0));
474 pixel_type_w top = (y ? pp[-onerow] : left);
475 pixel_type_w topleft = (x && y ? pp[-1 - onerow] : left);
476 pixel_type_w topright = (x + 1 < w && y ? pp[1 - onerow] : top);
477 pixel_type_w leftleft = (x > 1 ? pp[-2] : left);
478 pixel_type_w toptop = (y > 1 ? pp[-onerow - onerow] : top);
479 pixel_type_w toprightright = (x + 2 < w && y ? pp[2 - onerow] : topright);
480
481 if (compute_properties) {
482 // location
483 (*p)[offset++] = x;
484 // neighbors
485 (*p)[offset++] = std::abs(top);
486 (*p)[offset++] = std::abs(left);
487 (*p)[offset++] = top;
488 (*p)[offset++] = left;
489
490 // local gradient
491 (*p)[offset] = left - (*p)[offset + 1];
492 offset++;
493 // local gradient
494 (*p)[offset++] = left + top - topleft;
495
496 // FFV1 context properties
497 (*p)[offset++] = left - topleft;
498 (*p)[offset++] = topleft - top;
499 (*p)[offset++] = top - topright;
500 (*p)[offset++] = top - toptop;
501 (*p)[offset++] = left - leftleft;
502 }
503
504 pixel_type_w wp_pred = 0;
505 if (mode & kUseWP) {
506 wp_pred = wp_state->Predict<compute_properties>(
507 x, y, w, top, left, topright, topleft, toptop, p, offset);
508 }
509 if (compute_properties) {
510 offset += weighted::kNumProperties;
511 // Extra properties.
512 const pixel_type *JXL_RESTRICT rp = references->Row(x);
513 for (size_t i = 0; i < references->w; i++) {
514 (*p)[offset++] = rp[i];
515 }
516 }
517 PredictionResult result;
518 if (mode & kUseTree) {
519 MATreeLookup::LookupResult lr = lookup->Lookup(*p);
520 result.context = lr.context;
521 result.guess = lr.offset;
522 result.multiplier = lr.multiplier;
523 predictor = lr.predictor;
524 }
525 if (mode & kAllPredictions) {
526 for (size_t i = 0; i < kNumModularPredictors; i++) {
527 predictions[i] = PredictOne((Predictor)i, left, top, toptop, topleft,
528 topright, leftleft, toprightright, wp_pred);
529 }
530 }
531 result.guess += PredictOne(predictor, left, top, toptop, topleft, topright,
532 leftleft, toprightright, wp_pred);
533 result.predictor = predictor;
534
535 return result;
536 }
537 } // namespace detail
538
PredictNoTreeNoWP(size_t w,const pixel_type * JXL_RESTRICT pp,const intptr_t onerow,const int x,const int y,Predictor predictor)539 inline PredictionResult PredictNoTreeNoWP(size_t w,
540 const pixel_type *JXL_RESTRICT pp,
541 const intptr_t onerow, const int x,
542 const int y, Predictor predictor) {
543 return detail::Predict</*mode=*/0>(
544 /*p=*/nullptr, w, pp, onerow, x, y, predictor, /*lookup=*/nullptr,
545 /*references=*/nullptr, /*wp_state=*/nullptr, /*predictions=*/nullptr);
546 }
547
PredictNoTreeWP(size_t w,const pixel_type * JXL_RESTRICT pp,const intptr_t onerow,const int x,const int y,Predictor predictor,weighted::State * wp_state)548 inline PredictionResult PredictNoTreeWP(size_t w,
549 const pixel_type *JXL_RESTRICT pp,
550 const intptr_t onerow, const int x,
551 const int y, Predictor predictor,
552 weighted::State *wp_state) {
553 return detail::Predict<detail::kUseWP>(
554 /*p=*/nullptr, w, pp, onerow, x, y, predictor, /*lookup=*/nullptr,
555 /*references=*/nullptr, wp_state, /*predictions=*/nullptr);
556 }
557
PredictTreeNoWP(Properties * p,size_t w,const pixel_type * JXL_RESTRICT pp,const intptr_t onerow,const int x,const int y,const MATreeLookup & tree_lookup,const Channel & references)558 inline PredictionResult PredictTreeNoWP(Properties *p, size_t w,
559 const pixel_type *JXL_RESTRICT pp,
560 const intptr_t onerow, const int x,
561 const int y,
562 const MATreeLookup &tree_lookup,
563 const Channel &references) {
564 return detail::Predict<detail::kUseTree>(
565 p, w, pp, onerow, x, y, Predictor::Zero, &tree_lookup, &references,
566 /*wp_state=*/nullptr, /*predictions=*/nullptr);
567 }
568
PredictTreeWP(Properties * p,size_t w,const pixel_type * JXL_RESTRICT pp,const intptr_t onerow,const int x,const int y,const MATreeLookup & tree_lookup,const Channel & references,weighted::State * wp_state)569 inline PredictionResult PredictTreeWP(Properties *p, size_t w,
570 const pixel_type *JXL_RESTRICT pp,
571 const intptr_t onerow, const int x,
572 const int y,
573 const MATreeLookup &tree_lookup,
574 const Channel &references,
575 weighted::State *wp_state) {
576 return detail::Predict<detail::kUseTree | detail::kUseWP>(
577 p, w, pp, onerow, x, y, Predictor::Zero, &tree_lookup, &references,
578 wp_state, /*predictions=*/nullptr);
579 }
580
PredictLearn(Properties * p,size_t w,const pixel_type * JXL_RESTRICT pp,const intptr_t onerow,const int x,const int y,Predictor predictor,const Channel & references,weighted::State * wp_state)581 inline PredictionResult PredictLearn(Properties *p, size_t w,
582 const pixel_type *JXL_RESTRICT pp,
583 const intptr_t onerow, const int x,
584 const int y, Predictor predictor,
585 const Channel &references,
586 weighted::State *wp_state) {
587 return detail::Predict<detail::kForceComputeProperties | detail::kUseWP>(
588 p, w, pp, onerow, x, y, predictor, /*lookup=*/nullptr, &references,
589 wp_state, /*predictions=*/nullptr);
590 }
591
PredictLearnAll(Properties * p,size_t w,const pixel_type * JXL_RESTRICT pp,const intptr_t onerow,const int x,const int y,const Channel & references,weighted::State * wp_state,pixel_type_w * predictions)592 inline void PredictLearnAll(Properties *p, size_t w,
593 const pixel_type *JXL_RESTRICT pp,
594 const intptr_t onerow, const int x, const int y,
595 const Channel &references,
596 weighted::State *wp_state,
597 pixel_type_w *predictions) {
598 detail::Predict<detail::kForceComputeProperties | detail::kUseWP |
599 detail::kAllPredictions>(
600 p, w, pp, onerow, x, y, Predictor::Zero,
601 /*lookup=*/nullptr, &references, wp_state, predictions);
602 }
603
PredictAllNoWP(size_t w,const pixel_type * JXL_RESTRICT pp,const intptr_t onerow,const int x,const int y,pixel_type_w * predictions)604 inline void PredictAllNoWP(size_t w, const pixel_type *JXL_RESTRICT pp,
605 const intptr_t onerow, const int x, const int y,
606 pixel_type_w *predictions) {
607 detail::Predict<detail::kAllPredictions>(
608 /*p=*/nullptr, w, pp, onerow, x, y, Predictor::Zero,
609 /*lookup=*/nullptr,
610 /*references=*/nullptr, /*wp_state=*/nullptr, predictions);
611 }
612 } // namespace jxl
613
614 #endif // LIB_JXL_MODULAR_ENCODING_CONTEXT_PREDICT_H_
615