1 // Copyright 2019 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <algorithm>
16 #include <array>
17 #include <cassert>
18 #include <cstdint>
19 #include <cstdlib>
20 #include <cstring>
21 #include <memory>
22 #include <vector>
23 
24 #include "src/buffer_pool.h"
25 #include "src/dsp/constants.h"
26 #include "src/motion_vector.h"
27 #include "src/obu_parser.h"
28 #include "src/prediction_mask.h"
29 #include "src/symbol_decoder_context.h"
30 #include "src/tile.h"
31 #include "src/utils/array_2d.h"
32 #include "src/utils/bit_mask_set.h"
33 #include "src/utils/block_parameters_holder.h"
34 #include "src/utils/common.h"
35 #include "src/utils/constants.h"
36 #include "src/utils/entropy_decoder.h"
37 #include "src/utils/logging.h"
38 #include "src/utils/segmentation.h"
39 #include "src/utils/segmentation_map.h"
40 #include "src/utils/types.h"
41 
42 namespace libgav1 {
43 namespace {
44 
45 constexpr int kDeltaQSmall = 3;
46 constexpr int kDeltaLfSmall = 3;
47 
48 constexpr uint8_t kIntraYModeContext[kIntraPredictionModesY] = {
49     0, 1, 2, 3, 4, 4, 4, 4, 3, 0, 1, 2, 0};
50 
51 constexpr uint8_t kSizeGroup[kMaxBlockSizes] = {
52     0, 0, 0, 0, 1, 1, 1, 0, 1, 2, 2, 2, 1, 2, 3, 3, 2, 3, 3, 3, 3, 3};
53 
54 constexpr int kCompoundModeNewMvContexts = 5;
55 constexpr uint8_t kCompoundModeContextMap[3][kCompoundModeNewMvContexts] = {
56     {0, 1, 1, 1, 1}, {1, 2, 3, 4, 4}, {4, 4, 5, 6, 7}};
57 
58 enum CflSign : uint8_t {
59   kCflSignZero = 0,
60   kCflSignNegative = 1,
61   kCflSignPositive = 2
62 };
63 
64 // For each possible value of the combined signs (which is read from the
65 // bitstream), this array stores the following: sign_u, sign_v, alpha_u_context,
66 // alpha_v_context. Only positive entries are used. Entry at index i is computed
67 // as follows:
68 // sign_u = i / 3
69 // sign_v = i % 3
70 // alpha_u_context = i - 2
71 // alpha_v_context = (sign_v - 1) * 3 + sign_u
72 constexpr int8_t kCflAlphaLookup[kCflAlphaSignsSymbolCount][4] = {
73     {0, 1, -2, 0}, {0, 2, -1, 3}, {1, 0, 0, -2}, {1, 1, 1, 1},
74     {1, 2, 2, 4},  {2, 0, 3, -1}, {2, 1, 4, 2},  {2, 2, 5, 5},
75 };
76 
77 constexpr BitMaskSet kPredictionModeHasNearMvMask(kPredictionModeNearMv,
78                                                   kPredictionModeNearNearMv,
79                                                   kPredictionModeNearNewMv,
80                                                   kPredictionModeNewNearMv);
81 
82 constexpr BitMaskSet kIsInterIntraModeAllowedMask(kBlock8x8, kBlock8x16,
83                                                   kBlock16x8, kBlock16x16,
84                                                   kBlock16x32, kBlock32x16,
85                                                   kBlock32x32);
86 
IsBackwardReference(ReferenceFrameType type)87 bool IsBackwardReference(ReferenceFrameType type) {
88   return type >= kReferenceFrameBackward && type <= kReferenceFrameAlternate;
89 }
90 
IsSameDirectionReferencePair(ReferenceFrameType type1,ReferenceFrameType type2)91 bool IsSameDirectionReferencePair(ReferenceFrameType type1,
92                                   ReferenceFrameType type2) {
93   return (type1 >= kReferenceFrameBackward) ==
94          (type2 >= kReferenceFrameBackward);
95 }
96 
97 // This is called neg_deinterleave() in the spec.
DecodeSegmentId(int diff,int reference,int max)98 int DecodeSegmentId(int diff, int reference, int max) {
99   if (reference == 0) return diff;
100   if (reference >= max - 1) return max - diff - 1;
101   const int value = ((diff & 1) != 0) ? reference + ((diff + 1) >> 1)
102                                       : reference - (diff >> 1);
103   const int reference2 = (reference << 1);
104   if (reference2 < max) {
105     return (diff <= reference2) ? value : diff;
106   }
107   return (diff <= ((max - reference - 1) << 1)) ? value : max - (diff + 1);
108 }
109 
110 // This is called DrlCtxStack in section 7.10.2.14 of the spec.
111 // In the spec, the weights of all the nearest mvs are incremented by a bonus
112 // weight which is larger than any natural weight, and the weights of the mvs
113 // are compared with this bonus weight to determine their contexts. We replace
114 // this procedure by introducing |nearest_mv_count| in PredictionParameters,
115 // which records the count of the nearest mvs. Since all the nearest mvs are in
116 // the beginning of the mv stack, the |index| of a mv in the mv stack can be
117 // compared with |nearest_mv_count| to get that mv's context.
GetRefMvIndexContext(int nearest_mv_count,int index)118 int GetRefMvIndexContext(int nearest_mv_count, int index) {
119   if (index + 1 < nearest_mv_count) {
120     return 0;
121   }
122   if (index + 1 == nearest_mv_count) {
123     return 1;
124   }
125   return 2;
126 }
127 
128 // Returns true if both the width and height of the block is less than 64.
IsBlockDimensionLessThan64(BlockSize size)129 bool IsBlockDimensionLessThan64(BlockSize size) {
130   return size <= kBlock32x32 && size != kBlock16x64;
131 }
132 
GetUseCompoundReferenceContext(const Tile::Block & block)133 int GetUseCompoundReferenceContext(const Tile::Block& block) {
134   if (block.top_available[kPlaneY] && block.left_available[kPlaneY]) {
135     if (block.IsTopSingle() && block.IsLeftSingle()) {
136       return static_cast<int>(IsBackwardReference(block.TopReference(0))) ^
137              static_cast<int>(IsBackwardReference(block.LeftReference(0)));
138     }
139     if (block.IsTopSingle()) {
140       return 2 + static_cast<int>(IsBackwardReference(block.TopReference(0)) ||
141                                   block.IsTopIntra());
142     }
143     if (block.IsLeftSingle()) {
144       return 2 + static_cast<int>(IsBackwardReference(block.LeftReference(0)) ||
145                                   block.IsLeftIntra());
146     }
147     return 4;
148   }
149   if (block.top_available[kPlaneY]) {
150     return block.IsTopSingle()
151                ? static_cast<int>(IsBackwardReference(block.TopReference(0)))
152                : 3;
153   }
154   if (block.left_available[kPlaneY]) {
155     return block.IsLeftSingle()
156                ? static_cast<int>(IsBackwardReference(block.LeftReference(0)))
157                : 3;
158   }
159   return 1;
160 }
161 
162 // Calculates count0 by calling block.CountReferences() on the frame types from
163 // type0_start to type0_end, inclusive, and summing the results.
164 // Calculates count1 by calling block.CountReferences() on the frame types from
165 // type1_start to type1_end, inclusive, and summing the results.
166 // Compares count0 with count1 and returns 0, 1 or 2.
167 //
168 // See count_refs and ref_count_ctx in 8.3.2.
GetReferenceContext(const Tile::Block & block,ReferenceFrameType type0_start,ReferenceFrameType type0_end,ReferenceFrameType type1_start,ReferenceFrameType type1_end)169 int GetReferenceContext(const Tile::Block& block,
170                         ReferenceFrameType type0_start,
171                         ReferenceFrameType type0_end,
172                         ReferenceFrameType type1_start,
173                         ReferenceFrameType type1_end) {
174   int count0 = 0;
175   int count1 = 0;
176   for (int type = type0_start; type <= type0_end; ++type) {
177     count0 += block.CountReferences(static_cast<ReferenceFrameType>(type));
178   }
179   for (int type = type1_start; type <= type1_end; ++type) {
180     count1 += block.CountReferences(static_cast<ReferenceFrameType>(type));
181   }
182   return (count0 < count1) ? 0 : (count0 == count1 ? 1 : 2);
183 }
184 
185 }  // namespace
186 
ReadSegmentId(const Block & block)187 bool Tile::ReadSegmentId(const Block& block) {
188   // These two asserts ensure that current_frame_.segmentation_map() is not
189   // nullptr.
190   assert(frame_header_.segmentation.enabled);
191   assert(frame_header_.segmentation.update_map);
192   const SegmentationMap& map = *current_frame_.segmentation_map();
193   int top_left = -1;
194   if (block.top_available[kPlaneY] && block.left_available[kPlaneY]) {
195     top_left = map.segment_id(block.row4x4 - 1, block.column4x4 - 1);
196   }
197   int top = -1;
198   if (block.top_available[kPlaneY]) {
199     top = map.segment_id(block.row4x4 - 1, block.column4x4);
200   }
201   int left = -1;
202   if (block.left_available[kPlaneY]) {
203     left = map.segment_id(block.row4x4, block.column4x4 - 1);
204   }
205   int pred;
206   if (top == -1) {
207     pred = (left == -1) ? 0 : left;
208   } else if (left == -1) {
209     pred = top;
210   } else {
211     pred = (top_left == top) ? top : left;
212   }
213   BlockParameters& bp = *block.bp;
214   if (bp.skip) {
215     bp.prediction_parameters->segment_id = pred;
216     return true;
217   }
218   int context = 0;
219   if (top_left < 0) {
220     context = 0;
221   } else if (top_left == top && top_left == left) {
222     context = 2;
223   } else if (top_left == top || top_left == left || top == left) {
224     context = 1;
225   }
226   uint16_t* const segment_id_cdf =
227       symbol_decoder_context_.segment_id_cdf[context];
228   const int encoded_segment_id =
229       reader_.ReadSymbol<kMaxSegments>(segment_id_cdf);
230   bp.prediction_parameters->segment_id =
231       DecodeSegmentId(encoded_segment_id, pred,
232                       frame_header_.segmentation.last_active_segment_id + 1);
233   // Check the bitstream conformance requirement in Section 6.10.8 of the spec.
234   if (bp.prediction_parameters->segment_id < 0 ||
235       bp.prediction_parameters->segment_id >
236           frame_header_.segmentation.last_active_segment_id) {
237     LIBGAV1_DLOG(
238         ERROR,
239         "Corrupted segment_ids: encoded %d, last active %d, postprocessed %d",
240         encoded_segment_id, frame_header_.segmentation.last_active_segment_id,
241         bp.prediction_parameters->segment_id);
242     return false;
243   }
244   return true;
245 }
246 
ReadIntraSegmentId(const Block & block)247 bool Tile::ReadIntraSegmentId(const Block& block) {
248   BlockParameters& bp = *block.bp;
249   if (!frame_header_.segmentation.enabled) {
250     bp.prediction_parameters->segment_id = 0;
251     return true;
252   }
253   return ReadSegmentId(block);
254 }
255 
ReadSkip(const Block & block)256 void Tile::ReadSkip(const Block& block) {
257   BlockParameters& bp = *block.bp;
258   if (frame_header_.segmentation.segment_id_pre_skip &&
259       frame_header_.segmentation.FeatureActive(
260           bp.prediction_parameters->segment_id, kSegmentFeatureSkip)) {
261     bp.skip = true;
262     return;
263   }
264   int context = 0;
265   if (block.top_available[kPlaneY] && block.bp_top->skip) {
266     ++context;
267   }
268   if (block.left_available[kPlaneY] && block.bp_left->skip) {
269     ++context;
270   }
271   uint16_t* const skip_cdf = symbol_decoder_context_.skip_cdf[context];
272   bp.skip = reader_.ReadSymbol(skip_cdf);
273 }
274 
ReadSkipMode(const Block & block)275 bool Tile::ReadSkipMode(const Block& block) {
276   BlockParameters& bp = *block.bp;
277   if (!frame_header_.skip_mode_present ||
278       frame_header_.segmentation.FeatureActive(
279           bp.prediction_parameters->segment_id, kSegmentFeatureSkip) ||
280       frame_header_.segmentation.FeatureActive(
281           bp.prediction_parameters->segment_id,
282           kSegmentFeatureReferenceFrame) ||
283       frame_header_.segmentation.FeatureActive(
284           bp.prediction_parameters->segment_id, kSegmentFeatureGlobalMv) ||
285       IsBlockDimension4(block.size)) {
286     return false;
287   }
288   const int context =
289       (block.left_available[kPlaneY]
290            ? static_cast<int>(left_context_.skip_mode[block.left_context_index])
291            : 0) +
292       (block.top_available[kPlaneY]
293            ? static_cast<int>(
294                  block.top_context->skip_mode[block.top_context_index])
295            : 0);
296   return reader_.ReadSymbol(symbol_decoder_context_.skip_mode_cdf[context]);
297 }
298 
ReadCdef(const Block & block)299 void Tile::ReadCdef(const Block& block) {
300   BlockParameters& bp = *block.bp;
301   if (bp.skip || frame_header_.coded_lossless ||
302       !sequence_header_.enable_cdef || frame_header_.allow_intrabc ||
303       frame_header_.cdef.bits == 0) {
304     return;
305   }
306   int8_t* const cdef_index =
307       &cdef_index_[DivideBy16(block.row4x4)][DivideBy16(block.column4x4)];
308   int stride = cdef_index_.columns();
309   if (cdef_index[0] == -1) {
310     cdef_index[0] =
311         static_cast<int8_t>(reader_.ReadLiteral(frame_header_.cdef.bits));
312     if (block.size == kBlock128x128) {
313       // This condition is shorthand for block.width4x4 > 16 && block.height4x4
314       // > 16.
315       cdef_index[1] = cdef_index[0];
316       cdef_index[stride] = cdef_index[0];
317       cdef_index[stride + 1] = cdef_index[0];
318     } else if (block.width4x4 > 16) {
319       cdef_index[1] = cdef_index[0];
320     } else if (block.height4x4 > 16) {
321       cdef_index[stride] = cdef_index[0];
322     }
323   }
324 }
325 
ReadAndClipDelta(uint16_t * const cdf,int delta_small,int scale,int min_value,int max_value,int value)326 int Tile::ReadAndClipDelta(uint16_t* const cdf, int delta_small, int scale,
327                            int min_value, int max_value, int value) {
328   int abs = reader_.ReadSymbol<kDeltaSymbolCount>(cdf);
329   if (abs == delta_small) {
330     const int remaining_bit_count =
331         static_cast<int>(reader_.ReadLiteral(3)) + 1;
332     const int abs_remaining_bits =
333         static_cast<int>(reader_.ReadLiteral(remaining_bit_count));
334     abs = abs_remaining_bits + (1 << remaining_bit_count) + 1;
335   }
336   if (abs != 0) {
337     const bool sign = reader_.ReadBit() != 0;
338     const int scaled_abs = abs << scale;
339     const int reduced_delta = sign ? -scaled_abs : scaled_abs;
340     value += reduced_delta;
341     value = Clip3(value, min_value, max_value);
342   }
343   return value;
344 }
345 
ReadQuantizerIndexDelta(const Block & block)346 void Tile::ReadQuantizerIndexDelta(const Block& block) {
347   assert(read_deltas_);
348   BlockParameters& bp = *block.bp;
349   if ((block.size == SuperBlockSize() && bp.skip)) {
350     return;
351   }
352   current_quantizer_index_ =
353       ReadAndClipDelta(symbol_decoder_context_.delta_q_cdf, kDeltaQSmall,
354                        frame_header_.delta_q.scale, kMinLossyQuantizer,
355                        kMaxQuantizer, current_quantizer_index_);
356 }
357 
ReadLoopFilterDelta(const Block & block)358 void Tile::ReadLoopFilterDelta(const Block& block) {
359   assert(read_deltas_);
360   BlockParameters& bp = *block.bp;
361   if (!frame_header_.delta_lf.present ||
362       (block.size == SuperBlockSize() && bp.skip)) {
363     return;
364   }
365   int frame_lf_count = 1;
366   if (frame_header_.delta_lf.multi) {
367     frame_lf_count = kFrameLfCount - (PlaneCount() > 1 ? 0 : 2);
368   }
369   bool recompute_deblock_filter_levels = false;
370   for (int i = 0; i < frame_lf_count; ++i) {
371     uint16_t* const delta_lf_abs_cdf =
372         frame_header_.delta_lf.multi
373             ? symbol_decoder_context_.delta_lf_multi_cdf[i]
374             : symbol_decoder_context_.delta_lf_cdf;
375     const int8_t old_delta_lf = delta_lf_[i];
376     delta_lf_[i] = ReadAndClipDelta(
377         delta_lf_abs_cdf, kDeltaLfSmall, frame_header_.delta_lf.scale,
378         -kMaxLoopFilterValue, kMaxLoopFilterValue, delta_lf_[i]);
379     recompute_deblock_filter_levels =
380         recompute_deblock_filter_levels || (old_delta_lf != delta_lf_[i]);
381   }
382   delta_lf_all_zero_ =
383       (delta_lf_[0] | delta_lf_[1] | delta_lf_[2] | delta_lf_[3]) == 0;
384   if (!delta_lf_all_zero_ && recompute_deblock_filter_levels) {
385     post_filter_.ComputeDeblockFilterLevels(delta_lf_, deblock_filter_levels_);
386   }
387 }
388 
ReadPredictionModeY(const Block & block,bool intra_y_mode)389 void Tile::ReadPredictionModeY(const Block& block, bool intra_y_mode) {
390   uint16_t* cdf;
391   if (intra_y_mode) {
392     const PredictionMode top_mode =
393         block.top_available[kPlaneY] ? block.bp_top->y_mode : kPredictionModeDc;
394     const PredictionMode left_mode = block.left_available[kPlaneY]
395                                          ? block.bp_left->y_mode
396                                          : kPredictionModeDc;
397     const int top_context = kIntraYModeContext[top_mode];
398     const int left_context = kIntraYModeContext[left_mode];
399     cdf = symbol_decoder_context_
400               .intra_frame_y_mode_cdf[top_context][left_context];
401   } else {
402     cdf = symbol_decoder_context_.y_mode_cdf[kSizeGroup[block.size]];
403   }
404   block.bp->y_mode = static_cast<PredictionMode>(
405       reader_.ReadSymbol<kIntraPredictionModesY>(cdf));
406 }
407 
ReadIntraAngleInfo(const Block & block,PlaneType plane_type)408 void Tile::ReadIntraAngleInfo(const Block& block, PlaneType plane_type) {
409   BlockParameters& bp = *block.bp;
410   PredictionParameters& prediction_parameters =
411       *block.bp->prediction_parameters;
412   prediction_parameters.angle_delta[plane_type] = 0;
413   const PredictionMode mode = (plane_type == kPlaneTypeY)
414                                   ? bp.y_mode
415                                   : bp.prediction_parameters->uv_mode;
416   if (IsBlockSmallerThan8x8(block.size) || !IsDirectionalMode(mode)) return;
417   uint16_t* const cdf =
418       symbol_decoder_context_.angle_delta_cdf[mode - kPredictionModeVertical];
419   prediction_parameters.angle_delta[plane_type] =
420       reader_.ReadSymbol<kAngleDeltaSymbolCount>(cdf);
421   prediction_parameters.angle_delta[plane_type] -= kMaxAngleDelta;
422 }
423 
ReadCflAlpha(const Block & block)424 void Tile::ReadCflAlpha(const Block& block) {
425   const int signs = reader_.ReadSymbol<kCflAlphaSignsSymbolCount>(
426       symbol_decoder_context_.cfl_alpha_signs_cdf);
427   const int8_t* const cfl_lookup = kCflAlphaLookup[signs];
428   const auto sign_u = static_cast<CflSign>(cfl_lookup[0]);
429   const auto sign_v = static_cast<CflSign>(cfl_lookup[1]);
430   PredictionParameters& prediction_parameters =
431       *block.bp->prediction_parameters;
432   prediction_parameters.cfl_alpha_u = 0;
433   if (sign_u != kCflSignZero) {
434     assert(cfl_lookup[2] >= 0);
435     prediction_parameters.cfl_alpha_u =
436         reader_.ReadSymbol<kCflAlphaSymbolCount>(
437             symbol_decoder_context_.cfl_alpha_cdf[cfl_lookup[2]]) +
438         1;
439     if (sign_u == kCflSignNegative) prediction_parameters.cfl_alpha_u *= -1;
440   }
441   prediction_parameters.cfl_alpha_v = 0;
442   if (sign_v != kCflSignZero) {
443     assert(cfl_lookup[3] >= 0);
444     prediction_parameters.cfl_alpha_v =
445         reader_.ReadSymbol<kCflAlphaSymbolCount>(
446             symbol_decoder_context_.cfl_alpha_cdf[cfl_lookup[3]]) +
447         1;
448     if (sign_v == kCflSignNegative) prediction_parameters.cfl_alpha_v *= -1;
449   }
450 }
451 
ReadPredictionModeUV(const Block & block)452 void Tile::ReadPredictionModeUV(const Block& block) {
453   BlockParameters& bp = *block.bp;
454   bool chroma_from_luma_allowed;
455   if (frame_header_.segmentation
456           .lossless[bp.prediction_parameters->segment_id]) {
457     chroma_from_luma_allowed = block.residual_size[kPlaneU] == kBlock4x4;
458   } else {
459     chroma_from_luma_allowed = IsBlockDimensionLessThan64(block.size);
460   }
461   uint16_t* const cdf =
462       symbol_decoder_context_
463           .uv_mode_cdf[static_cast<int>(chroma_from_luma_allowed)][bp.y_mode];
464   if (chroma_from_luma_allowed) {
465     bp.prediction_parameters->uv_mode = static_cast<PredictionMode>(
466         reader_.ReadSymbol<kIntraPredictionModesUV>(cdf));
467   } else {
468     bp.prediction_parameters->uv_mode = static_cast<PredictionMode>(
469         reader_.ReadSymbol<kIntraPredictionModesUV - 1>(cdf));
470   }
471 }
472 
ReadMotionVectorComponent(const Block & block,const int component)473 int Tile::ReadMotionVectorComponent(const Block& block, const int component) {
474   const int context =
475       static_cast<int>(block.bp->prediction_parameters->use_intra_block_copy);
476   const bool sign = reader_.ReadSymbol(
477       symbol_decoder_context_.mv_sign_cdf[component][context]);
478   const int mv_class = reader_.ReadSymbol<kMvClassSymbolCount>(
479       symbol_decoder_context_.mv_class_cdf[component][context]);
480   int magnitude = 1;
481   int value;
482   uint16_t* fraction_cdf;
483   uint16_t* precision_cdf;
484   if (mv_class == 0) {
485     value = static_cast<int>(reader_.ReadSymbol(
486         symbol_decoder_context_.mv_class0_bit_cdf[component][context]));
487     fraction_cdf = symbol_decoder_context_
488                        .mv_class0_fraction_cdf[component][context][value];
489     precision_cdf = symbol_decoder_context_
490                         .mv_class0_high_precision_cdf[component][context];
491   } else {
492     assert(mv_class <= kMvBitSymbolCount);
493     value = 0;
494     for (int i = 0; i < mv_class; ++i) {
495       const int bit = static_cast<int>(reader_.ReadSymbol(
496           symbol_decoder_context_.mv_bit_cdf[component][context][i]));
497       value |= bit << i;
498     }
499     magnitude += 2 << (mv_class + 2);
500     fraction_cdf = symbol_decoder_context_.mv_fraction_cdf[component][context];
501     precision_cdf =
502         symbol_decoder_context_.mv_high_precision_cdf[component][context];
503   }
504   const int fraction =
505       (frame_header_.force_integer_mv == 0)
506           ? reader_.ReadSymbol<kMvFractionSymbolCount>(fraction_cdf)
507           : 3;
508   const int precision =
509       frame_header_.allow_high_precision_mv
510           ? static_cast<int>(reader_.ReadSymbol(precision_cdf))
511           : 1;
512   magnitude += (value << 3) | (fraction << 1) | precision;
513   return sign ? -magnitude : magnitude;
514 }
515 
ReadMotionVector(const Block & block,int index)516 void Tile::ReadMotionVector(const Block& block, int index) {
517   BlockParameters& bp = *block.bp;
518   const int context =
519       static_cast<int>(block.bp->prediction_parameters->use_intra_block_copy);
520   const auto mv_joint =
521       static_cast<MvJointType>(reader_.ReadSymbol<kNumMvJointTypes>(
522           symbol_decoder_context_.mv_joint_cdf[context]));
523   if (mv_joint == kMvJointTypeHorizontalZeroVerticalNonZero ||
524       mv_joint == kMvJointTypeNonZero) {
525     bp.mv.mv[index].mv[0] = ReadMotionVectorComponent(block, 0);
526   }
527   if (mv_joint == kMvJointTypeHorizontalNonZeroVerticalZero ||
528       mv_joint == kMvJointTypeNonZero) {
529     bp.mv.mv[index].mv[1] = ReadMotionVectorComponent(block, 1);
530   }
531 }
532 
ReadFilterIntraModeInfo(const Block & block)533 void Tile::ReadFilterIntraModeInfo(const Block& block) {
534   BlockParameters& bp = *block.bp;
535   PredictionParameters& prediction_parameters =
536       *block.bp->prediction_parameters;
537   prediction_parameters.use_filter_intra = false;
538   if (!sequence_header_.enable_filter_intra || bp.y_mode != kPredictionModeDc ||
539       bp.prediction_parameters->palette_mode_info.size[kPlaneTypeY] != 0 ||
540       !IsBlockDimensionLessThan64(block.size)) {
541     return;
542   }
543   prediction_parameters.use_filter_intra = reader_.ReadSymbol(
544       symbol_decoder_context_.use_filter_intra_cdf[block.size]);
545   if (prediction_parameters.use_filter_intra) {
546     prediction_parameters.filter_intra_mode = static_cast<FilterIntraPredictor>(
547         reader_.ReadSymbol<kNumFilterIntraPredictors>(
548             symbol_decoder_context_.filter_intra_mode_cdf));
549   }
550 }
551 
DecodeIntraModeInfo(const Block & block)552 bool Tile::DecodeIntraModeInfo(const Block& block) {
553   BlockParameters& bp = *block.bp;
554   bp.skip = false;
555   if (frame_header_.segmentation.segment_id_pre_skip &&
556       !ReadIntraSegmentId(block)) {
557     return false;
558   }
559   SetCdfContextSkipMode(block, false);
560   ReadSkip(block);
561   if (!frame_header_.segmentation.segment_id_pre_skip &&
562       !ReadIntraSegmentId(block)) {
563     return false;
564   }
565   ReadCdef(block);
566   if (read_deltas_) {
567     ReadQuantizerIndexDelta(block);
568     ReadLoopFilterDelta(block);
569     read_deltas_ = false;
570   }
571   PredictionParameters& prediction_parameters =
572       *block.bp->prediction_parameters;
573   prediction_parameters.use_intra_block_copy = false;
574   if (frame_header_.allow_intrabc) {
575     prediction_parameters.use_intra_block_copy =
576         reader_.ReadSymbol(symbol_decoder_context_.intra_block_copy_cdf);
577   }
578   if (prediction_parameters.use_intra_block_copy) {
579     bp.is_inter = true;
580     bp.reference_frame[0] = kReferenceFrameIntra;
581     bp.reference_frame[1] = kReferenceFrameNone;
582     bp.y_mode = kPredictionModeDc;
583     bp.prediction_parameters->uv_mode = kPredictionModeDc;
584     SetCdfContextUVMode(block);
585     prediction_parameters.motion_mode = kMotionModeSimple;
586     prediction_parameters.compound_prediction_type =
587         kCompoundPredictionTypeAverage;
588     bp.prediction_parameters->palette_mode_info.size[kPlaneTypeY] = 0;
589     bp.prediction_parameters->palette_mode_info.size[kPlaneTypeUV] = 0;
590     SetCdfContextPaletteSize(block);
591     bp.interpolation_filter[0] = kInterpolationFilterBilinear;
592     bp.interpolation_filter[1] = kInterpolationFilterBilinear;
593     MvContexts dummy_mode_contexts;
594     FindMvStack(block, /*is_compound=*/false, &dummy_mode_contexts);
595     return AssignIntraMv(block);
596   }
597   bp.is_inter = false;
598   return ReadIntraBlockModeInfo(block, /*intra_y_mode=*/true);
599 }
600 
ComputePredictedSegmentId(const Block & block) const601 int8_t Tile::ComputePredictedSegmentId(const Block& block) const {
602   // If prev_segment_ids_ is null, treat it as if it pointed to a segmentation
603   // map containing all 0s.
604   if (prev_segment_ids_ == nullptr) return 0;
605 
606   const int x_limit = std::min(frame_header_.columns4x4 - block.column4x4,
607                                static_cast<int>(block.width4x4));
608   const int y_limit = std::min(frame_header_.rows4x4 - block.row4x4,
609                                static_cast<int>(block.height4x4));
610   int8_t id = 7;
611   for (int y = 0; y < y_limit; ++y) {
612     for (int x = 0; x < x_limit; ++x) {
613       const int8_t prev_segment_id =
614           prev_segment_ids_->segment_id(block.row4x4 + y, block.column4x4 + x);
615       id = std::min(id, prev_segment_id);
616     }
617   }
618   return id;
619 }
620 
SetCdfContextUsePredictedSegmentId(const Block & block,bool use_predicted_segment_id)621 void Tile::SetCdfContextUsePredictedSegmentId(const Block& block,
622                                               bool use_predicted_segment_id) {
623   memset(left_context_.use_predicted_segment_id + block.left_context_index,
624          static_cast<int>(use_predicted_segment_id), block.height4x4);
625   memset(block.top_context->use_predicted_segment_id + block.top_context_index,
626          static_cast<int>(use_predicted_segment_id), block.width4x4);
627 }
628 
ReadInterSegmentId(const Block & block,bool pre_skip)629 bool Tile::ReadInterSegmentId(const Block& block, bool pre_skip) {
630   BlockParameters& bp = *block.bp;
631   if (!frame_header_.segmentation.enabled) {
632     bp.prediction_parameters->segment_id = 0;
633     return true;
634   }
635   if (!frame_header_.segmentation.update_map) {
636     bp.prediction_parameters->segment_id = ComputePredictedSegmentId(block);
637     return true;
638   }
639   if (pre_skip) {
640     if (!frame_header_.segmentation.segment_id_pre_skip) {
641       bp.prediction_parameters->segment_id = 0;
642       return true;
643     }
644   } else if (bp.skip) {
645     SetCdfContextUsePredictedSegmentId(block, false);
646     return ReadSegmentId(block);
647   }
648   if (frame_header_.segmentation.temporal_update) {
649     const int context =
650         (block.left_available[kPlaneY]
651              ? static_cast<int>(
652                    left_context_
653                        .use_predicted_segment_id[block.left_context_index])
654              : 0) +
655         (block.top_available[kPlaneY]
656              ? static_cast<int>(
657                    block.top_context
658                        ->use_predicted_segment_id[block.top_context_index])
659              : 0);
660     const bool use_predicted_segment_id = reader_.ReadSymbol(
661         symbol_decoder_context_.use_predicted_segment_id_cdf[context]);
662     SetCdfContextUsePredictedSegmentId(block, use_predicted_segment_id);
663     if (use_predicted_segment_id) {
664       bp.prediction_parameters->segment_id = ComputePredictedSegmentId(block);
665       return true;
666     }
667   }
668   return ReadSegmentId(block);
669 }
670 
ReadIsInter(const Block & block,bool skip_mode)671 void Tile::ReadIsInter(const Block& block, bool skip_mode) {
672   BlockParameters& bp = *block.bp;
673   if (skip_mode) {
674     bp.is_inter = true;
675     return;
676   }
677   if (frame_header_.segmentation.FeatureActive(
678           bp.prediction_parameters->segment_id,
679           kSegmentFeatureReferenceFrame)) {
680     bp.is_inter = frame_header_.segmentation
681                       .feature_data[bp.prediction_parameters->segment_id]
682                                    [kSegmentFeatureReferenceFrame] !=
683                   kReferenceFrameIntra;
684     return;
685   }
686   if (frame_header_.segmentation.FeatureActive(
687           bp.prediction_parameters->segment_id, kSegmentFeatureGlobalMv)) {
688     bp.is_inter = true;
689     return;
690   }
691   int context = 0;
692   if (block.top_available[kPlaneY] && block.left_available[kPlaneY]) {
693     context = (block.IsTopIntra() && block.IsLeftIntra())
694                   ? 3
695                   : static_cast<int>(block.IsTopIntra() || block.IsLeftIntra());
696   } else if (block.top_available[kPlaneY] || block.left_available[kPlaneY]) {
697     context = 2 * static_cast<int>(block.top_available[kPlaneY]
698                                        ? block.IsTopIntra()
699                                        : block.IsLeftIntra());
700   }
701   bp.is_inter =
702       reader_.ReadSymbol(symbol_decoder_context_.is_inter_cdf[context]);
703 }
704 
SetCdfContextPaletteSize(const Block & block)705 void Tile::SetCdfContextPaletteSize(const Block& block) {
706   const PaletteModeInfo& palette_mode_info =
707       block.bp->prediction_parameters->palette_mode_info;
708   for (int plane_type = kPlaneTypeY; plane_type <= kPlaneTypeUV; ++plane_type) {
709     memset(left_context_.palette_size[plane_type] + block.left_context_index,
710            palette_mode_info.size[plane_type], block.height4x4);
711     memset(
712         block.top_context->palette_size[plane_type] + block.top_context_index,
713         palette_mode_info.size[plane_type], block.width4x4);
714     if (palette_mode_info.size[plane_type] == 0) continue;
715     for (int i = block.left_context_index;
716          i < block.left_context_index + block.height4x4; ++i) {
717       memcpy(left_context_.palette_color[i][plane_type],
718              palette_mode_info.color[plane_type],
719              kMaxPaletteSize * sizeof(palette_mode_info.color[0][0]));
720     }
721     for (int i = block.top_context_index;
722          i < block.top_context_index + block.width4x4; ++i) {
723       memcpy(block.top_context->palette_color[i][plane_type],
724              palette_mode_info.color[plane_type],
725              kMaxPaletteSize * sizeof(palette_mode_info.color[0][0]));
726     }
727   }
728 }
729 
SetCdfContextUVMode(const Block & block)730 void Tile::SetCdfContextUVMode(const Block& block) {
731   // BlockCdfContext.uv_mode is only used to compute is_smooth_prediction for
732   // the intra edge upsamplers in the subsequent blocks. They have some special
733   // rules for subsampled UV planes. For subsampled UV planes, update left
734   // context only if current block contains the last odd column and update top
735   // context only if current block contains the last odd row.
736   if (subsampling_x_[kPlaneU] == 0 || (block.column4x4 & 1) == 1 ||
737       block.width4x4 > 1) {
738     memset(left_context_.uv_mode + block.left_context_index,
739            block.bp->prediction_parameters->uv_mode, block.height4x4);
740   }
741   if (subsampling_y_[kPlaneU] == 0 || (block.row4x4 & 1) == 1 ||
742       block.height4x4 > 1) {
743     memset(block.top_context->uv_mode + block.top_context_index,
744            block.bp->prediction_parameters->uv_mode, block.width4x4);
745   }
746 }
747 
ReadIntraBlockModeInfo(const Block & block,bool intra_y_mode)748 bool Tile::ReadIntraBlockModeInfo(const Block& block, bool intra_y_mode) {
749   BlockParameters& bp = *block.bp;
750   bp.reference_frame[0] = kReferenceFrameIntra;
751   bp.reference_frame[1] = kReferenceFrameNone;
752   ReadPredictionModeY(block, intra_y_mode);
753   ReadIntraAngleInfo(block, kPlaneTypeY);
754   if (block.HasChroma()) {
755     ReadPredictionModeUV(block);
756     if (bp.prediction_parameters->uv_mode == kPredictionModeChromaFromLuma) {
757       ReadCflAlpha(block);
758     }
759     if (block.left_available[kPlaneU]) {
760       const int smooth_row =
761           block.row4x4 + (~block.row4x4 & subsampling_y_[kPlaneU]);
762       const int smooth_column =
763           block.column4x4 - 1 - (block.column4x4 & subsampling_x_[kPlaneU]);
764       const BlockParameters& bp_left =
765           *block_parameters_holder_.Find(smooth_row, smooth_column);
766       bp.prediction_parameters->chroma_left_uses_smooth_prediction =
767           (bp_left.reference_frame[0] <= kReferenceFrameIntra) &&
768           kPredictionModeSmoothMask.Contains(
769               left_context_.uv_mode[CdfContextIndex(smooth_row)]);
770     }
771     if (block.top_available[kPlaneU]) {
772       const int smooth_row =
773           block.row4x4 - 1 - (block.row4x4 & subsampling_y_[kPlaneU]);
774       const int smooth_column =
775           block.column4x4 + (~block.column4x4 & subsampling_x_[kPlaneU]);
776       const BlockParameters& bp_top =
777           *block_parameters_holder_.Find(smooth_row, smooth_column);
778       bp.prediction_parameters->chroma_top_uses_smooth_prediction =
779           (bp_top.reference_frame[0] <= kReferenceFrameIntra) &&
780           kPredictionModeSmoothMask.Contains(
781               top_context_.get()[SuperBlockColumnIndex(smooth_column)]
782                   .uv_mode[CdfContextIndex(smooth_column)]);
783     }
784     SetCdfContextUVMode(block);
785     ReadIntraAngleInfo(block, kPlaneTypeUV);
786   }
787   ReadPaletteModeInfo(block);
788   SetCdfContextPaletteSize(block);
789   ReadFilterIntraModeInfo(block);
790   return true;
791 }
792 
ReadCompoundReferenceType(const Block & block)793 CompoundReferenceType Tile::ReadCompoundReferenceType(const Block& block) {
794   // compound and inter.
795   const bool top_comp_inter = block.top_available[kPlaneY] &&
796                               !block.IsTopIntra() && !block.IsTopSingle();
797   const bool left_comp_inter = block.left_available[kPlaneY] &&
798                                !block.IsLeftIntra() && !block.IsLeftSingle();
799   // unidirectional compound.
800   const bool top_uni_comp =
801       top_comp_inter && IsSameDirectionReferencePair(block.TopReference(0),
802                                                      block.TopReference(1));
803   const bool left_uni_comp =
804       left_comp_inter && IsSameDirectionReferencePair(block.LeftReference(0),
805                                                       block.LeftReference(1));
806   int context;
807   if (block.top_available[kPlaneY] && !block.IsTopIntra() &&
808       block.left_available[kPlaneY] && !block.IsLeftIntra()) {
809     const int same_direction = static_cast<int>(IsSameDirectionReferencePair(
810         block.TopReference(0), block.LeftReference(0)));
811     if (!top_comp_inter && !left_comp_inter) {
812       context = 1 + MultiplyBy2(same_direction);
813     } else if (!top_comp_inter) {
814       context = left_uni_comp ? 3 + same_direction : 1;
815     } else if (!left_comp_inter) {
816       context = top_uni_comp ? 3 + same_direction : 1;
817     } else {
818       if (!top_uni_comp && !left_uni_comp) {
819         context = 0;
820       } else if (!top_uni_comp || !left_uni_comp) {
821         context = 2;
822       } else {
823         context = 3 + static_cast<int>(
824                           (block.TopReference(0) == kReferenceFrameBackward) ==
825                           (block.LeftReference(0) == kReferenceFrameBackward));
826       }
827     }
828   } else if (block.top_available[kPlaneY] && block.left_available[kPlaneY]) {
829     if (top_comp_inter) {
830       context = 1 + MultiplyBy2(static_cast<int>(top_uni_comp));
831     } else if (left_comp_inter) {
832       context = 1 + MultiplyBy2(static_cast<int>(left_uni_comp));
833     } else {
834       context = 2;
835     }
836   } else if (top_comp_inter) {
837     context = MultiplyBy4(static_cast<int>(top_uni_comp));
838   } else if (left_comp_inter) {
839     context = MultiplyBy4(static_cast<int>(left_uni_comp));
840   } else {
841     context = 2;
842   }
843   return static_cast<CompoundReferenceType>(reader_.ReadSymbol(
844       symbol_decoder_context_.compound_reference_type_cdf[context]));
845 }
846 
847 template <bool is_single, bool is_backward, int index>
GetReferenceCdf(const Block & block,CompoundReferenceType type)848 uint16_t* Tile::GetReferenceCdf(
849     const Block& block,
850     CompoundReferenceType type /*= kNumCompoundReferenceTypes*/) {
851   int context = 0;
852   if ((type == kCompoundReferenceUnidirectional && index == 0) ||
853       (is_single && index == 1)) {
854     // uni_comp_ref and single_ref_p1.
855     context =
856         GetReferenceContext(block, kReferenceFrameLast, kReferenceFrameGolden,
857                             kReferenceFrameBackward, kReferenceFrameAlternate);
858   } else if (type == kCompoundReferenceUnidirectional && index == 1) {
859     // uni_comp_ref_p1.
860     context =
861         GetReferenceContext(block, kReferenceFrameLast2, kReferenceFrameLast2,
862                             kReferenceFrameLast3, kReferenceFrameGolden);
863   } else if ((type == kCompoundReferenceUnidirectional && index == 2) ||
864              (type == kCompoundReferenceBidirectional && index == 2) ||
865              (is_single && index == 5)) {
866     // uni_comp_ref_p2, comp_ref_p2 and single_ref_p5.
867     context =
868         GetReferenceContext(block, kReferenceFrameLast3, kReferenceFrameLast3,
869                             kReferenceFrameGolden, kReferenceFrameGolden);
870   } else if ((type == kCompoundReferenceBidirectional && index == 0) ||
871              (is_single && index == 3)) {
872     // comp_ref and single_ref_p3.
873     context =
874         GetReferenceContext(block, kReferenceFrameLast, kReferenceFrameLast2,
875                             kReferenceFrameLast3, kReferenceFrameGolden);
876   } else if ((type == kCompoundReferenceBidirectional && index == 1) ||
877              (is_single && index == 4)) {
878     // comp_ref_p1 and single_ref_p4.
879     context =
880         GetReferenceContext(block, kReferenceFrameLast, kReferenceFrameLast,
881                             kReferenceFrameLast2, kReferenceFrameLast2);
882   } else if ((is_single && index == 2) || (is_backward && index == 0)) {
883     // single_ref_p2 and comp_bwdref.
884     context = GetReferenceContext(
885         block, kReferenceFrameBackward, kReferenceFrameAlternate2,
886         kReferenceFrameAlternate, kReferenceFrameAlternate);
887   } else if ((is_single && index == 6) || (is_backward && index == 1)) {
888     // single_ref_p6 and comp_bwdref_p1.
889     context = GetReferenceContext(
890         block, kReferenceFrameBackward, kReferenceFrameBackward,
891         kReferenceFrameAlternate2, kReferenceFrameAlternate2);
892   }
893   if (is_single) {
894     // The index parameter for single references is offset by one since the spec
895     // uses 1-based index for these elements.
896     return symbol_decoder_context_.single_reference_cdf[context][index - 1];
897   }
898   if (is_backward) {
899     return symbol_decoder_context_
900         .compound_backward_reference_cdf[context][index];
901   }
902   return symbol_decoder_context_.compound_reference_cdf[type][context][index];
903 }
904 
ReadReferenceFrames(const Block & block,bool skip_mode)905 void Tile::ReadReferenceFrames(const Block& block, bool skip_mode) {
906   BlockParameters& bp = *block.bp;
907   if (skip_mode) {
908     bp.reference_frame[0] = frame_header_.skip_mode_frame[0];
909     bp.reference_frame[1] = frame_header_.skip_mode_frame[1];
910     return;
911   }
912   if (frame_header_.segmentation.FeatureActive(
913           bp.prediction_parameters->segment_id,
914           kSegmentFeatureReferenceFrame)) {
915     bp.reference_frame[0] = static_cast<ReferenceFrameType>(
916         frame_header_.segmentation
917             .feature_data[bp.prediction_parameters->segment_id]
918                          [kSegmentFeatureReferenceFrame]);
919     bp.reference_frame[1] = kReferenceFrameNone;
920     return;
921   }
922   if (frame_header_.segmentation.FeatureActive(
923           bp.prediction_parameters->segment_id, kSegmentFeatureSkip) ||
924       frame_header_.segmentation.FeatureActive(
925           bp.prediction_parameters->segment_id, kSegmentFeatureGlobalMv)) {
926     bp.reference_frame[0] = kReferenceFrameLast;
927     bp.reference_frame[1] = kReferenceFrameNone;
928     return;
929   }
930   const bool use_compound_reference =
931       frame_header_.reference_mode_select &&
932       std::min(block.width4x4, block.height4x4) >= 2 &&
933       reader_.ReadSymbol(symbol_decoder_context_.use_compound_reference_cdf
934                              [GetUseCompoundReferenceContext(block)]);
935   if (use_compound_reference) {
936     CompoundReferenceType reference_type = ReadCompoundReferenceType(block);
937     if (reference_type == kCompoundReferenceUnidirectional) {
938       // uni_comp_ref.
939       if (reader_.ReadSymbol(
940               GetReferenceCdf<false, false, 0>(block, reference_type))) {
941         bp.reference_frame[0] = kReferenceFrameBackward;
942         bp.reference_frame[1] = kReferenceFrameAlternate;
943         return;
944       }
945       // uni_comp_ref_p1.
946       if (!reader_.ReadSymbol(
947               GetReferenceCdf<false, false, 1>(block, reference_type))) {
948         bp.reference_frame[0] = kReferenceFrameLast;
949         bp.reference_frame[1] = kReferenceFrameLast2;
950         return;
951       }
952       // uni_comp_ref_p2.
953       if (reader_.ReadSymbol(
954               GetReferenceCdf<false, false, 2>(block, reference_type))) {
955         bp.reference_frame[0] = kReferenceFrameLast;
956         bp.reference_frame[1] = kReferenceFrameGolden;
957         return;
958       }
959       bp.reference_frame[0] = kReferenceFrameLast;
960       bp.reference_frame[1] = kReferenceFrameLast3;
961       return;
962     }
963     assert(reference_type == kCompoundReferenceBidirectional);
964     // comp_ref.
965     if (reader_.ReadSymbol(
966             GetReferenceCdf<false, false, 0>(block, reference_type))) {
967       // comp_ref_p2.
968       bp.reference_frame[0] =
969           reader_.ReadSymbol(
970               GetReferenceCdf<false, false, 2>(block, reference_type))
971               ? kReferenceFrameGolden
972               : kReferenceFrameLast3;
973     } else {
974       // comp_ref_p1.
975       bp.reference_frame[0] =
976           reader_.ReadSymbol(
977               GetReferenceCdf<false, false, 1>(block, reference_type))
978               ? kReferenceFrameLast2
979               : kReferenceFrameLast;
980     }
981     // comp_bwdref.
982     if (reader_.ReadSymbol(GetReferenceCdf<false, true, 0>(block))) {
983       bp.reference_frame[1] = kReferenceFrameAlternate;
984     } else {
985       // comp_bwdref_p1.
986       bp.reference_frame[1] =
987           reader_.ReadSymbol(GetReferenceCdf<false, true, 1>(block))
988               ? kReferenceFrameAlternate2
989               : kReferenceFrameBackward;
990     }
991     return;
992   }
993   assert(!use_compound_reference);
994   bp.reference_frame[1] = kReferenceFrameNone;
995   // single_ref_p1.
996   if (reader_.ReadSymbol(GetReferenceCdf<true, false, 1>(block))) {
997     // single_ref_p2.
998     if (reader_.ReadSymbol(GetReferenceCdf<true, false, 2>(block))) {
999       bp.reference_frame[0] = kReferenceFrameAlternate;
1000       return;
1001     }
1002     // single_ref_p6.
1003     bp.reference_frame[0] =
1004         reader_.ReadSymbol(GetReferenceCdf<true, false, 6>(block))
1005             ? kReferenceFrameAlternate2
1006             : kReferenceFrameBackward;
1007     return;
1008   }
1009   // single_ref_p3.
1010   if (reader_.ReadSymbol(GetReferenceCdf<true, false, 3>(block))) {
1011     // single_ref_p5.
1012     bp.reference_frame[0] =
1013         reader_.ReadSymbol(GetReferenceCdf<true, false, 5>(block))
1014             ? kReferenceFrameGolden
1015             : kReferenceFrameLast3;
1016     return;
1017   }
1018   // single_ref_p4.
1019   bp.reference_frame[0] =
1020       reader_.ReadSymbol(GetReferenceCdf<true, false, 4>(block))
1021           ? kReferenceFrameLast2
1022           : kReferenceFrameLast;
1023 }
1024 
ReadInterPredictionModeY(const Block & block,const MvContexts & mode_contexts,bool skip_mode)1025 void Tile::ReadInterPredictionModeY(const Block& block,
1026                                     const MvContexts& mode_contexts,
1027                                     bool skip_mode) {
1028   BlockParameters& bp = *block.bp;
1029   if (skip_mode) {
1030     bp.y_mode = kPredictionModeNearestNearestMv;
1031     return;
1032   }
1033   if (frame_header_.segmentation.FeatureActive(
1034           bp.prediction_parameters->segment_id, kSegmentFeatureSkip) ||
1035       frame_header_.segmentation.FeatureActive(
1036           bp.prediction_parameters->segment_id, kSegmentFeatureGlobalMv)) {
1037     bp.y_mode = kPredictionModeGlobalMv;
1038     return;
1039   }
1040   if (bp.reference_frame[1] > kReferenceFrameIntra) {
1041     const int idx0 = mode_contexts.reference_mv >> 1;
1042     const int idx1 =
1043         std::min(mode_contexts.new_mv, kCompoundModeNewMvContexts - 1);
1044     const int context = kCompoundModeContextMap[idx0][idx1];
1045     const int offset = reader_.ReadSymbol<kNumCompoundInterPredictionModes>(
1046         symbol_decoder_context_.compound_prediction_mode_cdf[context]);
1047     bp.y_mode =
1048         static_cast<PredictionMode>(kPredictionModeNearestNearestMv + offset);
1049     return;
1050   }
1051   // new_mv.
1052   if (!reader_.ReadSymbol(
1053           symbol_decoder_context_.new_mv_cdf[mode_contexts.new_mv])) {
1054     bp.y_mode = kPredictionModeNewMv;
1055     return;
1056   }
1057   // zero_mv.
1058   if (!reader_.ReadSymbol(
1059           symbol_decoder_context_.zero_mv_cdf[mode_contexts.zero_mv])) {
1060     bp.y_mode = kPredictionModeGlobalMv;
1061     return;
1062   }
1063   // ref_mv.
1064   bp.y_mode =
1065       reader_.ReadSymbol(
1066           symbol_decoder_context_.reference_mv_cdf[mode_contexts.reference_mv])
1067           ? kPredictionModeNearMv
1068           : kPredictionModeNearestMv;
1069 }
1070 
ReadRefMvIndex(const Block & block)1071 void Tile::ReadRefMvIndex(const Block& block) {
1072   BlockParameters& bp = *block.bp;
1073   PredictionParameters& prediction_parameters =
1074       *block.bp->prediction_parameters;
1075   prediction_parameters.ref_mv_index = 0;
1076   if (bp.y_mode != kPredictionModeNewMv &&
1077       bp.y_mode != kPredictionModeNewNewMv &&
1078       !kPredictionModeHasNearMvMask.Contains(bp.y_mode)) {
1079     return;
1080   }
1081   const int start =
1082       static_cast<int>(kPredictionModeHasNearMvMask.Contains(bp.y_mode));
1083   prediction_parameters.ref_mv_index = start;
1084   for (int i = start; i < start + 2; ++i) {
1085     if (prediction_parameters.ref_mv_count <= i + 1) break;
1086     // drl_mode in the spec.
1087     const bool ref_mv_index_bit = reader_.ReadSymbol(
1088         symbol_decoder_context_.ref_mv_index_cdf[GetRefMvIndexContext(
1089             prediction_parameters.nearest_mv_count, i)]);
1090     prediction_parameters.ref_mv_index = i + static_cast<int>(ref_mv_index_bit);
1091     if (!ref_mv_index_bit) return;
1092   }
1093 }
1094 
ReadInterIntraMode(const Block & block,bool is_compound,bool skip_mode)1095 void Tile::ReadInterIntraMode(const Block& block, bool is_compound,
1096                               bool skip_mode) {
1097   BlockParameters& bp = *block.bp;
1098   PredictionParameters& prediction_parameters =
1099       *block.bp->prediction_parameters;
1100   prediction_parameters.inter_intra_mode = kNumInterIntraModes;
1101   prediction_parameters.is_wedge_inter_intra = false;
1102   if (skip_mode || !sequence_header_.enable_interintra_compound ||
1103       is_compound || !kIsInterIntraModeAllowedMask.Contains(block.size)) {
1104     return;
1105   }
1106   // kSizeGroup[block.size] is guaranteed to be non-zero because of the block
1107   // size constraint enforced in the above condition.
1108   assert(kSizeGroup[block.size] - 1 >= 0);
1109   if (!reader_.ReadSymbol(
1110           symbol_decoder_context_
1111               .is_inter_intra_cdf[kSizeGroup[block.size] - 1])) {
1112     prediction_parameters.inter_intra_mode = kNumInterIntraModes;
1113     return;
1114   }
1115   prediction_parameters.inter_intra_mode =
1116       static_cast<InterIntraMode>(reader_.ReadSymbol<kNumInterIntraModes>(
1117           symbol_decoder_context_
1118               .inter_intra_mode_cdf[kSizeGroup[block.size] - 1]));
1119   bp.reference_frame[1] = kReferenceFrameIntra;
1120   prediction_parameters.angle_delta[kPlaneTypeY] = 0;
1121   prediction_parameters.angle_delta[kPlaneTypeUV] = 0;
1122   prediction_parameters.use_filter_intra = false;
1123   prediction_parameters.is_wedge_inter_intra = reader_.ReadSymbol(
1124       symbol_decoder_context_.is_wedge_inter_intra_cdf[block.size]);
1125   if (!prediction_parameters.is_wedge_inter_intra) return;
1126   prediction_parameters.wedge_index =
1127       reader_.ReadSymbol<kWedgeIndexSymbolCount>(
1128           symbol_decoder_context_.wedge_index_cdf[block.size]);
1129   prediction_parameters.wedge_sign = 0;
1130 }
1131 
ReadMotionMode(const Block & block,bool is_compound,bool skip_mode)1132 void Tile::ReadMotionMode(const Block& block, bool is_compound,
1133                           bool skip_mode) {
1134   BlockParameters& bp = *block.bp;
1135   PredictionParameters& prediction_parameters =
1136       *block.bp->prediction_parameters;
1137   const auto global_motion_type =
1138       frame_header_.global_motion[bp.reference_frame[0]].type;
1139   if (skip_mode || !frame_header_.is_motion_mode_switchable ||
1140       IsBlockDimension4(block.size) ||
1141       (frame_header_.force_integer_mv == 0 &&
1142        (bp.y_mode == kPredictionModeGlobalMv ||
1143         bp.y_mode == kPredictionModeGlobalGlobalMv) &&
1144        global_motion_type > kGlobalMotionTransformationTypeTranslation) ||
1145       is_compound || bp.reference_frame[1] == kReferenceFrameIntra ||
1146       !block.HasOverlappableCandidates()) {
1147     prediction_parameters.motion_mode = kMotionModeSimple;
1148     return;
1149   }
1150   prediction_parameters.num_warp_samples = 0;
1151   int num_samples_scanned = 0;
1152   memset(prediction_parameters.warp_estimate_candidates, 0,
1153          sizeof(prediction_parameters.warp_estimate_candidates));
1154   FindWarpSamples(block, &prediction_parameters.num_warp_samples,
1155                   &num_samples_scanned,
1156                   prediction_parameters.warp_estimate_candidates);
1157   if (frame_header_.force_integer_mv != 0 ||
1158       prediction_parameters.num_warp_samples == 0 ||
1159       !frame_header_.allow_warped_motion || IsScaled(bp.reference_frame[0])) {
1160     prediction_parameters.motion_mode =
1161         reader_.ReadSymbol(symbol_decoder_context_.use_obmc_cdf[block.size])
1162             ? kMotionModeObmc
1163             : kMotionModeSimple;
1164     return;
1165   }
1166   prediction_parameters.motion_mode =
1167       static_cast<MotionMode>(reader_.ReadSymbol<kNumMotionModes>(
1168           symbol_decoder_context_.motion_mode_cdf[block.size]));
1169 }
1170 
GetIsExplicitCompoundTypeCdf(const Block & block)1171 uint16_t* Tile::GetIsExplicitCompoundTypeCdf(const Block& block) {
1172   int context = 0;
1173   if (block.top_available[kPlaneY]) {
1174     if (!block.IsTopSingle()) {
1175       context += static_cast<int>(
1176           block.top_context
1177               ->is_explicit_compound_type[block.top_context_index]);
1178     } else if (block.TopReference(0) == kReferenceFrameAlternate) {
1179       context += 3;
1180     }
1181   }
1182   if (block.left_available[kPlaneY]) {
1183     if (!block.IsLeftSingle()) {
1184       context += static_cast<int>(
1185           left_context_.is_explicit_compound_type[block.left_context_index]);
1186     } else if (block.LeftReference(0) == kReferenceFrameAlternate) {
1187       context += 3;
1188     }
1189   }
1190   return symbol_decoder_context_.is_explicit_compound_type_cdf[std::min(
1191       context, kIsExplicitCompoundTypeContexts - 1)];
1192 }
1193 
GetIsCompoundTypeAverageCdf(const Block & block)1194 uint16_t* Tile::GetIsCompoundTypeAverageCdf(const Block& block) {
1195   const BlockParameters& bp = *block.bp;
1196   const ReferenceInfo& reference_info = *current_frame_.reference_info();
1197   const int forward =
1198       std::abs(reference_info.relative_distance_from[bp.reference_frame[0]]);
1199   const int backward =
1200       std::abs(reference_info.relative_distance_from[bp.reference_frame[1]]);
1201   int context = (forward == backward) ? 3 : 0;
1202   if (block.top_available[kPlaneY]) {
1203     if (!block.IsTopSingle()) {
1204       context += static_cast<int>(
1205           block.top_context->is_compound_type_average[block.top_context_index]);
1206     } else if (block.TopReference(0) == kReferenceFrameAlternate) {
1207       ++context;
1208     }
1209   }
1210   if (block.left_available[kPlaneY]) {
1211     if (!block.IsLeftSingle()) {
1212       context += static_cast<int>(
1213           left_context_.is_compound_type_average[block.left_context_index]);
1214     } else if (block.LeftReference(0) == kReferenceFrameAlternate) {
1215       ++context;
1216     }
1217   }
1218   return symbol_decoder_context_.is_compound_type_average_cdf[context];
1219 }
1220 
ReadCompoundType(const Block & block,bool is_compound,bool skip_mode,bool * const is_explicit_compound_type,bool * const is_compound_type_average)1221 void Tile::ReadCompoundType(const Block& block, bool is_compound,
1222                             bool skip_mode,
1223                             bool* const is_explicit_compound_type,
1224                             bool* const is_compound_type_average) {
1225   *is_explicit_compound_type = false;
1226   *is_compound_type_average = true;
1227   PredictionParameters& prediction_parameters =
1228       *block.bp->prediction_parameters;
1229   if (skip_mode) {
1230     prediction_parameters.compound_prediction_type =
1231         kCompoundPredictionTypeAverage;
1232     return;
1233   }
1234   if (is_compound) {
1235     if (sequence_header_.enable_masked_compound) {
1236       *is_explicit_compound_type =
1237           reader_.ReadSymbol(GetIsExplicitCompoundTypeCdf(block));
1238     }
1239     if (*is_explicit_compound_type) {
1240       if (kIsWedgeCompoundModeAllowed.Contains(block.size)) {
1241         // Only kCompoundPredictionTypeWedge and
1242         // kCompoundPredictionTypeDiffWeighted are signaled explicitly.
1243         prediction_parameters.compound_prediction_type =
1244             static_cast<CompoundPredictionType>(reader_.ReadSymbol(
1245                 symbol_decoder_context_.compound_type_cdf[block.size]));
1246       } else {
1247         prediction_parameters.compound_prediction_type =
1248             kCompoundPredictionTypeDiffWeighted;
1249       }
1250     } else {
1251       if (sequence_header_.enable_jnt_comp) {
1252         *is_compound_type_average =
1253             reader_.ReadSymbol(GetIsCompoundTypeAverageCdf(block));
1254         prediction_parameters.compound_prediction_type =
1255             *is_compound_type_average ? kCompoundPredictionTypeAverage
1256                                       : kCompoundPredictionTypeDistance;
1257       } else {
1258         prediction_parameters.compound_prediction_type =
1259             kCompoundPredictionTypeAverage;
1260         return;
1261       }
1262     }
1263     if (prediction_parameters.compound_prediction_type ==
1264         kCompoundPredictionTypeWedge) {
1265       prediction_parameters.wedge_index =
1266           reader_.ReadSymbol<kWedgeIndexSymbolCount>(
1267               symbol_decoder_context_.wedge_index_cdf[block.size]);
1268       prediction_parameters.wedge_sign = static_cast<int>(reader_.ReadBit());
1269     } else if (prediction_parameters.compound_prediction_type ==
1270                kCompoundPredictionTypeDiffWeighted) {
1271       prediction_parameters.mask_is_inverse = reader_.ReadBit() != 0;
1272     }
1273     return;
1274   }
1275   if (prediction_parameters.inter_intra_mode != kNumInterIntraModes) {
1276     prediction_parameters.compound_prediction_type =
1277         prediction_parameters.is_wedge_inter_intra
1278             ? kCompoundPredictionTypeWedge
1279             : kCompoundPredictionTypeIntra;
1280     return;
1281   }
1282   prediction_parameters.compound_prediction_type =
1283       kCompoundPredictionTypeAverage;
1284 }
1285 
GetInterpolationFilterCdf(const Block & block,int direction)1286 uint16_t* Tile::GetInterpolationFilterCdf(const Block& block, int direction) {
1287   const BlockParameters& bp = *block.bp;
1288   int context = MultiplyBy8(direction) +
1289                 MultiplyBy4(static_cast<int>(bp.reference_frame[1] >
1290                                              kReferenceFrameIntra));
1291   int top_type = kNumExplicitInterpolationFilters;
1292   if (block.top_available[kPlaneY]) {
1293     if (block.bp_top->reference_frame[0] == bp.reference_frame[0] ||
1294         block.bp_top->reference_frame[1] == bp.reference_frame[0]) {
1295       top_type = block.bp_top->interpolation_filter[direction];
1296     }
1297   }
1298   int left_type = kNumExplicitInterpolationFilters;
1299   if (block.left_available[kPlaneY]) {
1300     if (block.bp_left->reference_frame[0] == bp.reference_frame[0] ||
1301         block.bp_left->reference_frame[1] == bp.reference_frame[0]) {
1302       left_type = block.bp_left->interpolation_filter[direction];
1303     }
1304   }
1305   if (left_type == top_type) {
1306     context += left_type;
1307   } else if (left_type == kNumExplicitInterpolationFilters) {
1308     context += top_type;
1309   } else if (top_type == kNumExplicitInterpolationFilters) {
1310     context += left_type;
1311   } else {
1312     context += kNumExplicitInterpolationFilters;
1313   }
1314   return symbol_decoder_context_.interpolation_filter_cdf[context];
1315 }
1316 
ReadInterpolationFilter(const Block & block,bool skip_mode)1317 void Tile::ReadInterpolationFilter(const Block& block, bool skip_mode) {
1318   BlockParameters& bp = *block.bp;
1319   if (frame_header_.interpolation_filter != kInterpolationFilterSwitchable) {
1320     static_assert(
1321         sizeof(bp.interpolation_filter) / sizeof(bp.interpolation_filter[0]) ==
1322             2,
1323         "Interpolation filter array size is not 2");
1324     for (auto& interpolation_filter : bp.interpolation_filter) {
1325       interpolation_filter = frame_header_.interpolation_filter;
1326     }
1327     return;
1328   }
1329   bool interpolation_filter_present = true;
1330   if (skip_mode ||
1331       block.bp->prediction_parameters->motion_mode == kMotionModeLocalWarp) {
1332     interpolation_filter_present = false;
1333   } else if (!IsBlockDimension4(block.size) &&
1334              bp.y_mode == kPredictionModeGlobalMv) {
1335     interpolation_filter_present =
1336         frame_header_.global_motion[bp.reference_frame[0]].type ==
1337         kGlobalMotionTransformationTypeTranslation;
1338   } else if (!IsBlockDimension4(block.size) &&
1339              bp.y_mode == kPredictionModeGlobalGlobalMv) {
1340     interpolation_filter_present =
1341         frame_header_.global_motion[bp.reference_frame[0]].type ==
1342             kGlobalMotionTransformationTypeTranslation ||
1343         frame_header_.global_motion[bp.reference_frame[1]].type ==
1344             kGlobalMotionTransformationTypeTranslation;
1345   }
1346   for (int i = 0; i < (sequence_header_.enable_dual_filter ? 2 : 1); ++i) {
1347     bp.interpolation_filter[i] =
1348         interpolation_filter_present
1349             ? static_cast<InterpolationFilter>(
1350                   reader_.ReadSymbol<kNumExplicitInterpolationFilters>(
1351                       GetInterpolationFilterCdf(block, i)))
1352             : kInterpolationFilterEightTap;
1353   }
1354   if (!sequence_header_.enable_dual_filter) {
1355     bp.interpolation_filter[1] = bp.interpolation_filter[0];
1356   }
1357 }
1358 
SetCdfContextCompoundType(const Block & block,bool is_explicit_compound_type,bool is_compound_type_average)1359 void Tile::SetCdfContextCompoundType(const Block& block,
1360                                      bool is_explicit_compound_type,
1361                                      bool is_compound_type_average) {
1362   memset(left_context_.is_explicit_compound_type + block.left_context_index,
1363          static_cast<int>(is_explicit_compound_type), block.height4x4);
1364   memset(left_context_.is_compound_type_average + block.left_context_index,
1365          static_cast<int>(is_compound_type_average), block.height4x4);
1366   memset(block.top_context->is_explicit_compound_type + block.top_context_index,
1367          static_cast<int>(is_explicit_compound_type), block.width4x4);
1368   memset(block.top_context->is_compound_type_average + block.top_context_index,
1369          static_cast<int>(is_compound_type_average), block.width4x4);
1370 }
1371 
ReadInterBlockModeInfo(const Block & block,bool skip_mode)1372 bool Tile::ReadInterBlockModeInfo(const Block& block, bool skip_mode) {
1373   BlockParameters& bp = *block.bp;
1374   bp.prediction_parameters->palette_mode_info.size[kPlaneTypeY] = 0;
1375   bp.prediction_parameters->palette_mode_info.size[kPlaneTypeUV] = 0;
1376   SetCdfContextPaletteSize(block);
1377   ReadReferenceFrames(block, skip_mode);
1378   const bool is_compound = bp.reference_frame[1] > kReferenceFrameIntra;
1379   MvContexts mode_contexts;
1380   FindMvStack(block, is_compound, &mode_contexts);
1381   ReadInterPredictionModeY(block, mode_contexts, skip_mode);
1382   ReadRefMvIndex(block);
1383   if (!AssignInterMv(block, is_compound)) return false;
1384   ReadInterIntraMode(block, is_compound, skip_mode);
1385   ReadMotionMode(block, is_compound, skip_mode);
1386   bool is_explicit_compound_type;
1387   bool is_compound_type_average;
1388   ReadCompoundType(block, is_compound, skip_mode, &is_explicit_compound_type,
1389                    &is_compound_type_average);
1390   SetCdfContextCompoundType(block, is_explicit_compound_type,
1391                             is_compound_type_average);
1392   ReadInterpolationFilter(block, skip_mode);
1393   return true;
1394 }
1395 
SetCdfContextSkipMode(const Block & block,bool skip_mode)1396 void Tile::SetCdfContextSkipMode(const Block& block, bool skip_mode) {
1397   memset(left_context_.skip_mode + block.left_context_index,
1398          static_cast<int>(skip_mode), block.height4x4);
1399   memset(block.top_context->skip_mode + block.top_context_index,
1400          static_cast<int>(skip_mode), block.width4x4);
1401 }
1402 
DecodeInterModeInfo(const Block & block)1403 bool Tile::DecodeInterModeInfo(const Block& block) {
1404   BlockParameters& bp = *block.bp;
1405   block.bp->prediction_parameters->use_intra_block_copy = false;
1406   bp.skip = false;
1407   if (!ReadInterSegmentId(block, /*pre_skip=*/true)) return false;
1408   bool skip_mode = ReadSkipMode(block);
1409   SetCdfContextSkipMode(block, skip_mode);
1410   if (skip_mode) {
1411     bp.skip = true;
1412   } else {
1413     ReadSkip(block);
1414   }
1415   if (!frame_header_.segmentation.segment_id_pre_skip &&
1416       !ReadInterSegmentId(block, /*pre_skip=*/false)) {
1417     return false;
1418   }
1419   ReadCdef(block);
1420   if (read_deltas_) {
1421     ReadQuantizerIndexDelta(block);
1422     ReadLoopFilterDelta(block);
1423     read_deltas_ = false;
1424   }
1425   ReadIsInter(block, skip_mode);
1426   return bp.is_inter ? ReadInterBlockModeInfo(block, skip_mode)
1427                      : ReadIntraBlockModeInfo(block, /*intra_y_mode=*/false);
1428 }
1429 
DecodeModeInfo(const Block & block)1430 bool Tile::DecodeModeInfo(const Block& block) {
1431   return IsIntraFrame(frame_header_.frame_type) ? DecodeIntraModeInfo(block)
1432                                                 : DecodeInterModeInfo(block);
1433 }
1434 
1435 }  // namespace libgav1
1436