1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5
6 #ifndef LIB_JXL_MODULAR_ENCODING_CONTEXT_PREDICT_H_
7 #define LIB_JXL_MODULAR_ENCODING_CONTEXT_PREDICT_H_
8
9 #include <utility>
10 #include <vector>
11
12 #include "lib/jxl/fields.h"
13 #include "lib/jxl/modular/modular_image.h"
14 #include "lib/jxl/modular/options.h"
15
16 namespace jxl {
17
18 namespace weighted {
19 constexpr static size_t kNumPredictors = 4;
20 constexpr static int64_t kPredExtraBits = 3;
21 constexpr static int64_t kPredictionRound = ((1 << kPredExtraBits) >> 1) - 1;
22 constexpr static size_t kNumProperties = 1;
23
24 struct Header : public Fields {
NameHeader25 const char *Name() const override { return "WeightedPredictorHeader"; }
26 // TODO(janwas): move to cc file, avoid including fields.h.
HeaderHeader27 Header() { Bundle::Init(this); }
28
VisitFieldsHeader29 Status VisitFields(Visitor *JXL_RESTRICT visitor) override {
30 if (visitor->AllDefault(*this, &all_default)) {
31 // Overwrite all serialized fields, but not any nonserialized_*.
32 visitor->SetDefault(this);
33 return true;
34 }
35 auto visit_p = [visitor](pixel_type val, pixel_type *p) {
36 uint32_t up = *p;
37 JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(5, val, &up));
38 *p = up;
39 return Status(true);
40 };
41 JXL_QUIET_RETURN_IF_ERROR(visit_p(16, &p1C));
42 JXL_QUIET_RETURN_IF_ERROR(visit_p(10, &p2C));
43 JXL_QUIET_RETURN_IF_ERROR(visit_p(7, &p3Ca));
44 JXL_QUIET_RETURN_IF_ERROR(visit_p(7, &p3Cb));
45 JXL_QUIET_RETURN_IF_ERROR(visit_p(7, &p3Cc));
46 JXL_QUIET_RETURN_IF_ERROR(visit_p(0, &p3Cd));
47 JXL_QUIET_RETURN_IF_ERROR(visit_p(0, &p3Ce));
48 JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xd, &w[0]));
49 JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xc, &w[1]));
50 JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xc, &w[2]));
51 JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xc, &w[3]));
52 return true;
53 }
54
55 bool all_default;
56 pixel_type p1C = 0, p2C = 0, p3Ca = 0, p3Cb = 0, p3Cc = 0, p3Cd = 0, p3Ce = 0;
57 uint32_t w[kNumPredictors] = {};
58 };
59
60 struct State {
61 pixel_type_w prediction[kNumPredictors] = {};
62 pixel_type_w pred = 0; // *before* removing the added bits.
63 std::vector<uint32_t> pred_errors[kNumPredictors];
64 std::vector<int32_t> error;
65 Header header;
66
67 // Allows to approximate division by a number from 1 to 64.
68 uint32_t divlookup[64];
69
AddBitsState70 constexpr static pixel_type_w AddBits(pixel_type_w x) {
71 return uint64_t(x) << kPredExtraBits;
72 }
73
StateState74 State(Header header, size_t xsize, size_t ysize) : header(header) {
75 // Extra margin to avoid out-of-bounds writes.
76 // All have space for two rows of data.
77 for (size_t i = 0; i < 4; i++) {
78 pred_errors[i].resize((xsize + 2) * 2);
79 }
80 error.resize((xsize + 2) * 2);
81 // Initialize division lookup table.
82 for (int i = 0; i < 64; i++) {
83 divlookup[i] = (1 << 24) / (i + 1);
84 }
85 }
86
87 // Approximates 4+(maxweight<<24)/(x+1), avoiding division
ErrorWeightState88 JXL_INLINE uint32_t ErrorWeight(uint64_t x, uint32_t maxweight) const {
89 int shift = FloorLog2Nonzero(x + 1) - 5;
90 if (shift < 0) shift = 0;
91 return 4 + ((maxweight * divlookup[x >> shift]) >> shift);
92 }
93
94 // Approximates the weighted average of the input values with the given
95 // weights, avoiding division. Weights must sum to at least 16.
96 JXL_INLINE pixel_type_w
WeightedAverageState97 WeightedAverage(const pixel_type_w *JXL_RESTRICT p,
98 std::array<uint32_t, kNumPredictors> w) const {
99 uint32_t weight_sum = 0;
100 for (size_t i = 0; i < kNumPredictors; i++) {
101 weight_sum += w[i];
102 }
103 JXL_DASSERT(weight_sum > 15);
104 uint32_t log_weight = FloorLog2Nonzero(weight_sum); // at least 4.
105 weight_sum = 0;
106 for (size_t i = 0; i < kNumPredictors; i++) {
107 w[i] >>= log_weight - 4;
108 weight_sum += w[i];
109 }
110 // for rounding.
111 pixel_type_w sum = (weight_sum >> 1) - 1;
112 for (size_t i = 0; i < kNumPredictors; i++) {
113 sum += p[i] * w[i];
114 }
115 return (sum * divlookup[weight_sum - 1]) >> 24;
116 }
117
118 template <bool compute_properties>
PredictState119 JXL_INLINE pixel_type_w Predict(size_t x, size_t y, size_t xsize,
120 pixel_type_w N, pixel_type_w W,
121 pixel_type_w NE, pixel_type_w NW,
122 pixel_type_w NN, Properties *properties,
123 size_t offset) {
124 size_t cur_row = y & 1 ? 0 : (xsize + 2);
125 size_t prev_row = y & 1 ? (xsize + 2) : 0;
126 size_t pos_N = prev_row + x;
127 size_t pos_NE = x < xsize - 1 ? pos_N + 1 : pos_N;
128 size_t pos_NW = x > 0 ? pos_N - 1 : pos_N;
129 std::array<uint32_t, kNumPredictors> weights;
130 for (size_t i = 0; i < kNumPredictors; i++) {
131 // pred_errors[pos_N] also contains the error of pixel W.
132 // pred_errors[pos_NW] also contains the error of pixel WW.
133 weights[i] = pred_errors[i][pos_N] + pred_errors[i][pos_NE] +
134 pred_errors[i][pos_NW];
135 weights[i] = ErrorWeight(weights[i], header.w[i]);
136 }
137
138 N = AddBits(N);
139 W = AddBits(W);
140 NE = AddBits(NE);
141 NW = AddBits(NW);
142 NN = AddBits(NN);
143
144 pixel_type_w teW = x == 0 ? 0 : error[cur_row + x - 1];
145 pixel_type_w teN = error[pos_N];
146 pixel_type_w teNW = error[pos_NW];
147 pixel_type_w sumWN = teN + teW;
148 pixel_type_w teNE = error[pos_NE];
149
150 if (compute_properties) {
151 pixel_type_w p = teW;
152 if (std::abs(teN) > std::abs(p)) p = teN;
153 if (std::abs(teNW) > std::abs(p)) p = teNW;
154 if (std::abs(teNE) > std::abs(p)) p = teNE;
155 (*properties)[offset++] = p;
156 }
157
158 prediction[0] = W + NE - N;
159 prediction[1] = N - (((sumWN + teNE) * header.p1C) >> 5);
160 prediction[2] = W - (((sumWN + teNW) * header.p2C) >> 5);
161 prediction[3] =
162 N - ((teNW * header.p3Ca + teN * header.p3Cb + teNE * header.p3Cc +
163 (NN - N) * header.p3Cd + (NW - W) * header.p3Ce) >>
164 5);
165
166 pred = WeightedAverage(prediction, weights);
167
168 // If all three have the same sign, skip clamping.
169 if (((teN ^ teW) | (teN ^ teNW)) > 0) {
170 return (pred + kPredictionRound) >> kPredExtraBits;
171 }
172
173 // Otherwise, clamp to min/max of neighbouring pixels (just W, NE, N).
174 pixel_type_w mx = std::max(W, std::max(NE, N));
175 pixel_type_w mn = std::min(W, std::min(NE, N));
176 pred = std::max(mn, std::min(mx, pred));
177 return (pred + kPredictionRound) >> kPredExtraBits;
178 }
179
UpdateErrorsState180 JXL_INLINE void UpdateErrors(pixel_type_w val, size_t x, size_t y,
181 size_t xsize) {
182 size_t cur_row = y & 1 ? 0 : (xsize + 2);
183 size_t prev_row = y & 1 ? (xsize + 2) : 0;
184 val = AddBits(val);
185 error[cur_row + x] = ClampToRange<pixel_type>(pred - val);
186 for (size_t i = 0; i < kNumPredictors; i++) {
187 pixel_type_w err =
188 (std::abs(prediction[i] - val) + kPredictionRound) >> kPredExtraBits;
189 // For predicting in the next row.
190 pred_errors[i][cur_row + x] = err;
191 // Add the error on this pixel to the error on the NE pixel. This has the
192 // effect of adding the error on this pixel to the E and EE pixels.
193 pred_errors[i][prev_row + x + 1] += err;
194 }
195 }
196 };
197
198 // Encoder helper function to set the parameters to some presets.
PredictorMode(int i,Header * header)199 inline void PredictorMode(int i, Header *header) {
200 switch (i) {
201 case 0:
202 // ~ lossless16 predictor
203 header->w[0] = 0xd;
204 header->w[1] = 0xc;
205 header->w[2] = 0xc;
206 header->w[3] = 0xc;
207 header->p1C = 16;
208 header->p2C = 10;
209 header->p3Ca = 7;
210 header->p3Cb = 7;
211 header->p3Cc = 7;
212 header->p3Cd = 0;
213 header->p3Ce = 0;
214 break;
215 case 1:
216 // ~ default lossless8 predictor
217 header->w[0] = 0xd;
218 header->w[1] = 0xc;
219 header->w[2] = 0xc;
220 header->w[3] = 0xb;
221 header->p1C = 8;
222 header->p2C = 8;
223 header->p3Ca = 4;
224 header->p3Cb = 0;
225 header->p3Cc = 3;
226 header->p3Cd = 23;
227 header->p3Ce = 2;
228 break;
229 case 2:
230 // ~ west lossless8 predictor
231 header->w[0] = 0xd;
232 header->w[1] = 0xc;
233 header->w[2] = 0xd;
234 header->w[3] = 0xc;
235 header->p1C = 10;
236 header->p2C = 9;
237 header->p3Ca = 7;
238 header->p3Cb = 0;
239 header->p3Cc = 0;
240 header->p3Cd = 16;
241 header->p3Ce = 9;
242 break;
243 case 3:
244 // ~ north lossless8 predictor
245 header->w[0] = 0xd;
246 header->w[1] = 0xd;
247 header->w[2] = 0xc;
248 header->w[3] = 0xc;
249 header->p1C = 16;
250 header->p2C = 8;
251 header->p3Ca = 0;
252 header->p3Cb = 16;
253 header->p3Cc = 0;
254 header->p3Cd = 23;
255 header->p3Ce = 0;
256 break;
257 case 4:
258 default:
259 // something else, because why not
260 header->w[0] = 0xd;
261 header->w[1] = 0xc;
262 header->w[2] = 0xc;
263 header->w[3] = 0xc;
264 header->p1C = 10;
265 header->p2C = 10;
266 header->p3Ca = 5;
267 header->p3Cb = 5;
268 header->p3Cc = 5;
269 header->p3Cd = 12;
270 header->p3Ce = 4;
271 break;
272 }
273 }
274 } // namespace weighted
275
276 // Stores a node and its two children at the same time. This significantly
277 // reduces the number of branches needed during decoding.
278 struct FlatDecisionNode {
279 // Property + splitval of the top node.
280 int32_t property0; // -1 if leaf.
281 union {
282 PropertyVal splitval0;
283 Predictor predictor;
284 };
285 uint32_t childID; // childID is ctx id if leaf.
286 // Property+splitval of the two child nodes.
287 union {
288 PropertyVal splitvals[2];
289 int32_t multiplier;
290 };
291 union {
292 int32_t properties[2];
293 int64_t predictor_offset;
294 };
295 };
296 using FlatTree = std::vector<FlatDecisionNode>;
297
298 class MATreeLookup {
299 public:
MATreeLookup(const FlatTree & tree)300 explicit MATreeLookup(const FlatTree &tree) : nodes_(tree) {}
301 struct LookupResult {
302 uint32_t context;
303 Predictor predictor;
304 int64_t offset;
305 int32_t multiplier;
306 };
Lookup(const Properties & properties)307 LookupResult Lookup(const Properties &properties) const {
308 uint32_t pos = 0;
309 while (true) {
310 const FlatDecisionNode &node = nodes_[pos];
311 if (node.property0 < 0) {
312 return {node.childID, node.predictor, node.predictor_offset,
313 node.multiplier};
314 }
315 bool p0 = properties[node.property0] <= node.splitval0;
316 uint32_t off0 = properties[node.properties[0]] <= node.splitvals[0];
317 uint32_t off1 = 2 | (properties[node.properties[1]] <= node.splitvals[1]);
318 pos = node.childID + (p0 ? off1 : off0);
319 }
320 }
321
322 private:
323 const FlatTree &nodes_;
324 };
325
326 static constexpr size_t kExtraPropsPerChannel = 4;
327 static constexpr size_t kNumNonrefProperties =
328 kNumStaticProperties + 13 + weighted::kNumProperties;
329
330 constexpr size_t kWPProp = kNumNonrefProperties - weighted::kNumProperties;
331 constexpr size_t kGradientProp = 9;
332
333 // Clamps gradient to the min/max of n, w (and l, implicitly).
ClampedGradient(const int32_t n,const int32_t w,const int32_t l)334 static JXL_INLINE int32_t ClampedGradient(const int32_t n, const int32_t w,
335 const int32_t l) {
336 const int32_t m = std::min(n, w);
337 const int32_t M = std::max(n, w);
338 // The end result of this operation doesn't overflow or underflow if the
339 // result is between m and M, but the intermediate value may overflow, so we
340 // do the intermediate operations in uint32_t and check later if we had an
341 // overflow or underflow condition comparing m, M and l directly.
342 // grad = M + m - l = n + w - l
343 const int32_t grad =
344 static_cast<int32_t>(static_cast<uint32_t>(n) + static_cast<uint32_t>(w) -
345 static_cast<uint32_t>(l));
346 // We use two sets of ternary operators to force the evaluation of them in
347 // any case, allowing the compiler to avoid branches and use cmovl/cmovg in
348 // x86.
349 const int32_t grad_clamp_M = (l < m) ? M : grad;
350 return (l > M) ? m : grad_clamp_M;
351 }
352
Select(pixel_type_w a,pixel_type_w b,pixel_type_w c)353 inline pixel_type_w Select(pixel_type_w a, pixel_type_w b, pixel_type_w c) {
354 pixel_type_w p = a + b - c;
355 pixel_type_w pa = std::abs(p - a);
356 pixel_type_w pb = std::abs(p - b);
357 return pa < pb ? a : b;
358 }
359
PrecomputeReferences(const Channel & ch,size_t y,const Image & image,uint32_t i,Channel * references)360 inline void PrecomputeReferences(const Channel &ch, size_t y,
361 const Image &image, uint32_t i,
362 Channel *references) {
363 ZeroFillImage(&references->plane);
364 uint32_t offset = 0;
365 size_t num_extra_props = references->w;
366 intptr_t onerow = references->plane.PixelsPerRow();
367 for (int32_t j = static_cast<int32_t>(i) - 1;
368 j >= 0 && offset < num_extra_props; j--) {
369 if (image.channel[j].w != image.channel[i].w ||
370 image.channel[j].h != image.channel[i].h) {
371 continue;
372 }
373 if (image.channel[j].hshift != image.channel[i].hshift) continue;
374 if (image.channel[j].vshift != image.channel[i].vshift) continue;
375 pixel_type *JXL_RESTRICT rp = references->Row(0) + offset;
376 const pixel_type *JXL_RESTRICT rpp = image.channel[j].Row(y);
377 const pixel_type *JXL_RESTRICT rpprev = image.channel[j].Row(y ? y - 1 : 0);
378 for (size_t x = 0; x < ch.w; x++, rp += onerow) {
379 pixel_type_w v = rpp[x];
380 rp[0] = std::abs(v);
381 rp[1] = v;
382 pixel_type_w vleft = (x ? rpp[x - 1] : 0);
383 pixel_type_w vtop = (y ? rpprev[x] : vleft);
384 pixel_type_w vtopleft = (x && y ? rpprev[x - 1] : vleft);
385 pixel_type_w vpredicted = ClampedGradient(vleft, vtop, vtopleft);
386 rp[2] = std::abs(v - vpredicted);
387 rp[3] = v - vpredicted;
388 }
389
390 offset += kExtraPropsPerChannel;
391 }
392 }
393
394 struct PredictionResult {
395 int context = 0;
396 pixel_type_w guess = 0;
397 Predictor predictor;
398 int32_t multiplier;
399 };
400
PropertyName(size_t i)401 inline std::string PropertyName(size_t i) {
402 static_assert(kNumNonrefProperties == 16, "Update this function");
403 switch (i) {
404 case 0:
405 return "c";
406 case 1:
407 return "g";
408 case 2:
409 return "y";
410 case 3:
411 return "x";
412 case 4:
413 return "|N|";
414 case 5:
415 return "|W|";
416 case 6:
417 return "N";
418 case 7:
419 return "W";
420 case 8:
421 return "W-WW-NW+NWW";
422 case 9:
423 return "W+N-NW";
424 case 10:
425 return "W-NW";
426 case 11:
427 return "NW-N";
428 case 12:
429 return "N-NE";
430 case 13:
431 return "N-NN";
432 case 14:
433 return "W-WW";
434 case 15:
435 return "WGH";
436 default:
437 return "ch[" + ToString(15 - (int)i) + "]";
438 }
439 }
440
InitPropsRow(Properties * p,const std::array<pixel_type,kNumStaticProperties> & static_props,const int y)441 inline void InitPropsRow(
442 Properties *p,
443 const std::array<pixel_type, kNumStaticProperties> &static_props,
444 const int y) {
445 for (size_t i = 0; i < kNumStaticProperties; i++) {
446 (*p)[i] = static_props[i];
447 }
448 (*p)[2] = y;
449 (*p)[9] = 0; // local gradient.
450 }
451
452 namespace detail {
453 enum PredictorMode {
454 kUseTree = 1,
455 kUseWP = 2,
456 kForceComputeProperties = 4,
457 kAllPredictions = 8,
458 };
459
PredictOne(Predictor p,pixel_type_w left,pixel_type_w top,pixel_type_w toptop,pixel_type_w topleft,pixel_type_w topright,pixel_type_w leftleft,pixel_type_w toprightright,pixel_type_w wp_pred)460 JXL_INLINE pixel_type_w PredictOne(Predictor p, pixel_type_w left,
461 pixel_type_w top, pixel_type_w toptop,
462 pixel_type_w topleft, pixel_type_w topright,
463 pixel_type_w leftleft,
464 pixel_type_w toprightright,
465 pixel_type_w wp_pred) {
466 switch (p) {
467 case Predictor::Zero:
468 return pixel_type_w{0};
469 case Predictor::Left:
470 return left;
471 case Predictor::Top:
472 return top;
473 case Predictor::Select:
474 return Select(left, top, topleft);
475 case Predictor::Weighted:
476 return wp_pred;
477 case Predictor::Gradient:
478 return pixel_type_w{ClampedGradient(left, top, topleft)};
479 case Predictor::TopLeft:
480 return topleft;
481 case Predictor::TopRight:
482 return topright;
483 case Predictor::LeftLeft:
484 return leftleft;
485 case Predictor::Average0:
486 return (left + top) / 2;
487 case Predictor::Average1:
488 return (left + topleft) / 2;
489 case Predictor::Average2:
490 return (topleft + top) / 2;
491 case Predictor::Average3:
492 return (top + topright) / 2;
493 case Predictor::Average4:
494 return (6 * top - 2 * toptop + 7 * left + 1 * leftleft +
495 1 * toprightright + 3 * topright + 8) /
496 16;
497 default:
498 return pixel_type_w{0};
499 }
500 }
501
502 template <int mode>
Predict(Properties * p,size_t w,const pixel_type * JXL_RESTRICT pp,const intptr_t onerow,const size_t x,const size_t y,Predictor predictor,const MATreeLookup * lookup,const Channel * references,weighted::State * wp_state,pixel_type_w * predictions)503 inline PredictionResult Predict(
504 Properties *p, size_t w, const pixel_type *JXL_RESTRICT pp,
505 const intptr_t onerow, const size_t x, const size_t y, Predictor predictor,
506 const MATreeLookup *lookup, const Channel *references,
507 weighted::State *wp_state, pixel_type_w *predictions) {
508 // We start in position 3 because of 2 static properties + y.
509 size_t offset = 3;
510 constexpr bool compute_properties =
511 mode & kUseTree || mode & kForceComputeProperties;
512 pixel_type_w left = (x ? pp[-1] : (y ? pp[-onerow] : 0));
513 pixel_type_w top = (y ? pp[-onerow] : left);
514 pixel_type_w topleft = (x && y ? pp[-1 - onerow] : left);
515 pixel_type_w topright = (x + 1 < w && y ? pp[1 - onerow] : top);
516 pixel_type_w leftleft = (x > 1 ? pp[-2] : left);
517 pixel_type_w toptop = (y > 1 ? pp[-onerow - onerow] : top);
518 pixel_type_w toprightright = (x + 2 < w && y ? pp[2 - onerow] : topright);
519
520 if (compute_properties) {
521 // location
522 (*p)[offset++] = x;
523 // neighbors
524 (*p)[offset++] = std::abs(top);
525 (*p)[offset++] = std::abs(left);
526 (*p)[offset++] = top;
527 (*p)[offset++] = left;
528
529 // local gradient
530 (*p)[offset] = left - (*p)[offset + 1];
531 offset++;
532 // local gradient
533 (*p)[offset++] = left + top - topleft;
534
535 // FFV1 context properties
536 (*p)[offset++] = left - topleft;
537 (*p)[offset++] = topleft - top;
538 (*p)[offset++] = top - topright;
539 (*p)[offset++] = top - toptop;
540 (*p)[offset++] = left - leftleft;
541 }
542
543 pixel_type_w wp_pred = 0;
544 if (mode & kUseWP) {
545 wp_pred = wp_state->Predict<compute_properties>(
546 x, y, w, top, left, topright, topleft, toptop, p, offset);
547 }
548 if (compute_properties) {
549 offset += weighted::kNumProperties;
550 // Extra properties.
551 const pixel_type *JXL_RESTRICT rp = references->Row(x);
552 for (size_t i = 0; i < references->w; i++) {
553 (*p)[offset++] = rp[i];
554 }
555 }
556 PredictionResult result;
557 if (mode & kUseTree) {
558 MATreeLookup::LookupResult lr = lookup->Lookup(*p);
559 result.context = lr.context;
560 result.guess = lr.offset;
561 result.multiplier = lr.multiplier;
562 predictor = lr.predictor;
563 }
564 if (mode & kAllPredictions) {
565 for (size_t i = 0; i < kNumModularPredictors; i++) {
566 predictions[i] = PredictOne((Predictor)i, left, top, toptop, topleft,
567 topright, leftleft, toprightright, wp_pred);
568 }
569 }
570 result.guess += PredictOne(predictor, left, top, toptop, topleft, topright,
571 leftleft, toprightright, wp_pred);
572 result.predictor = predictor;
573
574 return result;
575 }
576 } // namespace detail
577
PredictNoTreeNoWP(size_t w,const pixel_type * JXL_RESTRICT pp,const intptr_t onerow,const int x,const int y,Predictor predictor)578 inline PredictionResult PredictNoTreeNoWP(size_t w,
579 const pixel_type *JXL_RESTRICT pp,
580 const intptr_t onerow, const int x,
581 const int y, Predictor predictor) {
582 return detail::Predict</*mode=*/0>(
583 /*p=*/nullptr, w, pp, onerow, x, y, predictor, /*lookup=*/nullptr,
584 /*references=*/nullptr, /*wp_state=*/nullptr, /*predictions=*/nullptr);
585 }
586
PredictNoTreeWP(size_t w,const pixel_type * JXL_RESTRICT pp,const intptr_t onerow,const int x,const int y,Predictor predictor,weighted::State * wp_state)587 inline PredictionResult PredictNoTreeWP(size_t w,
588 const pixel_type *JXL_RESTRICT pp,
589 const intptr_t onerow, const int x,
590 const int y, Predictor predictor,
591 weighted::State *wp_state) {
592 return detail::Predict<detail::kUseWP>(
593 /*p=*/nullptr, w, pp, onerow, x, y, predictor, /*lookup=*/nullptr,
594 /*references=*/nullptr, wp_state, /*predictions=*/nullptr);
595 }
596
PredictTreeNoWP(Properties * p,size_t w,const pixel_type * JXL_RESTRICT pp,const intptr_t onerow,const int x,const int y,const MATreeLookup & tree_lookup,const Channel & references)597 inline PredictionResult PredictTreeNoWP(Properties *p, size_t w,
598 const pixel_type *JXL_RESTRICT pp,
599 const intptr_t onerow, const int x,
600 const int y,
601 const MATreeLookup &tree_lookup,
602 const Channel &references) {
603 return detail::Predict<detail::kUseTree>(
604 p, w, pp, onerow, x, y, Predictor::Zero, &tree_lookup, &references,
605 /*wp_state=*/nullptr, /*predictions=*/nullptr);
606 }
607
PredictTreeWP(Properties * p,size_t w,const pixel_type * JXL_RESTRICT pp,const intptr_t onerow,const int x,const int y,const MATreeLookup & tree_lookup,const Channel & references,weighted::State * wp_state)608 inline PredictionResult PredictTreeWP(Properties *p, size_t w,
609 const pixel_type *JXL_RESTRICT pp,
610 const intptr_t onerow, const int x,
611 const int y,
612 const MATreeLookup &tree_lookup,
613 const Channel &references,
614 weighted::State *wp_state) {
615 return detail::Predict<detail::kUseTree | detail::kUseWP>(
616 p, w, pp, onerow, x, y, Predictor::Zero, &tree_lookup, &references,
617 wp_state, /*predictions=*/nullptr);
618 }
619
PredictLearn(Properties * p,size_t w,const pixel_type * JXL_RESTRICT pp,const intptr_t onerow,const int x,const int y,Predictor predictor,const Channel & references,weighted::State * wp_state)620 inline PredictionResult PredictLearn(Properties *p, size_t w,
621 const pixel_type *JXL_RESTRICT pp,
622 const intptr_t onerow, const int x,
623 const int y, Predictor predictor,
624 const Channel &references,
625 weighted::State *wp_state) {
626 return detail::Predict<detail::kForceComputeProperties | detail::kUseWP>(
627 p, w, pp, onerow, x, y, predictor, /*lookup=*/nullptr, &references,
628 wp_state, /*predictions=*/nullptr);
629 }
630
PredictLearnAll(Properties * p,size_t w,const pixel_type * JXL_RESTRICT pp,const intptr_t onerow,const int x,const int y,const Channel & references,weighted::State * wp_state,pixel_type_w * predictions)631 inline void PredictLearnAll(Properties *p, size_t w,
632 const pixel_type *JXL_RESTRICT pp,
633 const intptr_t onerow, const int x, const int y,
634 const Channel &references,
635 weighted::State *wp_state,
636 pixel_type_w *predictions) {
637 detail::Predict<detail::kForceComputeProperties | detail::kUseWP |
638 detail::kAllPredictions>(
639 p, w, pp, onerow, x, y, Predictor::Zero,
640 /*lookup=*/nullptr, &references, wp_state, predictions);
641 }
642
PredictAllNoWP(size_t w,const pixel_type * JXL_RESTRICT pp,const intptr_t onerow,const int x,const int y,pixel_type_w * predictions)643 inline void PredictAllNoWP(size_t w, const pixel_type *JXL_RESTRICT pp,
644 const intptr_t onerow, const int x, const int y,
645 pixel_type_w *predictions) {
646 detail::Predict<detail::kAllPredictions>(
647 /*p=*/nullptr, w, pp, onerow, x, y, Predictor::Zero,
648 /*lookup=*/nullptr,
649 /*references=*/nullptr, /*wp_state=*/nullptr, predictions);
650 }
651 } // namespace jxl
652
653 #endif // LIB_JXL_MODULAR_ENCODING_CONTEXT_PREDICT_H_
654