1 // Copyright 2019 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <algorithm>
16 #include <array>
17 #include <cassert>
18 #include <cstdint>
19 #include <cstdlib>
20 #include <cstring>
21 #include <memory>
22 #include <vector>
23 
24 #include "src/buffer_pool.h"
25 #include "src/dsp/constants.h"
26 #include "src/motion_vector.h"
27 #include "src/obu_parser.h"
28 #include "src/prediction_mask.h"
29 #include "src/symbol_decoder_context.h"
30 #include "src/tile.h"
31 #include "src/utils/array_2d.h"
32 #include "src/utils/bit_mask_set.h"
33 #include "src/utils/block_parameters_holder.h"
34 #include "src/utils/common.h"
35 #include "src/utils/constants.h"
36 #include "src/utils/entropy_decoder.h"
37 #include "src/utils/logging.h"
38 #include "src/utils/segmentation.h"
39 #include "src/utils/segmentation_map.h"
40 #include "src/utils/types.h"
41 
42 namespace libgav1 {
43 namespace {
44 
45 constexpr int kDeltaQSmall = 3;
46 constexpr int kDeltaLfSmall = 3;
47 constexpr int kNoScale = 1 << kReferenceFrameScalePrecision;
48 
49 constexpr uint8_t kIntraYModeContext[kIntraPredictionModesY] = {
50     0, 1, 2, 3, 4, 4, 4, 4, 3, 0, 1, 2, 0};
51 
52 constexpr uint8_t kSizeGroup[kMaxBlockSizes] = {
53     0, 0, 0, 0, 1, 1, 1, 0, 1, 2, 2, 2, 1, 2, 3, 3, 2, 3, 3, 3, 3, 3};
54 
55 constexpr int kCompoundModeNewMvContexts = 5;
56 constexpr uint8_t kCompoundModeContextMap[3][kCompoundModeNewMvContexts] = {
57     {0, 1, 1, 1, 1}, {1, 2, 3, 4, 4}, {4, 4, 5, 6, 7}};
58 
59 enum CflSign : uint8_t {
60   kCflSignZero = 0,
61   kCflSignNegative = 1,
62   kCflSignPositive = 2
63 };
64 
65 // For each possible value of the combined signs (which is read from the
66 // bitstream), this array stores the following: sign_u, sign_v, alpha_u_context,
67 // alpha_v_context. Only positive entries are used. Entry at index i is computed
68 // as follows:
69 // sign_u = i / 3
70 // sign_v = i % 3
71 // alpha_u_context = i - 2
72 // alpha_v_context = (sign_v - 1) * 3 + sign_u
73 constexpr int8_t kCflAlphaLookup[kCflAlphaSignsSymbolCount][4] = {
74     {0, 1, -2, 0}, {0, 2, -1, 3}, {1, 0, 0, -2}, {1, 1, 1, 1},
75     {1, 2, 2, 4},  {2, 0, 3, -1}, {2, 1, 4, 2},  {2, 2, 5, 5},
76 };
77 
78 constexpr BitMaskSet kPredictionModeHasNearMvMask(kPredictionModeNearMv,
79                                                   kPredictionModeNearNearMv,
80                                                   kPredictionModeNearNewMv,
81                                                   kPredictionModeNewNearMv);
82 
83 constexpr BitMaskSet kIsInterIntraModeAllowedMask(kBlock8x8, kBlock8x16,
84                                                   kBlock16x8, kBlock16x16,
85                                                   kBlock16x32, kBlock32x16,
86                                                   kBlock32x32);
87 
IsBackwardReference(ReferenceFrameType type)88 bool IsBackwardReference(ReferenceFrameType type) {
89   return type >= kReferenceFrameBackward && type <= kReferenceFrameAlternate;
90 }
91 
IsSameDirectionReferencePair(ReferenceFrameType type1,ReferenceFrameType type2)92 bool IsSameDirectionReferencePair(ReferenceFrameType type1,
93                                   ReferenceFrameType type2) {
94   return (type1 >= kReferenceFrameBackward) ==
95          (type2 >= kReferenceFrameBackward);
96 }
97 
98 // This is called neg_deinterleave() in the spec.
DecodeSegmentId(int diff,int reference,int max)99 int DecodeSegmentId(int diff, int reference, int max) {
100   if (reference == 0) return diff;
101   if (reference >= max - 1) return max - diff - 1;
102   const int value = ((diff & 1) != 0) ? reference + ((diff + 1) >> 1)
103                                       : reference - (diff >> 1);
104   const int reference2 = (reference << 1);
105   if (reference2 < max) {
106     return (diff <= reference2) ? value : diff;
107   }
108   return (diff <= ((max - reference - 1) << 1)) ? value : max - (diff + 1);
109 }
110 
111 // This is called DrlCtxStack in section 7.10.2.14 of the spec.
112 // In the spec, the weights of all the nearest mvs are incremented by a bonus
113 // weight which is larger than any natural weight, and the weights of the mvs
114 // are compared with this bonus weight to determine their contexts. We replace
115 // this procedure by introducing |nearest_mv_count| in PredictionParameters,
116 // which records the count of the nearest mvs. Since all the nearest mvs are in
117 // the beginning of the mv stack, the |index| of a mv in the mv stack can be
118 // compared with |nearest_mv_count| to get that mv's context.
GetRefMvIndexContext(int nearest_mv_count,int index)119 int GetRefMvIndexContext(int nearest_mv_count, int index) {
120   if (index + 1 < nearest_mv_count) {
121     return 0;
122   }
123   if (index + 1 == nearest_mv_count) {
124     return 1;
125   }
126   return 2;
127 }
128 
129 // Returns true if both the width and height of the block is less than 64.
IsBlockDimensionLessThan64(BlockSize size)130 bool IsBlockDimensionLessThan64(BlockSize size) {
131   return size <= kBlock32x32 && size != kBlock16x64;
132 }
133 
134 }  // namespace
135 
ReadSegmentId(const Block & block)136 bool Tile::ReadSegmentId(const Block& block) {
137   int top_left = -1;
138   if (block.top_available[kPlaneY] && block.left_available[kPlaneY]) {
139     top_left =
140         block_parameters_holder_.Find(block.row4x4 - 1, block.column4x4 - 1)
141             ->segment_id;
142   }
143   int top = -1;
144   if (block.top_available[kPlaneY]) {
145     top = block.bp_top->segment_id;
146   }
147   int left = -1;
148   if (block.left_available[kPlaneY]) {
149     left = block.bp_left->segment_id;
150   }
151   int pred;
152   if (top == -1) {
153     pred = (left == -1) ? 0 : left;
154   } else if (left == -1) {
155     pred = top;
156   } else {
157     pred = (top_left == top) ? top : left;
158   }
159   BlockParameters& bp = *block.bp;
160   if (bp.skip) {
161     bp.segment_id = pred;
162     return true;
163   }
164   int context = 0;
165   if (top_left < 0) {
166     context = 0;
167   } else if (top_left == top && top_left == left) {
168     context = 2;
169   } else if (top_left == top || top_left == left || top == left) {
170     context = 1;
171   }
172   uint16_t* const segment_id_cdf =
173       symbol_decoder_context_.segment_id_cdf[context];
174   const int encoded_segment_id =
175       reader_.ReadSymbol<kMaxSegments>(segment_id_cdf);
176   bp.segment_id =
177       DecodeSegmentId(encoded_segment_id, pred,
178                       frame_header_.segmentation.last_active_segment_id + 1);
179   // Check the bitstream conformance requirement in Section 6.10.8 of the spec.
180   if (bp.segment_id < 0 ||
181       bp.segment_id > frame_header_.segmentation.last_active_segment_id) {
182     LIBGAV1_DLOG(
183         ERROR,
184         "Corrupted segment_ids: encoded %d, last active %d, postprocessed %d",
185         encoded_segment_id, frame_header_.segmentation.last_active_segment_id,
186         bp.segment_id);
187     return false;
188   }
189   return true;
190 }
191 
ReadIntraSegmentId(const Block & block)192 bool Tile::ReadIntraSegmentId(const Block& block) {
193   BlockParameters& bp = *block.bp;
194   if (!frame_header_.segmentation.enabled) {
195     bp.segment_id = 0;
196     return true;
197   }
198   return ReadSegmentId(block);
199 }
200 
ReadSkip(const Block & block)201 void Tile::ReadSkip(const Block& block) {
202   BlockParameters& bp = *block.bp;
203   if (frame_header_.segmentation.segment_id_pre_skip &&
204       frame_header_.segmentation.FeatureActive(bp.segment_id,
205                                                kSegmentFeatureSkip)) {
206     bp.skip = true;
207     return;
208   }
209   int context = 0;
210   if (block.top_available[kPlaneY] && block.bp_top->skip) {
211     ++context;
212   }
213   if (block.left_available[kPlaneY] && block.bp_left->skip) {
214     ++context;
215   }
216   uint16_t* const skip_cdf = symbol_decoder_context_.skip_cdf[context];
217   bp.skip = reader_.ReadSymbol(skip_cdf);
218 }
219 
ReadSkipMode(const Block & block)220 void Tile::ReadSkipMode(const Block& block) {
221   BlockParameters& bp = *block.bp;
222   if (!frame_header_.skip_mode_present ||
223       frame_header_.segmentation.FeatureActive(bp.segment_id,
224                                                kSegmentFeatureSkip) ||
225       frame_header_.segmentation.FeatureActive(bp.segment_id,
226                                                kSegmentFeatureReferenceFrame) ||
227       frame_header_.segmentation.FeatureActive(bp.segment_id,
228                                                kSegmentFeatureGlobalMv) ||
229       IsBlockDimension4(block.size)) {
230     bp.skip_mode = false;
231     return;
232   }
233   const int context =
234       (block.left_available[kPlaneY]
235            ? static_cast<int>(block.bp_left->skip_mode)
236            : 0) +
237       (block.top_available[kPlaneY] ? static_cast<int>(block.bp_top->skip_mode)
238                                     : 0);
239   bp.skip_mode =
240       reader_.ReadSymbol(symbol_decoder_context_.skip_mode_cdf[context]);
241 }
242 
ReadCdef(const Block & block)243 void Tile::ReadCdef(const Block& block) {
244   BlockParameters& bp = *block.bp;
245   if (bp.skip || frame_header_.coded_lossless ||
246       !sequence_header_.enable_cdef || frame_header_.allow_intrabc) {
247     return;
248   }
249   const int cdef_size4x4 = kNum4x4BlocksWide[kBlock64x64];
250   const int cdef_mask4x4 = ~(cdef_size4x4 - 1);
251   const int row4x4 = block.row4x4 & cdef_mask4x4;
252   const int column4x4 = block.column4x4 & cdef_mask4x4;
253   const int row = DivideBy16(row4x4);
254   const int column = DivideBy16(column4x4);
255   if (cdef_index_[row][column] == -1) {
256     cdef_index_[row][column] =
257         frame_header_.cdef.bits > 0
258             ? static_cast<int16_t>(reader_.ReadLiteral(frame_header_.cdef.bits))
259             : 0;
260     for (int i = row4x4; i < row4x4 + block.height4x4; i += cdef_size4x4) {
261       for (int j = column4x4; j < column4x4 + block.width4x4;
262            j += cdef_size4x4) {
263         cdef_index_[DivideBy16(i)][DivideBy16(j)] = cdef_index_[row][column];
264       }
265     }
266   }
267 }
268 
ReadAndClipDelta(uint16_t * const cdf,int delta_small,int scale,int min_value,int max_value,int value)269 int Tile::ReadAndClipDelta(uint16_t* const cdf, int delta_small, int scale,
270                            int min_value, int max_value, int value) {
271   int abs = reader_.ReadSymbol<kDeltaSymbolCount>(cdf);
272   if (abs == delta_small) {
273     const int remaining_bit_count =
274         static_cast<int>(reader_.ReadLiteral(3)) + 1;
275     const int abs_remaining_bits =
276         static_cast<int>(reader_.ReadLiteral(remaining_bit_count));
277     abs = abs_remaining_bits + (1 << remaining_bit_count) + 1;
278   }
279   if (abs != 0) {
280     const bool sign = static_cast<bool>(reader_.ReadBit());
281     const int scaled_abs = abs << scale;
282     const int reduced_delta = sign ? -scaled_abs : scaled_abs;
283     value += reduced_delta;
284     value = Clip3(value, min_value, max_value);
285   }
286   return value;
287 }
288 
ReadQuantizerIndexDelta(const Block & block)289 void Tile::ReadQuantizerIndexDelta(const Block& block) {
290   assert(read_deltas_);
291   BlockParameters& bp = *block.bp;
292   if ((block.size == SuperBlockSize() && bp.skip)) {
293     return;
294   }
295   current_quantizer_index_ =
296       ReadAndClipDelta(symbol_decoder_context_.delta_q_cdf, kDeltaQSmall,
297                        frame_header_.delta_q.scale, kMinLossyQuantizer,
298                        kMaxQuantizer, current_quantizer_index_);
299 }
300 
ReadLoopFilterDelta(const Block & block)301 void Tile::ReadLoopFilterDelta(const Block& block) {
302   assert(read_deltas_);
303   BlockParameters& bp = *block.bp;
304   if (!frame_header_.delta_lf.present ||
305       (block.size == SuperBlockSize() && bp.skip)) {
306     return;
307   }
308   int frame_lf_count = 1;
309   if (frame_header_.delta_lf.multi) {
310     frame_lf_count = kFrameLfCount - (PlaneCount() > 1 ? 0 : 2);
311   }
312   bool recompute_deblock_filter_levels = false;
313   for (int i = 0; i < frame_lf_count; ++i) {
314     uint16_t* const delta_lf_abs_cdf =
315         frame_header_.delta_lf.multi
316             ? symbol_decoder_context_.delta_lf_multi_cdf[i]
317             : symbol_decoder_context_.delta_lf_cdf;
318     const int8_t old_delta_lf = delta_lf_[i];
319     delta_lf_[i] = ReadAndClipDelta(
320         delta_lf_abs_cdf, kDeltaLfSmall, frame_header_.delta_lf.scale,
321         -kMaxLoopFilterValue, kMaxLoopFilterValue, delta_lf_[i]);
322     recompute_deblock_filter_levels =
323         recompute_deblock_filter_levels || (old_delta_lf != delta_lf_[i]);
324   }
325   delta_lf_all_zero_ =
326       (delta_lf_[0] | delta_lf_[1] | delta_lf_[2] | delta_lf_[3]) == 0;
327   if (!delta_lf_all_zero_ && recompute_deblock_filter_levels) {
328     post_filter_.ComputeDeblockFilterLevels(delta_lf_, deblock_filter_levels_);
329   }
330 }
331 
ReadPredictionModeY(const Block & block,bool intra_y_mode)332 void Tile::ReadPredictionModeY(const Block& block, bool intra_y_mode) {
333   uint16_t* cdf;
334   if (intra_y_mode) {
335     const PredictionMode top_mode =
336         block.top_available[kPlaneY] ? block.bp_top->y_mode : kPredictionModeDc;
337     const PredictionMode left_mode = block.left_available[kPlaneY]
338                                          ? block.bp_left->y_mode
339                                          : kPredictionModeDc;
340     const int top_context = kIntraYModeContext[top_mode];
341     const int left_context = kIntraYModeContext[left_mode];
342     cdf = symbol_decoder_context_
343               .intra_frame_y_mode_cdf[top_context][left_context];
344   } else {
345     cdf = symbol_decoder_context_.y_mode_cdf[kSizeGroup[block.size]];
346   }
347   block.bp->y_mode = static_cast<PredictionMode>(
348       reader_.ReadSymbol<kIntraPredictionModesY>(cdf));
349 }
350 
ReadIntraAngleInfo(const Block & block,PlaneType plane_type)351 void Tile::ReadIntraAngleInfo(const Block& block, PlaneType plane_type) {
352   BlockParameters& bp = *block.bp;
353   PredictionParameters& prediction_parameters =
354       *block.bp->prediction_parameters;
355   prediction_parameters.angle_delta[plane_type] = 0;
356   const PredictionMode mode =
357       (plane_type == kPlaneTypeY) ? bp.y_mode : bp.uv_mode;
358   if (IsBlockSmallerThan8x8(block.size) || !IsDirectionalMode(mode)) return;
359   uint16_t* const cdf =
360       symbol_decoder_context_.angle_delta_cdf[mode - kPredictionModeVertical];
361   prediction_parameters.angle_delta[plane_type] =
362       reader_.ReadSymbol<kAngleDeltaSymbolCount>(cdf);
363   prediction_parameters.angle_delta[plane_type] -= kMaxAngleDelta;
364 }
365 
ReadCflAlpha(const Block & block)366 void Tile::ReadCflAlpha(const Block& block) {
367   const int signs = reader_.ReadSymbol<kCflAlphaSignsSymbolCount>(
368       symbol_decoder_context_.cfl_alpha_signs_cdf);
369   const int8_t* const cfl_lookup = kCflAlphaLookup[signs];
370   const auto sign_u = static_cast<CflSign>(cfl_lookup[0]);
371   const auto sign_v = static_cast<CflSign>(cfl_lookup[1]);
372   PredictionParameters& prediction_parameters =
373       *block.bp->prediction_parameters;
374   prediction_parameters.cfl_alpha_u = 0;
375   if (sign_u != kCflSignZero) {
376     assert(cfl_lookup[2] >= 0);
377     prediction_parameters.cfl_alpha_u =
378         reader_.ReadSymbol<kCflAlphaSymbolCount>(
379             symbol_decoder_context_.cfl_alpha_cdf[cfl_lookup[2]]) +
380         1;
381     if (sign_u == kCflSignNegative) prediction_parameters.cfl_alpha_u *= -1;
382   }
383   prediction_parameters.cfl_alpha_v = 0;
384   if (sign_v != kCflSignZero) {
385     assert(cfl_lookup[3] >= 0);
386     prediction_parameters.cfl_alpha_v =
387         reader_.ReadSymbol<kCflAlphaSymbolCount>(
388             symbol_decoder_context_.cfl_alpha_cdf[cfl_lookup[3]]) +
389         1;
390     if (sign_v == kCflSignNegative) prediction_parameters.cfl_alpha_v *= -1;
391   }
392 }
393 
ReadPredictionModeUV(const Block & block)394 void Tile::ReadPredictionModeUV(const Block& block) {
395   BlockParameters& bp = *block.bp;
396   bool chroma_from_luma_allowed;
397   if (frame_header_.segmentation.lossless[bp.segment_id]) {
398     chroma_from_luma_allowed = block.residual_size[kPlaneU] == kBlock4x4;
399   } else {
400     chroma_from_luma_allowed = IsBlockDimensionLessThan64(block.size);
401   }
402   uint16_t* const cdf =
403       symbol_decoder_context_
404           .uv_mode_cdf[static_cast<int>(chroma_from_luma_allowed)][bp.y_mode];
405   if (chroma_from_luma_allowed) {
406     bp.uv_mode = static_cast<PredictionMode>(
407         reader_.ReadSymbol<kIntraPredictionModesUV>(cdf));
408   } else {
409     bp.uv_mode = static_cast<PredictionMode>(
410         reader_.ReadSymbol<kIntraPredictionModesUV - 1>(cdf));
411   }
412 }
413 
ReadMotionVectorComponent(const Block & block,const int component)414 int Tile::ReadMotionVectorComponent(const Block& block, const int component) {
415   const int context =
416       static_cast<int>(block.bp->prediction_parameters->use_intra_block_copy);
417   const bool sign = reader_.ReadSymbol(
418       symbol_decoder_context_.mv_sign_cdf[component][context]);
419   const int mv_class = reader_.ReadSymbol<kMvClassSymbolCount>(
420       symbol_decoder_context_.mv_class_cdf[component][context]);
421   int magnitude = 1;
422   int value;
423   uint16_t* fraction_cdf;
424   uint16_t* precision_cdf;
425   if (mv_class == 0) {
426     value = static_cast<int>(reader_.ReadSymbol(
427         symbol_decoder_context_.mv_class0_bit_cdf[component][context]));
428     fraction_cdf = symbol_decoder_context_
429                        .mv_class0_fraction_cdf[component][context][value];
430     precision_cdf = symbol_decoder_context_
431                         .mv_class0_high_precision_cdf[component][context];
432   } else {
433     assert(mv_class <= kMvBitSymbolCount);
434     value = 0;
435     for (int i = 0; i < mv_class; ++i) {
436       const int bit = static_cast<int>(reader_.ReadSymbol(
437           symbol_decoder_context_.mv_bit_cdf[component][context][i]));
438       value |= bit << i;
439     }
440     magnitude += 2 << (mv_class + 2);
441     fraction_cdf = symbol_decoder_context_.mv_fraction_cdf[component][context];
442     precision_cdf =
443         symbol_decoder_context_.mv_high_precision_cdf[component][context];
444   }
445   const int fraction =
446       (frame_header_.force_integer_mv == 0)
447           ? reader_.ReadSymbol<kMvFractionSymbolCount>(fraction_cdf)
448           : 3;
449   const int precision =
450       frame_header_.allow_high_precision_mv
451           ? static_cast<int>(reader_.ReadSymbol(precision_cdf))
452           : 1;
453   magnitude += (value << 3) | (fraction << 1) | precision;
454   return sign ? -magnitude : magnitude;
455 }
456 
ReadMotionVector(const Block & block,int index)457 void Tile::ReadMotionVector(const Block& block, int index) {
458   BlockParameters& bp = *block.bp;
459   const int context =
460       static_cast<int>(block.bp->prediction_parameters->use_intra_block_copy);
461   const auto mv_joint = static_cast<MvJointType>(
462       reader_.ReadSymbol(symbol_decoder_context_.mv_joint_cdf[context],
463                          static_cast<int>(kNumMvJointTypes)));
464   if (mv_joint == kMvJointTypeHorizontalZeroVerticalNonZero ||
465       mv_joint == kMvJointTypeNonZero) {
466     bp.mv.mv[index].mv[0] = ReadMotionVectorComponent(block, 0);
467   }
468   if (mv_joint == kMvJointTypeHorizontalNonZeroVerticalZero ||
469       mv_joint == kMvJointTypeNonZero) {
470     bp.mv.mv[index].mv[1] = ReadMotionVectorComponent(block, 1);
471   }
472 }
473 
ReadFilterIntraModeInfo(const Block & block)474 void Tile::ReadFilterIntraModeInfo(const Block& block) {
475   BlockParameters& bp = *block.bp;
476   PredictionParameters& prediction_parameters =
477       *block.bp->prediction_parameters;
478   prediction_parameters.use_filter_intra = false;
479   if (!sequence_header_.enable_filter_intra || bp.y_mode != kPredictionModeDc ||
480       bp.palette_mode_info.size[kPlaneTypeY] != 0 ||
481       !IsBlockDimensionLessThan64(block.size)) {
482     return;
483   }
484   prediction_parameters.use_filter_intra = reader_.ReadSymbol(
485       symbol_decoder_context_.use_filter_intra_cdf[block.size]);
486   if (prediction_parameters.use_filter_intra) {
487     prediction_parameters.filter_intra_mode = static_cast<FilterIntraPredictor>(
488         reader_.ReadSymbol<kNumFilterIntraPredictors>(
489             symbol_decoder_context_.filter_intra_mode_cdf));
490   }
491 }
492 
DecodeIntraModeInfo(const Block & block)493 bool Tile::DecodeIntraModeInfo(const Block& block) {
494   BlockParameters& bp = *block.bp;
495   bp.skip = false;
496   if (frame_header_.segmentation.segment_id_pre_skip &&
497       !ReadIntraSegmentId(block)) {
498     return false;
499   }
500   bp.skip_mode = false;
501   ReadSkip(block);
502   if (!frame_header_.segmentation.segment_id_pre_skip &&
503       !ReadIntraSegmentId(block)) {
504     return false;
505   }
506   ReadCdef(block);
507   if (read_deltas_) {
508     ReadQuantizerIndexDelta(block);
509     ReadLoopFilterDelta(block);
510     read_deltas_ = false;
511   }
512   PredictionParameters& prediction_parameters =
513       *block.bp->prediction_parameters;
514   prediction_parameters.use_intra_block_copy = false;
515   if (frame_header_.allow_intrabc) {
516     prediction_parameters.use_intra_block_copy =
517         reader_.ReadSymbol(symbol_decoder_context_.intra_block_copy_cdf);
518   }
519   if (prediction_parameters.use_intra_block_copy) {
520     bp.is_inter = true;
521     bp.reference_frame[0] = kReferenceFrameIntra;
522     bp.reference_frame[1] = kReferenceFrameNone;
523     bp.y_mode = kPredictionModeDc;
524     bp.uv_mode = kPredictionModeDc;
525     prediction_parameters.motion_mode = kMotionModeSimple;
526     prediction_parameters.compound_prediction_type =
527         kCompoundPredictionTypeAverage;
528     bp.palette_mode_info.size[kPlaneTypeY] = 0;
529     bp.palette_mode_info.size[kPlaneTypeUV] = 0;
530     bp.interpolation_filter[0] = kInterpolationFilterBilinear;
531     bp.interpolation_filter[1] = kInterpolationFilterBilinear;
532     MvContexts dummy_mode_contexts;
533     FindMvStack(block, /*is_compound=*/false, &dummy_mode_contexts);
534     return AssignIntraMv(block);
535   }
536   bp.is_inter = false;
537   return ReadIntraBlockModeInfo(block, /*intra_y_mode=*/true);
538 }
539 
ComputePredictedSegmentId(const Block & block) const540 int8_t Tile::ComputePredictedSegmentId(const Block& block) const {
541   // If prev_segment_ids_ is null, treat it as if it pointed to a segmentation
542   // map containing all 0s.
543   if (prev_segment_ids_ == nullptr) return 0;
544 
545   const int x_limit = std::min(frame_header_.columns4x4 - block.column4x4,
546                                static_cast<int>(block.width4x4));
547   const int y_limit = std::min(frame_header_.rows4x4 - block.row4x4,
548                                static_cast<int>(block.height4x4));
549   int8_t id = 7;
550   for (int y = 0; y < y_limit; ++y) {
551     for (int x = 0; x < x_limit; ++x) {
552       const int8_t prev_segment_id =
553           prev_segment_ids_->segment_id(block.row4x4 + y, block.column4x4 + x);
554       id = std::min(id, prev_segment_id);
555     }
556   }
557   return id;
558 }
559 
ReadInterSegmentId(const Block & block,bool pre_skip)560 bool Tile::ReadInterSegmentId(const Block& block, bool pre_skip) {
561   BlockParameters& bp = *block.bp;
562   if (!frame_header_.segmentation.enabled) {
563     bp.segment_id = 0;
564     return true;
565   }
566   if (!frame_header_.segmentation.update_map) {
567     bp.segment_id = ComputePredictedSegmentId(block);
568     return true;
569   }
570   if (pre_skip) {
571     if (!frame_header_.segmentation.segment_id_pre_skip) {
572       bp.segment_id = 0;
573       return true;
574     }
575   } else if (bp.skip) {
576     bp.use_predicted_segment_id = false;
577     return ReadSegmentId(block);
578   }
579   if (frame_header_.segmentation.temporal_update) {
580     const int context =
581         (block.left_available[kPlaneY]
582              ? static_cast<int>(block.bp_left->use_predicted_segment_id)
583              : 0) +
584         (block.top_available[kPlaneY]
585              ? static_cast<int>(block.bp_top->use_predicted_segment_id)
586              : 0);
587     bp.use_predicted_segment_id = reader_.ReadSymbol(
588         symbol_decoder_context_.use_predicted_segment_id_cdf[context]);
589     if (bp.use_predicted_segment_id) {
590       bp.segment_id = ComputePredictedSegmentId(block);
591       return true;
592     }
593   }
594   return ReadSegmentId(block);
595 }
596 
ReadIsInter(const Block & block)597 void Tile::ReadIsInter(const Block& block) {
598   BlockParameters& bp = *block.bp;
599   if (bp.skip_mode) {
600     bp.is_inter = true;
601     return;
602   }
603   if (frame_header_.segmentation.FeatureActive(bp.segment_id,
604                                                kSegmentFeatureReferenceFrame)) {
605     bp.is_inter =
606         frame_header_.segmentation
607             .feature_data[bp.segment_id][kSegmentFeatureReferenceFrame] !=
608         kReferenceFrameIntra;
609     return;
610   }
611   if (frame_header_.segmentation.FeatureActive(bp.segment_id,
612                                                kSegmentFeatureGlobalMv)) {
613     bp.is_inter = true;
614     return;
615   }
616   int context = 0;
617   if (block.top_available[kPlaneY] && block.left_available[kPlaneY]) {
618     context = (block.IsTopIntra() && block.IsLeftIntra())
619                   ? 3
620                   : static_cast<int>(block.IsTopIntra() || block.IsLeftIntra());
621   } else if (block.top_available[kPlaneY] || block.left_available[kPlaneY]) {
622     context = 2 * static_cast<int>(block.top_available[kPlaneY]
623                                        ? block.IsTopIntra()
624                                        : block.IsLeftIntra());
625   }
626   bp.is_inter =
627       reader_.ReadSymbol(symbol_decoder_context_.is_inter_cdf[context]);
628 }
629 
ReadIntraBlockModeInfo(const Block & block,bool intra_y_mode)630 bool Tile::ReadIntraBlockModeInfo(const Block& block, bool intra_y_mode) {
631   BlockParameters& bp = *block.bp;
632   bp.reference_frame[0] = kReferenceFrameIntra;
633   bp.reference_frame[1] = kReferenceFrameNone;
634   ReadPredictionModeY(block, intra_y_mode);
635   ReadIntraAngleInfo(block, kPlaneTypeY);
636   if (block.HasChroma()) {
637     ReadPredictionModeUV(block);
638     if (bp.uv_mode == kPredictionModeChromaFromLuma) {
639       ReadCflAlpha(block);
640     }
641     ReadIntraAngleInfo(block, kPlaneTypeUV);
642   }
643   ReadPaletteModeInfo(block);
644   ReadFilterIntraModeInfo(block);
645   return true;
646 }
647 
GetUseCompoundReferenceContext(const Block & block)648 int Tile::GetUseCompoundReferenceContext(const Block& block) {
649   if (block.top_available[kPlaneY] && block.left_available[kPlaneY]) {
650     if (block.IsTopSingle() && block.IsLeftSingle()) {
651       return static_cast<int>(IsBackwardReference(block.TopReference(0))) ^
652              static_cast<int>(IsBackwardReference(block.LeftReference(0)));
653     }
654     if (block.IsTopSingle()) {
655       return 2 + static_cast<int>(IsBackwardReference(block.TopReference(0)) ||
656                                   block.IsTopIntra());
657     }
658     if (block.IsLeftSingle()) {
659       return 2 + static_cast<int>(IsBackwardReference(block.LeftReference(0)) ||
660                                   block.IsLeftIntra());
661     }
662     return 4;
663   }
664   if (block.top_available[kPlaneY]) {
665     return block.IsTopSingle()
666                ? static_cast<int>(IsBackwardReference(block.TopReference(0)))
667                : 3;
668   }
669   if (block.left_available[kPlaneY]) {
670     return block.IsLeftSingle()
671                ? static_cast<int>(IsBackwardReference(block.LeftReference(0)))
672                : 3;
673   }
674   return 1;
675 }
676 
ReadCompoundReferenceType(const Block & block)677 CompoundReferenceType Tile::ReadCompoundReferenceType(const Block& block) {
678   // compound and inter.
679   const bool top_comp_inter = block.top_available[kPlaneY] &&
680                               !block.IsTopIntra() && !block.IsTopSingle();
681   const bool left_comp_inter = block.left_available[kPlaneY] &&
682                                !block.IsLeftIntra() && !block.IsLeftSingle();
683   // unidirectional compound.
684   const bool top_uni_comp =
685       top_comp_inter && IsSameDirectionReferencePair(block.TopReference(0),
686                                                      block.TopReference(1));
687   const bool left_uni_comp =
688       left_comp_inter && IsSameDirectionReferencePair(block.LeftReference(0),
689                                                       block.LeftReference(1));
690   int context;
691   if (block.top_available[kPlaneY] && !block.IsTopIntra() &&
692       block.left_available[kPlaneY] && !block.IsLeftIntra()) {
693     const int same_direction = static_cast<int>(IsSameDirectionReferencePair(
694         block.TopReference(0), block.LeftReference(0)));
695     if (!top_comp_inter && !left_comp_inter) {
696       context = 1 + MultiplyBy2(same_direction);
697     } else if (!top_comp_inter) {
698       context = left_uni_comp ? 3 + same_direction : 1;
699     } else if (!left_comp_inter) {
700       context = top_uni_comp ? 3 + same_direction : 1;
701     } else {
702       if (!top_uni_comp && !left_uni_comp) {
703         context = 0;
704       } else if (!top_uni_comp || !left_uni_comp) {
705         context = 2;
706       } else {
707         context = 3 + static_cast<int>(
708                           (block.TopReference(0) == kReferenceFrameBackward) ==
709                           (block.LeftReference(0) == kReferenceFrameBackward));
710       }
711     }
712   } else if (block.top_available[kPlaneY] && block.left_available[kPlaneY]) {
713     if (top_comp_inter) {
714       context = 1 + MultiplyBy2(static_cast<int>(top_uni_comp));
715     } else if (left_comp_inter) {
716       context = 1 + MultiplyBy2(static_cast<int>(left_uni_comp));
717     } else {
718       context = 2;
719     }
720   } else if (top_comp_inter) {
721     context = MultiplyBy4(static_cast<int>(top_uni_comp));
722   } else if (left_comp_inter) {
723     context = MultiplyBy4(static_cast<int>(left_uni_comp));
724   } else {
725     context = 2;
726   }
727   return static_cast<CompoundReferenceType>(reader_.ReadSymbol(
728       symbol_decoder_context_.compound_reference_type_cdf[context]));
729 }
730 
GetReferenceContext(const Block & block,ReferenceFrameType type0_start,ReferenceFrameType type0_end,ReferenceFrameType type1_start,ReferenceFrameType type1_end) const731 int Tile::GetReferenceContext(const Block& block,
732                               ReferenceFrameType type0_start,
733                               ReferenceFrameType type0_end,
734                               ReferenceFrameType type1_start,
735                               ReferenceFrameType type1_end) const {
736   int count0 = 0;
737   int count1 = 0;
738   for (int type = type0_start; type <= type0_end; ++type) {
739     count0 += block.CountReferences(static_cast<ReferenceFrameType>(type));
740   }
741   for (int type = type1_start; type <= type1_end; ++type) {
742     count1 += block.CountReferences(static_cast<ReferenceFrameType>(type));
743   }
744   return (count0 < count1) ? 0 : (count0 == count1 ? 1 : 2);
745 }
746 
747 template <bool is_single, bool is_backward, int index>
GetReferenceCdf(const Block & block,CompoundReferenceType type)748 uint16_t* Tile::GetReferenceCdf(
749     const Block& block,
750     CompoundReferenceType type /*= kNumCompoundReferenceTypes*/) {
751   int context = 0;
752   if ((type == kCompoundReferenceUnidirectional && index == 0) ||
753       (is_single && index == 1)) {
754     // uni_comp_ref and single_ref_p1.
755     context =
756         GetReferenceContext(block, kReferenceFrameLast, kReferenceFrameGolden,
757                             kReferenceFrameBackward, kReferenceFrameAlternate);
758   } else if (type == kCompoundReferenceUnidirectional && index == 1) {
759     // uni_comp_ref_p1.
760     context =
761         GetReferenceContext(block, kReferenceFrameLast2, kReferenceFrameLast2,
762                             kReferenceFrameLast3, kReferenceFrameGolden);
763   } else if ((type == kCompoundReferenceUnidirectional && index == 2) ||
764              (type == kCompoundReferenceBidirectional && index == 2) ||
765              (is_single && index == 5)) {
766     // uni_comp_ref_p2, comp_ref_p2 and single_ref_p5.
767     context =
768         GetReferenceContext(block, kReferenceFrameLast3, kReferenceFrameLast3,
769                             kReferenceFrameGolden, kReferenceFrameGolden);
770   } else if ((type == kCompoundReferenceBidirectional && index == 0) ||
771              (is_single && index == 3)) {
772     // comp_ref and single_ref_p3.
773     context =
774         GetReferenceContext(block, kReferenceFrameLast, kReferenceFrameLast2,
775                             kReferenceFrameLast3, kReferenceFrameGolden);
776   } else if ((type == kCompoundReferenceBidirectional && index == 1) ||
777              (is_single && index == 4)) {
778     // comp_ref_p1 and single_ref_p4.
779     context =
780         GetReferenceContext(block, kReferenceFrameLast, kReferenceFrameLast,
781                             kReferenceFrameLast2, kReferenceFrameLast2);
782   } else if ((is_single && index == 2) || (is_backward && index == 0)) {
783     // single_ref_p2 and comp_bwdref.
784     context = GetReferenceContext(
785         block, kReferenceFrameBackward, kReferenceFrameAlternate2,
786         kReferenceFrameAlternate, kReferenceFrameAlternate);
787   } else if ((is_single && index == 6) || (is_backward && index == 1)) {
788     // single_ref_p6 and comp_bwdref_p1.
789     context = GetReferenceContext(
790         block, kReferenceFrameBackward, kReferenceFrameBackward,
791         kReferenceFrameAlternate2, kReferenceFrameAlternate2);
792   }
793   if (is_single) {
794     // The index parameter for single references is offset by one since the spec
795     // uses 1-based index for these elements.
796     return symbol_decoder_context_.single_reference_cdf[context][index - 1];
797   }
798   if (is_backward) {
799     return symbol_decoder_context_
800         .compound_backward_reference_cdf[context][index];
801   }
802   return symbol_decoder_context_.compound_reference_cdf[type][context][index];
803 }
804 
ReadReferenceFrames(const Block & block)805 void Tile::ReadReferenceFrames(const Block& block) {
806   BlockParameters& bp = *block.bp;
807   if (bp.skip_mode) {
808     bp.reference_frame[0] = frame_header_.skip_mode_frame[0];
809     bp.reference_frame[1] = frame_header_.skip_mode_frame[1];
810     return;
811   }
812   if (frame_header_.segmentation.FeatureActive(bp.segment_id,
813                                                kSegmentFeatureReferenceFrame)) {
814     bp.reference_frame[0] = static_cast<ReferenceFrameType>(
815         frame_header_.segmentation
816             .feature_data[bp.segment_id][kSegmentFeatureReferenceFrame]);
817     bp.reference_frame[1] = kReferenceFrameNone;
818     return;
819   }
820   if (frame_header_.segmentation.FeatureActive(bp.segment_id,
821                                                kSegmentFeatureSkip) ||
822       frame_header_.segmentation.FeatureActive(bp.segment_id,
823                                                kSegmentFeatureGlobalMv)) {
824     bp.reference_frame[0] = kReferenceFrameLast;
825     bp.reference_frame[1] = kReferenceFrameNone;
826     return;
827   }
828   const bool use_compound_reference =
829       frame_header_.reference_mode_select &&
830       std::min(block.width4x4, block.height4x4) >= 2 &&
831       reader_.ReadSymbol(symbol_decoder_context_.use_compound_reference_cdf
832                              [GetUseCompoundReferenceContext(block)]);
833   if (use_compound_reference) {
834     CompoundReferenceType reference_type = ReadCompoundReferenceType(block);
835     if (reference_type == kCompoundReferenceUnidirectional) {
836       // uni_comp_ref.
837       if (reader_.ReadSymbol(
838               GetReferenceCdf<false, false, 0>(block, reference_type))) {
839         bp.reference_frame[0] = kReferenceFrameBackward;
840         bp.reference_frame[1] = kReferenceFrameAlternate;
841         return;
842       }
843       // uni_comp_ref_p1.
844       if (!reader_.ReadSymbol(
845               GetReferenceCdf<false, false, 1>(block, reference_type))) {
846         bp.reference_frame[0] = kReferenceFrameLast;
847         bp.reference_frame[1] = kReferenceFrameLast2;
848         return;
849       }
850       // uni_comp_ref_p2.
851       if (reader_.ReadSymbol(
852               GetReferenceCdf<false, false, 2>(block, reference_type))) {
853         bp.reference_frame[0] = kReferenceFrameLast;
854         bp.reference_frame[1] = kReferenceFrameGolden;
855         return;
856       }
857       bp.reference_frame[0] = kReferenceFrameLast;
858       bp.reference_frame[1] = kReferenceFrameLast3;
859       return;
860     }
861     assert(reference_type == kCompoundReferenceBidirectional);
862     // comp_ref.
863     if (reader_.ReadSymbol(
864             GetReferenceCdf<false, false, 0>(block, reference_type))) {
865       // comp_ref_p2.
866       bp.reference_frame[0] =
867           reader_.ReadSymbol(
868               GetReferenceCdf<false, false, 2>(block, reference_type))
869               ? kReferenceFrameGolden
870               : kReferenceFrameLast3;
871     } else {
872       // comp_ref_p1.
873       bp.reference_frame[0] =
874           reader_.ReadSymbol(
875               GetReferenceCdf<false, false, 1>(block, reference_type))
876               ? kReferenceFrameLast2
877               : kReferenceFrameLast;
878     }
879     // comp_bwdref.
880     if (reader_.ReadSymbol(GetReferenceCdf<false, true, 0>(block))) {
881       bp.reference_frame[1] = kReferenceFrameAlternate;
882     } else {
883       // comp_bwdref_p1.
884       bp.reference_frame[1] =
885           reader_.ReadSymbol(GetReferenceCdf<false, true, 1>(block))
886               ? kReferenceFrameAlternate2
887               : kReferenceFrameBackward;
888     }
889     return;
890   }
891   assert(!use_compound_reference);
892   bp.reference_frame[1] = kReferenceFrameNone;
893   // single_ref_p1.
894   if (reader_.ReadSymbol(GetReferenceCdf<true, false, 1>(block))) {
895     // single_ref_p2.
896     if (reader_.ReadSymbol(GetReferenceCdf<true, false, 2>(block))) {
897       bp.reference_frame[0] = kReferenceFrameAlternate;
898       return;
899     }
900     // single_ref_p6.
901     bp.reference_frame[0] =
902         reader_.ReadSymbol(GetReferenceCdf<true, false, 6>(block))
903             ? kReferenceFrameAlternate2
904             : kReferenceFrameBackward;
905     return;
906   }
907   // single_ref_p3.
908   if (reader_.ReadSymbol(GetReferenceCdf<true, false, 3>(block))) {
909     // single_ref_p5.
910     bp.reference_frame[0] =
911         reader_.ReadSymbol(GetReferenceCdf<true, false, 5>(block))
912             ? kReferenceFrameGolden
913             : kReferenceFrameLast3;
914     return;
915   }
916   // single_ref_p4.
917   bp.reference_frame[0] =
918       reader_.ReadSymbol(GetReferenceCdf<true, false, 4>(block))
919           ? kReferenceFrameLast2
920           : kReferenceFrameLast;
921 }
922 
ReadInterPredictionModeY(const Block & block,const MvContexts & mode_contexts)923 void Tile::ReadInterPredictionModeY(const Block& block,
924                                     const MvContexts& mode_contexts) {
925   BlockParameters& bp = *block.bp;
926   if (bp.skip_mode) {
927     bp.y_mode = kPredictionModeNearestNearestMv;
928     return;
929   }
930   if (frame_header_.segmentation.FeatureActive(bp.segment_id,
931                                                kSegmentFeatureSkip) ||
932       frame_header_.segmentation.FeatureActive(bp.segment_id,
933                                                kSegmentFeatureGlobalMv)) {
934     bp.y_mode = kPredictionModeGlobalMv;
935     return;
936   }
937   if (bp.reference_frame[1] > kReferenceFrameIntra) {
938     const int idx0 = mode_contexts.reference_mv >> 1;
939     const int idx1 =
940         std::min(mode_contexts.new_mv, kCompoundModeNewMvContexts - 1);
941     const int context = kCompoundModeContextMap[idx0][idx1];
942     const int offset = reader_.ReadSymbol<kNumCompoundInterPredictionModes>(
943         symbol_decoder_context_.compound_prediction_mode_cdf[context]);
944     bp.y_mode =
945         static_cast<PredictionMode>(kPredictionModeNearestNearestMv + offset);
946     return;
947   }
948   // new_mv.
949   if (!reader_.ReadSymbol(
950           symbol_decoder_context_.new_mv_cdf[mode_contexts.new_mv])) {
951     bp.y_mode = kPredictionModeNewMv;
952     return;
953   }
954   // zero_mv.
955   if (!reader_.ReadSymbol(
956           symbol_decoder_context_.zero_mv_cdf[mode_contexts.zero_mv])) {
957     bp.y_mode = kPredictionModeGlobalMv;
958     return;
959   }
960   // ref_mv.
961   bp.y_mode =
962       reader_.ReadSymbol(
963           symbol_decoder_context_.reference_mv_cdf[mode_contexts.reference_mv])
964           ? kPredictionModeNearMv
965           : kPredictionModeNearestMv;
966 }
967 
ReadRefMvIndex(const Block & block)968 void Tile::ReadRefMvIndex(const Block& block) {
969   BlockParameters& bp = *block.bp;
970   PredictionParameters& prediction_parameters =
971       *block.bp->prediction_parameters;
972   prediction_parameters.ref_mv_index = 0;
973   if (bp.y_mode != kPredictionModeNewMv &&
974       bp.y_mode != kPredictionModeNewNewMv &&
975       !kPredictionModeHasNearMvMask.Contains(bp.y_mode)) {
976     return;
977   }
978   const int start =
979       static_cast<int>(kPredictionModeHasNearMvMask.Contains(bp.y_mode));
980   prediction_parameters.ref_mv_index = start;
981   for (int i = start; i < start + 2; ++i) {
982     if (prediction_parameters.ref_mv_count <= i + 1) break;
983     // drl_mode in the spec.
984     const bool ref_mv_index_bit = reader_.ReadSymbol(
985         symbol_decoder_context_.ref_mv_index_cdf[GetRefMvIndexContext(
986             prediction_parameters.nearest_mv_count, i)]);
987     prediction_parameters.ref_mv_index = i + static_cast<int>(ref_mv_index_bit);
988     if (!ref_mv_index_bit) return;
989   }
990 }
991 
ReadInterIntraMode(const Block & block,bool is_compound)992 void Tile::ReadInterIntraMode(const Block& block, bool is_compound) {
993   BlockParameters& bp = *block.bp;
994   PredictionParameters& prediction_parameters =
995       *block.bp->prediction_parameters;
996   prediction_parameters.inter_intra_mode = kNumInterIntraModes;
997   prediction_parameters.is_wedge_inter_intra = false;
998   if (bp.skip_mode || !sequence_header_.enable_interintra_compound ||
999       is_compound || !kIsInterIntraModeAllowedMask.Contains(block.size)) {
1000     return;
1001   }
1002   // kSizeGroup[block.size] is guaranteed to be non-zero because of the block
1003   // size constraint enforced in the above condition.
1004   assert(kSizeGroup[block.size] - 1 >= 0);
1005   if (!reader_.ReadSymbol(
1006           symbol_decoder_context_
1007               .is_inter_intra_cdf[kSizeGroup[block.size] - 1])) {
1008     prediction_parameters.inter_intra_mode = kNumInterIntraModes;
1009     return;
1010   }
1011   prediction_parameters.inter_intra_mode =
1012       static_cast<InterIntraMode>(reader_.ReadSymbol<kNumInterIntraModes>(
1013           symbol_decoder_context_
1014               .inter_intra_mode_cdf[kSizeGroup[block.size] - 1]));
1015   bp.reference_frame[1] = kReferenceFrameIntra;
1016   prediction_parameters.angle_delta[kPlaneTypeY] = 0;
1017   prediction_parameters.angle_delta[kPlaneTypeUV] = 0;
1018   prediction_parameters.use_filter_intra = false;
1019   prediction_parameters.is_wedge_inter_intra = reader_.ReadSymbol(
1020       symbol_decoder_context_.is_wedge_inter_intra_cdf[block.size]);
1021   if (!prediction_parameters.is_wedge_inter_intra) return;
1022   prediction_parameters.wedge_index =
1023       reader_.ReadSymbol<kWedgeIndexSymbolCount>(
1024           symbol_decoder_context_.wedge_index_cdf[block.size]);
1025   prediction_parameters.wedge_sign = 0;
1026 }
1027 
IsScaled(ReferenceFrameType type) const1028 bool Tile::IsScaled(ReferenceFrameType type) const {
1029   const int index =
1030       frame_header_.reference_frame_index[type - kReferenceFrameLast];
1031   const int x_scale = ((reference_frames_[index]->upscaled_width()
1032                         << kReferenceFrameScalePrecision) +
1033                        DivideBy2(frame_header_.width)) /
1034                       frame_header_.width;
1035   if (x_scale != kNoScale) return true;
1036   const int y_scale = ((reference_frames_[index]->frame_height()
1037                         << kReferenceFrameScalePrecision) +
1038                        DivideBy2(frame_header_.height)) /
1039                       frame_header_.height;
1040   return y_scale != kNoScale;
1041 }
1042 
ReadMotionMode(const Block & block,bool is_compound)1043 void Tile::ReadMotionMode(const Block& block, bool is_compound) {
1044   BlockParameters& bp = *block.bp;
1045   PredictionParameters& prediction_parameters =
1046       *block.bp->prediction_parameters;
1047   const auto global_motion_type =
1048       frame_header_.global_motion[bp.reference_frame[0]].type;
1049   if (bp.skip_mode || !frame_header_.is_motion_mode_switchable ||
1050       IsBlockDimension4(block.size) ||
1051       (frame_header_.force_integer_mv == 0 &&
1052        (bp.y_mode == kPredictionModeGlobalMv ||
1053         bp.y_mode == kPredictionModeGlobalGlobalMv) &&
1054        global_motion_type > kGlobalMotionTransformationTypeTranslation) ||
1055       is_compound || bp.reference_frame[1] == kReferenceFrameIntra ||
1056       !block.HasOverlappableCandidates()) {
1057     prediction_parameters.motion_mode = kMotionModeSimple;
1058     return;
1059   }
1060   prediction_parameters.num_warp_samples = 0;
1061   int num_samples_scanned = 0;
1062   memset(prediction_parameters.warp_estimate_candidates, 0,
1063          sizeof(prediction_parameters.warp_estimate_candidates));
1064   FindWarpSamples(block, &prediction_parameters.num_warp_samples,
1065                   &num_samples_scanned,
1066                   prediction_parameters.warp_estimate_candidates);
1067   if (frame_header_.force_integer_mv != 0 ||
1068       prediction_parameters.num_warp_samples == 0 ||
1069       !frame_header_.allow_warped_motion || IsScaled(bp.reference_frame[0])) {
1070     prediction_parameters.motion_mode =
1071         reader_.ReadSymbol(symbol_decoder_context_.use_obmc_cdf[block.size])
1072             ? kMotionModeObmc
1073             : kMotionModeSimple;
1074     return;
1075   }
1076   prediction_parameters.motion_mode =
1077       static_cast<MotionMode>(reader_.ReadSymbol<kNumMotionModes>(
1078           symbol_decoder_context_.motion_mode_cdf[block.size]));
1079 }
1080 
GetIsExplicitCompoundTypeCdf(const Block & block)1081 uint16_t* Tile::GetIsExplicitCompoundTypeCdf(const Block& block) {
1082   int context = 0;
1083   if (block.top_available[kPlaneY]) {
1084     if (!block.IsTopSingle()) {
1085       context += static_cast<int>(block.bp_top->is_explicit_compound_type);
1086     } else if (block.TopReference(0) == kReferenceFrameAlternate) {
1087       context += 3;
1088     }
1089   }
1090   if (block.left_available[kPlaneY]) {
1091     if (!block.IsLeftSingle()) {
1092       context += static_cast<int>(block.bp_left->is_explicit_compound_type);
1093     } else if (block.LeftReference(0) == kReferenceFrameAlternate) {
1094       context += 3;
1095     }
1096   }
1097   return symbol_decoder_context_.is_explicit_compound_type_cdf[std::min(
1098       context, kIsExplicitCompoundTypeContexts - 1)];
1099 }
1100 
GetIsCompoundTypeAverageCdf(const Block & block)1101 uint16_t* Tile::GetIsCompoundTypeAverageCdf(const Block& block) {
1102   const BlockParameters& bp = *block.bp;
1103   const int forward = std::abs(GetRelativeDistance(
1104       current_frame_.order_hint(bp.reference_frame[0]),
1105       frame_header_.order_hint, sequence_header_.order_hint_shift_bits));
1106   const int backward = std::abs(GetRelativeDistance(
1107       current_frame_.order_hint(bp.reference_frame[1]),
1108       frame_header_.order_hint, sequence_header_.order_hint_shift_bits));
1109   int context = (forward == backward) ? 3 : 0;
1110   if (block.top_available[kPlaneY]) {
1111     if (!block.IsTopSingle()) {
1112       context += static_cast<int>(block.bp_top->is_compound_type_average);
1113     } else if (block.TopReference(0) == kReferenceFrameAlternate) {
1114       ++context;
1115     }
1116   }
1117   if (block.left_available[kPlaneY]) {
1118     if (!block.IsLeftSingle()) {
1119       context += static_cast<int>(block.bp_left->is_compound_type_average);
1120     } else if (block.LeftReference(0) == kReferenceFrameAlternate) {
1121       ++context;
1122     }
1123   }
1124   return symbol_decoder_context_.is_compound_type_average_cdf[context];
1125 }
1126 
ReadCompoundType(const Block & block,bool is_compound)1127 void Tile::ReadCompoundType(const Block& block, bool is_compound) {
1128   BlockParameters& bp = *block.bp;
1129   bp.is_explicit_compound_type = false;
1130   bp.is_compound_type_average = true;
1131   PredictionParameters& prediction_parameters =
1132       *block.bp->prediction_parameters;
1133   if (bp.skip_mode) {
1134     prediction_parameters.compound_prediction_type =
1135         kCompoundPredictionTypeAverage;
1136     return;
1137   }
1138   if (is_compound) {
1139     if (sequence_header_.enable_masked_compound) {
1140       bp.is_explicit_compound_type =
1141           reader_.ReadSymbol(GetIsExplicitCompoundTypeCdf(block));
1142     }
1143     if (bp.is_explicit_compound_type) {
1144       if (kIsWedgeCompoundModeAllowed.Contains(block.size)) {
1145         // Only kCompoundPredictionTypeWedge and
1146         // kCompoundPredictionTypeDiffWeighted are signaled explicitly.
1147         prediction_parameters.compound_prediction_type =
1148             static_cast<CompoundPredictionType>(reader_.ReadSymbol(
1149                 symbol_decoder_context_.compound_type_cdf[block.size]));
1150       } else {
1151         prediction_parameters.compound_prediction_type =
1152             kCompoundPredictionTypeDiffWeighted;
1153       }
1154     } else {
1155       if (sequence_header_.enable_jnt_comp) {
1156         bp.is_compound_type_average =
1157             reader_.ReadSymbol(GetIsCompoundTypeAverageCdf(block));
1158         prediction_parameters.compound_prediction_type =
1159             bp.is_compound_type_average ? kCompoundPredictionTypeAverage
1160                                         : kCompoundPredictionTypeDistance;
1161       } else {
1162         prediction_parameters.compound_prediction_type =
1163             kCompoundPredictionTypeAverage;
1164         return;
1165       }
1166     }
1167     if (prediction_parameters.compound_prediction_type ==
1168         kCompoundPredictionTypeWedge) {
1169       prediction_parameters.wedge_index =
1170           reader_.ReadSymbol<kWedgeIndexSymbolCount>(
1171               symbol_decoder_context_.wedge_index_cdf[block.size]);
1172       prediction_parameters.wedge_sign = static_cast<int>(reader_.ReadBit());
1173     } else if (prediction_parameters.compound_prediction_type ==
1174                kCompoundPredictionTypeDiffWeighted) {
1175       prediction_parameters.mask_is_inverse =
1176           static_cast<bool>(reader_.ReadBit());
1177     }
1178     return;
1179   }
1180   if (prediction_parameters.inter_intra_mode != kNumInterIntraModes) {
1181     prediction_parameters.compound_prediction_type =
1182         prediction_parameters.is_wedge_inter_intra
1183             ? kCompoundPredictionTypeWedge
1184             : kCompoundPredictionTypeIntra;
1185     return;
1186   }
1187   prediction_parameters.compound_prediction_type =
1188       kCompoundPredictionTypeAverage;
1189 }
1190 
GetInterpolationFilterCdf(const Block & block,int direction)1191 uint16_t* Tile::GetInterpolationFilterCdf(const Block& block, int direction) {
1192   const BlockParameters& bp = *block.bp;
1193   int context = MultiplyBy8(direction) +
1194                 MultiplyBy4(static_cast<int>(bp.reference_frame[1] >
1195                                              kReferenceFrameIntra));
1196   int top_type = kNumExplicitInterpolationFilters;
1197   if (block.top_available[kPlaneY]) {
1198     if (block.bp_top->reference_frame[0] == bp.reference_frame[0] ||
1199         block.bp_top->reference_frame[1] == bp.reference_frame[0]) {
1200       top_type = block.bp_top->interpolation_filter[direction];
1201     }
1202   }
1203   int left_type = kNumExplicitInterpolationFilters;
1204   if (block.left_available[kPlaneY]) {
1205     if (block.bp_left->reference_frame[0] == bp.reference_frame[0] ||
1206         block.bp_left->reference_frame[1] == bp.reference_frame[0]) {
1207       left_type = block.bp_left->interpolation_filter[direction];
1208     }
1209   }
1210   if (left_type == top_type) {
1211     context += left_type;
1212   } else if (left_type == kNumExplicitInterpolationFilters) {
1213     context += top_type;
1214   } else if (top_type == kNumExplicitInterpolationFilters) {
1215     context += left_type;
1216   } else {
1217     context += kNumExplicitInterpolationFilters;
1218   }
1219   return symbol_decoder_context_.interpolation_filter_cdf[context];
1220 }
1221 
ReadInterpolationFilter(const Block & block)1222 void Tile::ReadInterpolationFilter(const Block& block) {
1223   BlockParameters& bp = *block.bp;
1224   if (frame_header_.interpolation_filter != kInterpolationFilterSwitchable) {
1225     static_assert(
1226         sizeof(bp.interpolation_filter) / sizeof(bp.interpolation_filter[0]) ==
1227             2,
1228         "Interpolation filter array size is not 2");
1229     for (auto& interpolation_filter : bp.interpolation_filter) {
1230       interpolation_filter = frame_header_.interpolation_filter;
1231     }
1232     return;
1233   }
1234   bool interpolation_filter_present = true;
1235   if (bp.skip_mode ||
1236       block.bp->prediction_parameters->motion_mode == kMotionModeLocalWarp) {
1237     interpolation_filter_present = false;
1238   } else if (!IsBlockDimension4(block.size) &&
1239              bp.y_mode == kPredictionModeGlobalMv) {
1240     interpolation_filter_present =
1241         frame_header_.global_motion[bp.reference_frame[0]].type ==
1242         kGlobalMotionTransformationTypeTranslation;
1243   } else if (!IsBlockDimension4(block.size) &&
1244              bp.y_mode == kPredictionModeGlobalGlobalMv) {
1245     interpolation_filter_present =
1246         frame_header_.global_motion[bp.reference_frame[0]].type ==
1247             kGlobalMotionTransformationTypeTranslation ||
1248         frame_header_.global_motion[bp.reference_frame[1]].type ==
1249             kGlobalMotionTransformationTypeTranslation;
1250   }
1251   for (int i = 0; i < (sequence_header_.enable_dual_filter ? 2 : 1); ++i) {
1252     bp.interpolation_filter[i] =
1253         interpolation_filter_present
1254             ? static_cast<InterpolationFilter>(
1255                   reader_.ReadSymbol<kNumExplicitInterpolationFilters>(
1256                       GetInterpolationFilterCdf(block, i)))
1257             : kInterpolationFilterEightTap;
1258   }
1259   if (!sequence_header_.enable_dual_filter) {
1260     bp.interpolation_filter[1] = bp.interpolation_filter[0];
1261   }
1262 }
1263 
ReadInterBlockModeInfo(const Block & block)1264 bool Tile::ReadInterBlockModeInfo(const Block& block) {
1265   BlockParameters& bp = *block.bp;
1266   bp.palette_mode_info.size[kPlaneTypeY] = 0;
1267   bp.palette_mode_info.size[kPlaneTypeUV] = 0;
1268   ReadReferenceFrames(block);
1269   const bool is_compound = bp.reference_frame[1] > kReferenceFrameIntra;
1270   MvContexts mode_contexts;
1271   FindMvStack(block, is_compound, &mode_contexts);
1272   ReadInterPredictionModeY(block, mode_contexts);
1273   ReadRefMvIndex(block);
1274   if (!AssignInterMv(block, is_compound)) return false;
1275   ReadInterIntraMode(block, is_compound);
1276   ReadMotionMode(block, is_compound);
1277   ReadCompoundType(block, is_compound);
1278   ReadInterpolationFilter(block);
1279   return true;
1280 }
1281 
DecodeInterModeInfo(const Block & block)1282 bool Tile::DecodeInterModeInfo(const Block& block) {
1283   BlockParameters& bp = *block.bp;
1284   block.bp->prediction_parameters->use_intra_block_copy = false;
1285   bp.skip = false;
1286   if (!ReadInterSegmentId(block, /*pre_skip=*/true)) return false;
1287   ReadSkipMode(block);
1288   if (bp.skip_mode) {
1289     bp.skip = true;
1290   } else {
1291     ReadSkip(block);
1292   }
1293   if (!frame_header_.segmentation.segment_id_pre_skip &&
1294       !ReadInterSegmentId(block, /*pre_skip=*/false)) {
1295     return false;
1296   }
1297   ReadCdef(block);
1298   if (read_deltas_) {
1299     ReadQuantizerIndexDelta(block);
1300     ReadLoopFilterDelta(block);
1301     read_deltas_ = false;
1302   }
1303   ReadIsInter(block);
1304   return bp.is_inter ? ReadInterBlockModeInfo(block)
1305                      : ReadIntraBlockModeInfo(block, /*intra_y_mode=*/false);
1306 }
1307 
DecodeModeInfo(const Block & block)1308 bool Tile::DecodeModeInfo(const Block& block) {
1309   return IsIntraFrame(frame_header_.frame_type) ? DecodeIntraModeInfo(block)
1310                                                 : DecodeInterModeInfo(block);
1311 }
1312 
1313 }  // namespace libgav1
1314