1 /*****************************************************************************
2  * This file is part of Kvazaar HEVC encoder.
3  *
4  * Copyright (c) 2021, Tampere University, ITU/ISO/IEC, project contributors
5  * All rights reserved.
6  *
7  * Redistribution and use in source and binary forms, with or without modification,
8  * are permitted provided that the following conditions are met:
9  *
10  * * Redistributions of source code must retain the above copyright notice, this
11  *   list of conditions and the following disclaimer.
12  *
13  * * Redistributions in binary form must reproduce the above copyright notice, this
14  *   list of conditions and the following disclaimer in the documentation and/or
15  *   other materials provided with the distribution.
16  *
17  * * Neither the name of the Tampere University or ITU/ISO/IEC nor the names of its
18  *   contributors may be used to endorse or promote products derived from
19  *   this software without specific prior written permission.
20  *
21  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
22  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
23  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
25  * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
26  * INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION HOWEVER CAUSED AND ON
28  * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30  * INCLUDING NEGLIGENCE OR OTHERWISE ARISING IN ANY WAY OUT OF THE USE OF THIS
31  ****************************************************************************/
32 
33 #include "search_inter.h"
34 
35 #include <limits.h>
36 #include <stdlib.h>
37 
38 #include "cabac.h"
39 #include "encoder.h"
40 #include "image.h"
41 #include "imagelist.h"
42 #include "inter.h"
43 #include "kvazaar.h"
44 #include "rdo.h"
45 #include "search.h"
46 #include "strategies/strategies-ipol.h"
47 #include "strategies/strategies-picture.h"
48 #include "transform.h"
49 #include "videoframe.h"
50 
51 typedef struct {
52   encoder_state_t *state;
53 
54   /**
55    * \brief Current frame
56    */
57   const kvz_picture *pic;
58   /**
59    * \brief Reference frame
60    */
61   const kvz_picture *ref;
62 
63   /**
64    * \brief Index of the reference frame
65    */
66   int32_t ref_idx;
67 
68   /**
69    * \brief Top-left corner of the PU
70    */
71   const vector2d_t origin;
72   int32_t width;
73   int32_t height;
74 
75   int16_t mv_cand[2][2];
76   inter_merge_cand_t merge_cand[MRG_MAX_NUM_CANDS];
77   int32_t num_merge_cand;
78 
79   kvz_mvd_cost_func *mvd_cost_func;
80 
81   /**
82    * \brief Best motion vector among the ones tested so far
83    */
84   vector2d_t best_mv;
85   /**
86    * \brief Cost of best_mv
87    */
88   uint32_t best_cost;
89   /**
90    * \brief Bit cost of best_mv
91    */
92   uint32_t best_bitcost;
93 
94   /**
95    * \brief Possible optimized SAD implementation for the width, leave as
96    *        NULL for arbitrary-width blocks
97    */
98   optimized_sad_func_ptr_t optimized_sad;
99 
100 } inter_search_info_t;
101 
102 
103 /**
104  * \return  True if referred block is within current tile.
105  */
fracmv_within_tile(const inter_search_info_t * info,int x,int y)106 static INLINE bool fracmv_within_tile(const inter_search_info_t *info, int x, int y)
107 {
108   const encoder_control_t *ctrl = info->state->encoder_control;
109 
110   const bool is_frac_luma   = x % 4 != 0 || y % 4 != 0;
111   const bool is_frac_chroma = x % 8 != 0 || y % 8 != 0;
112 
113   if (ctrl->cfg.owf && ctrl->cfg.wpp) {
114     // Check that the block does not reference pixels that are not final.
115 
116     // Margin as luma pixels.
117     int margin = 0;
118     if (is_frac_luma) {
119       // Fractional motion estimation needs up to 4 pixels outside the
120       // block.
121       margin = 4;
122     } else if (is_frac_chroma) {
123       // Odd chroma interpolation needs up to 2 luma pixels outside the
124       // block.
125       margin = 2;
126     }
127 
128     if (ctrl->cfg.sao_type) {
129       // Make sure we don't refer to pixels for which SAO reconstruction
130       // has not been done.
131       margin += SAO_DELAY_PX;
132     } else if (ctrl->cfg.deblock_enable) {
133       // Make sure we don't refer to pixels that have not been deblocked.
134       margin += DEBLOCK_DELAY_PX;
135     }
136 
137     // Coordinates of the top-left corner of the containing LCU.
138     const vector2d_t orig_lcu = {
139       .x = info->origin.x / LCU_WIDTH,
140       .y = info->origin.y / LCU_WIDTH,
141     };
142     // Difference between the coordinates of the LCU containing the
143     // bottom-left corner of the referenced block and the LCU containing
144     // this block.
145     const vector2d_t mv_lcu = {
146       ((info->origin.x + info->width  + margin) * 4 + x) / (LCU_WIDTH << 2) - orig_lcu.x,
147       ((info->origin.y + info->height + margin) * 4 + y) / (LCU_WIDTH << 2) - orig_lcu.y,
148     };
149 
150     if (mv_lcu.y > ctrl->max_inter_ref_lcu.down) {
151       return false;
152     }
153 
154     if (mv_lcu.x + mv_lcu.y >
155         ctrl->max_inter_ref_lcu.down + ctrl->max_inter_ref_lcu.right)
156     {
157       return false;
158     }
159   }
160 
161   if (ctrl->cfg.mv_constraint == KVZ_MV_CONSTRAIN_NONE) {
162     return true;
163   }
164 
165   // Margin as luma quater pixels.
166   int margin = 0;
167   if (ctrl->cfg.mv_constraint == KVZ_MV_CONSTRAIN_FRAME_AND_TILE_MARGIN) {
168     if (is_frac_luma) {
169       margin = 4 << 2;
170     } else if (is_frac_chroma) {
171       margin = 2 << 2;
172     }
173   }
174 
175   // TODO implement KVZ_MV_CONSTRAIN_FRAM and KVZ_MV_CONSTRAIN_TILE.
176   const vector2d_t abs_mv = {
177     info->origin.x * 4 + x,
178     info->origin.y * 4 + y,
179   };
180 
181   // Check that both margin constraints are satisfied.
182   const int from_right  =
183     (info->state->tile->frame->width  << 2) - (abs_mv.x + (info->width  << 2));
184   const int from_bottom =
185     (info->state->tile->frame->height << 2) - (abs_mv.y + (info->height << 2));
186 
187   return abs_mv.x >= margin &&
188          abs_mv.y >= margin &&
189          from_right >= margin &&
190          from_bottom >= margin;
191 }
192 
193 
194 /**
195  * \return  True if referred block is within current tile.
196  */
intmv_within_tile(const inter_search_info_t * info,int x,int y)197 static INLINE bool intmv_within_tile(const inter_search_info_t *info, int x, int y)
198 {
199   return fracmv_within_tile(info, x * 4, y * 4);
200 }
201 
202 
203 /**
204  * \brief Calculate cost for an integer motion vector.
205  *
206  * Updates info->best_mv, info->best_cost and info->best_bitcost to the new
207  * motion vector if it yields a lower cost than the current one.
208  *
209  * If the motion vector violates the MV constraints for tiles or WPP, the
210  * cost is not set.
211  *
212  * \return true if info->best_mv was changed, false otherwise
213  */
check_mv_cost(inter_search_info_t * info,int x,int y)214 static bool check_mv_cost(inter_search_info_t *info, int x, int y)
215 {
216   if (!intmv_within_tile(info, x, y)) return false;
217 
218   uint32_t bitcost = 0;
219   uint32_t cost = kvz_image_calc_sad(
220       info->pic,
221       info->ref,
222       info->origin.x,
223       info->origin.y,
224       info->state->tile->offset_x + info->origin.x + x,
225       info->state->tile->offset_y + info->origin.y + y,
226       info->width,
227       info->height,
228       info->optimized_sad
229   );
230 
231   if (cost >= info->best_cost) return false;
232 
233   cost += info->mvd_cost_func(
234       info->state,
235       x, y, 2,
236       info->mv_cand,
237       info->merge_cand,
238       info->num_merge_cand,
239       info->ref_idx,
240       &bitcost
241   );
242 
243   if (cost >= info->best_cost) return false;
244 
245   // Set to motion vector in quarter pixel precision.
246   info->best_mv.x = x * 4;
247   info->best_mv.y = y * 4;
248   info->best_cost = cost;
249   info->best_bitcost = bitcost;
250 
251   return true;
252 }
253 
254 
get_ep_ex_golomb_bitcost(unsigned symbol)255 static unsigned get_ep_ex_golomb_bitcost(unsigned symbol)
256 {
257   // Calculate 2 * log2(symbol + 2)
258 
259   unsigned bins = 0;
260   symbol += 2;
261   if (symbol >= 1 << 8) { bins += 16; symbol >>= 8; }
262   if (symbol >= 1 << 4) { bins += 8; symbol >>= 4; }
263   if (symbol >= 1 << 2) { bins += 4; symbol >>= 2; }
264   if (symbol >= 1 << 1) { bins += 2; }
265 
266   // TODO: It might be a good idea to put a small slope on this function to
267   // make sure any search function that follows the gradient heads towards
268   // a smaller MVD, but that would require fractinal costs and bits being
269   // used everywhere in inter search.
270   // return num_bins + 0.001 * symbol;
271 
272   return bins;
273 }
274 
275 
276 /**
277  * \brief Checks if mv is one of the merge candidates.
278  * \return true if found else return false
279  */
mv_in_merge(const inter_search_info_t * info,vector2d_t mv)280 static bool mv_in_merge(const inter_search_info_t *info, vector2d_t mv)
281 {
282   for (int i = 0; i < info->num_merge_cand; ++i) {
283     if (info->merge_cand[i].dir == 3) continue;
284     const vector2d_t merge_mv = {
285       (info->merge_cand[i].mv[info->merge_cand[i].dir - 1][0] + 2) >> 2,
286       (info->merge_cand[i].mv[info->merge_cand[i].dir - 1][1] + 2) >> 2
287     };
288     if (merge_mv.x == mv.x && merge_mv.y == mv.y) {
289       return true;
290     }
291   }
292   return false;
293 }
294 
295 
296 /**
297  * \brief Select starting point for integer motion estimation search.
298  *
299  * Checks the zero vector, extra_mv and merge candidates and updates
300  * info->best_mv to the best one.
301  */
select_starting_point(inter_search_info_t * info,vector2d_t extra_mv)302 static void select_starting_point(inter_search_info_t *info, vector2d_t extra_mv)
303 {
304   // Check the 0-vector, so we can ignore all 0-vectors in the merge cand list.
305   check_mv_cost(info, 0, 0);
306 
307   // Change to integer precision.
308   extra_mv.x >>= 2;
309   extra_mv.y >>= 2;
310 
311   // Check mv_in if it's not one of the merge candidates.
312   if ((extra_mv.x != 0 || extra_mv.y != 0) && !mv_in_merge(info, extra_mv)) {
313     check_mv_cost(info, extra_mv.x, extra_mv.y);
314   }
315 
316   // Go through candidates
317   for (unsigned i = 0; i < info->num_merge_cand; ++i) {
318     if (info->merge_cand[i].dir == 3) continue;
319 
320     int x = (info->merge_cand[i].mv[info->merge_cand[i].dir - 1][0] + 2) >> 2;
321     int y = (info->merge_cand[i].mv[info->merge_cand[i].dir - 1][1] + 2) >> 2;
322 
323     if (x == 0 && y == 0) continue;
324 
325     check_mv_cost(info, x, y);
326   }
327 }
328 
329 
get_mvd_coding_cost(const encoder_state_t * state,const cabac_data_t * cabac,const int32_t mvd_hor,const int32_t mvd_ver)330 static uint32_t get_mvd_coding_cost(const encoder_state_t *state,
331                                     const cabac_data_t* cabac,
332                                     const int32_t mvd_hor,
333                                     const int32_t mvd_ver)
334 {
335   unsigned bitcost = 0;
336   const vector2d_t abs_mvd = { abs(mvd_hor), abs(mvd_ver) };
337 
338   bitcost += get_ep_ex_golomb_bitcost(abs_mvd.x) << CTX_FRAC_BITS;
339   bitcost += get_ep_ex_golomb_bitcost(abs_mvd.y) << CTX_FRAC_BITS;
340 
341   // Round and shift back to integer bits.
342   return (bitcost + CTX_FRAC_HALF_BIT) >> CTX_FRAC_BITS;
343 }
344 
345 
select_mv_cand(const encoder_state_t * state,int16_t mv_cand[2][2],int32_t mv_x,int32_t mv_y,uint32_t * cost_out)346 static int select_mv_cand(const encoder_state_t *state,
347                           int16_t mv_cand[2][2],
348                           int32_t mv_x,
349                           int32_t mv_y,
350                           uint32_t *cost_out)
351 {
352   const bool same_cand =
353     (mv_cand[0][0] == mv_cand[1][0] && mv_cand[0][1] == mv_cand[1][1]);
354 
355   if (same_cand && !cost_out) {
356     // Pick the first one if both candidates are the same.
357     return 0;
358   }
359 
360   uint32_t (*mvd_coding_cost)(const encoder_state_t * const state,
361                               const cabac_data_t*,
362                               int32_t, int32_t);
363   if (state->encoder_control->cfg.mv_rdo) {
364     mvd_coding_cost = kvz_get_mvd_coding_cost_cabac;
365   } else {
366     mvd_coding_cost = get_mvd_coding_cost;
367   }
368 
369   uint32_t cand1_cost = mvd_coding_cost(
370       state, &state->cabac,
371       mv_x - mv_cand[0][0],
372       mv_y - mv_cand[0][1]);
373 
374   uint32_t cand2_cost;
375   if (same_cand) {
376     cand2_cost = cand1_cost;
377   } else {
378     cand2_cost = mvd_coding_cost(
379       state, &state->cabac,
380       mv_x - mv_cand[1][0],
381       mv_y - mv_cand[1][1]);
382   }
383 
384   if (cost_out) {
385     *cost_out = MIN(cand1_cost, cand2_cost);
386   }
387 
388   // Pick the second candidate if it has lower cost.
389   return cand2_cost < cand1_cost ? 1 : 0;
390 }
391 
392 
calc_mvd_cost(const encoder_state_t * state,int x,int y,int mv_shift,int16_t mv_cand[2][2],inter_merge_cand_t merge_cand[MRG_MAX_NUM_CANDS],int16_t num_cand,int32_t ref_idx,uint32_t * bitcost)393 static uint32_t calc_mvd_cost(const encoder_state_t *state,
394                               int x,
395                               int y,
396                               int mv_shift,
397                               int16_t mv_cand[2][2],
398                               inter_merge_cand_t merge_cand[MRG_MAX_NUM_CANDS],
399                               int16_t num_cand,
400                               int32_t ref_idx,
401                               uint32_t *bitcost)
402 {
403   uint32_t temp_bitcost = 0;
404   uint32_t merge_idx;
405   int8_t merged      = 0;
406 
407   x *= 1 << mv_shift;
408   y *= 1 << mv_shift;
409 
410   // Check every candidate to find a match
411   for(merge_idx = 0; merge_idx < (uint32_t)num_cand; merge_idx++) {
412     if (merge_cand[merge_idx].dir == 3) continue;
413     if (merge_cand[merge_idx].mv[merge_cand[merge_idx].dir - 1][0] == x &&
414         merge_cand[merge_idx].mv[merge_cand[merge_idx].dir - 1][1] == y &&
415         state->frame->ref_LX[merge_cand[merge_idx].dir - 1][
416           merge_cand[merge_idx].ref[merge_cand[merge_idx].dir - 1]
417         ] == ref_idx) {
418       temp_bitcost += merge_idx;
419       merged = 1;
420       break;
421     }
422   }
423 
424   // Check mvd cost only if mv is not merged
425   if (!merged) {
426     uint32_t mvd_cost = 0;
427     select_mv_cand(state, mv_cand, x, y, &mvd_cost);
428     temp_bitcost += mvd_cost;
429   }
430   *bitcost = temp_bitcost;
431   return temp_bitcost*(int32_t)(state->lambda_sqrt + 0.5);
432 }
433 
434 
early_terminate(inter_search_info_t * info)435 static bool early_terminate(inter_search_info_t *info)
436 {
437   static const vector2d_t small_hexbs[7] = {
438       { 0, -1 }, { -1, 0 }, { 0, 1 }, { 1, 0 },
439       { 0, -1 }, { -1, 0 }, { 0, 0 },
440   };
441 
442   vector2d_t mv = { info->best_mv.x >> 2, info->best_mv.y >> 2 };
443 
444   int first_index = 0;
445   int last_index = 3;
446 
447   for (int k = 0; k < 2; ++k) {
448     double threshold;
449     if (info->state->encoder_control->cfg.me_early_termination ==
450         KVZ_ME_EARLY_TERMINATION_SENSITIVE)
451     {
452       threshold = info->best_cost * 0.95;
453     } else {
454       threshold = info->best_cost;
455     }
456 
457     int best_index = 6;
458     for (int i = first_index; i <= last_index; i++) {
459       int x = mv.x + small_hexbs[i].x;
460       int y = mv.y + small_hexbs[i].y;
461 
462       if (check_mv_cost(info, x, y)) {
463         best_index = i;
464       }
465     }
466 
467     // Adjust the movement vector
468     mv.x += small_hexbs[best_index].x;
469     mv.y += small_hexbs[best_index].y;
470 
471     // If best match is not better than threshold, we stop the search.
472     if (info->best_cost >= threshold) {
473       return true;
474     }
475 
476     first_index = (best_index + 3) % 4;
477     last_index = first_index + 2;
478   }
479   return false;
480 }
481 
482 
kvz_tz_pattern_search(inter_search_info_t * info,unsigned pattern_type,const int iDist,vector2d_t mv,int * best_dist)483 void kvz_tz_pattern_search(inter_search_info_t *info,
484                            unsigned pattern_type,
485                            const int iDist,
486                            vector2d_t mv,
487                            int *best_dist)
488 {
489   assert(pattern_type < 4);
490 
491   //implemented search patterns
492   const vector2d_t pattern[4][8] = {
493       //diamond (8 points)
494       //[ ][ ][ ][ ][1][ ][ ][ ][ ]
495       //[ ][ ][ ][ ][ ][ ][ ][ ][ ]
496       //[ ][ ][8][ ][ ][ ][5][ ][ ]
497       //[ ][ ][ ][ ][ ][ ][ ][ ][ ]
498       //[4][ ][ ][ ][o][ ][ ][ ][2]
499       //[ ][ ][ ][ ][ ][ ][ ][ ][ ]
500       //[ ][ ][7][ ][ ][ ][6][ ][ ]
501       //[ ][ ][ ][ ][ ][ ][ ][ ][ ]
502       //[ ][ ][ ][ ][3][ ][ ][ ][ ]
503       {
504         { 0, iDist }, { iDist, 0 }, { 0, -iDist }, { -iDist, 0 },
505         { iDist / 2, iDist / 2 }, { iDist / 2, -iDist / 2 }, { -iDist / 2, -iDist / 2 }, { -iDist / 2, iDist / 2 }
506       },
507 
508       //square (8 points)
509       //[8][ ][ ][ ][1][ ][ ][ ][2]
510       //[ ][ ][ ][ ][ ][ ][ ][ ][ ]
511       //[ ][ ][ ][ ][ ][ ][ ][ ][ ]
512       //[ ][ ][ ][ ][ ][ ][ ][ ][ ]
513       //[7][ ][ ][ ][o][ ][ ][ ][3]
514       //[ ][ ][ ][ ][ ][ ][ ][ ][ ]
515       //[ ][ ][ ][ ][ ][ ][ ][ ][ ]
516       //[ ][ ][ ][ ][ ][ ][ ][ ][ ]
517       //[6][ ][ ][ ][5][ ][ ][ ][4]
518       {
519         { 0, iDist }, { iDist, iDist }, { iDist, 0 }, { iDist, -iDist }, { 0, -iDist },
520         { -iDist, -iDist }, { -iDist, 0 }, { -iDist, iDist }
521       },
522 
523       //octagon (8 points)
524       //[ ][ ][5][ ][ ][ ][1][ ][ ]
525       //[ ][ ][ ][ ][ ][ ][ ][ ][ ]
526       //[ ][ ][ ][ ][ ][ ][ ][ ][2]
527       //[4][ ][ ][ ][ ][ ][ ][ ][ ]
528       //[ ][ ][ ][ ][o][ ][ ][ ][ ]
529       //[ ][ ][ ][ ][ ][ ][ ][ ][ ]
530       //[8][ ][ ][ ][ ][ ][ ][ ][6]
531       //[ ][ ][ ][ ][ ][ ][ ][ ][ ]
532       //[ ][ ][7][ ][ ][ ][3][ ][ ]
533       {
534         { iDist / 2, iDist }, { iDist, iDist / 2 }, { iDist / 2, -iDist }, { -iDist, iDist / 2 },
535         { -iDist / 2, iDist }, { iDist, -iDist / 2 }, { -iDist / 2, -iDist }, { -iDist, -iDist / 2 }
536       },
537 
538       //hexagon (6 points)
539       //[ ][ ][5][ ][ ][ ][1][ ][ ]
540       //[ ][ ][ ][ ][ ][ ][ ][ ][ ]
541       //[ ][ ][ ][ ][ ][ ][ ][ ][ ]
542       //[ ][ ][ ][ ][ ][ ][ ][ ][ ]
543       //[4][ ][ ][ ][o][ ][ ][ ][2]
544       //[ ][ ][ ][ ][ ][ ][ ][ ][ ]
545       //[ ][ ][ ][ ][ ][ ][ ][ ][ ]
546       //[ ][ ][ ][ ][ ][ ][ ][ ][ ]
547       //[ ][ ][6][ ][ ][ ][3][ ][ ]
548       {
549         { iDist / 2, iDist }, { iDist, 0 }, { iDist / 2, -iDist }, { -iDist, 0 },
550         { iDist / 2, iDist }, { -iDist / 2, -iDist }, { 0, 0 }, { 0, 0 }
551       }
552   };
553 
554   // Set the number of points to be checked.
555   int n_points;
556   if (iDist == 1) {
557     switch (pattern_type) {
558       case 0:
559         n_points = 4;
560         break;
561       case 2:
562         n_points = 4;
563         break;
564       case 3:
565         n_points = 4;
566         break;
567       default:
568         n_points = 8;
569         break;
570     };
571   } else {
572     switch (pattern_type) {
573       case 3:
574         n_points = 6;
575         break;
576       default:
577         n_points = 8;
578         break;
579     };
580   }
581 
582   // Compute SAD values for all chosen points.
583   int best_index = -1;
584   for (int i = 0; i < n_points; i++) {
585     vector2d_t offset = pattern[pattern_type][i];
586     int x = mv.x + offset.x;
587     int y = mv.y + offset.y;
588 
589     if (check_mv_cost(info, x, y)) {
590       best_index = i;
591     }
592   }
593 
594   if (best_index >= 0) {
595     *best_dist = iDist;
596   }
597 }
598 
599 
kvz_tz_raster_search(inter_search_info_t * info,int iSearchRange,int iRaster)600 void kvz_tz_raster_search(inter_search_info_t *info,
601                           int iSearchRange,
602                           int iRaster)
603 {
604   const vector2d_t mv = { info->best_mv.x >> 2, info->best_mv.y >> 2 };
605 
606   //compute SAD values for every point in the iRaster downsampled version of the current search area
607   for (int y = iSearchRange; y >= -iSearchRange; y -= iRaster) {
608     for (int x = -iSearchRange; x <= iSearchRange; x += iRaster) {
609       check_mv_cost(info, mv.x + x, mv.y + y);
610     }
611   }
612 }
613 
614 
tz_search(inter_search_info_t * info,vector2d_t extra_mv)615 static void tz_search(inter_search_info_t *info, vector2d_t extra_mv)
616 {
617   //TZ parameters
618   const int iSearchRange = 96;  // search range for each stage
619   const int iRaster = 5;  // search distance limit and downsampling factor for step 3
620   const unsigned step2_type = 0;  // search patterns for steps 2 and 4
621   const unsigned step4_type = 0;
622   const bool use_raster_scan = false;  // enable step 3
623   const bool use_raster_refinement = false;  // enable step 4 mode 1
624   const bool use_star_refinement = true;   // enable step 4 mode 2 (only one mode will be executed)
625 
626   int best_dist = 0;
627   info->best_cost = UINT32_MAX;
628 
629   // Select starting point from among merge candidates. These should
630   // include both mv_cand vectors and (0, 0).
631   select_starting_point(info, extra_mv);
632 
633   // Check if we should stop search
634   if (info->state->encoder_control->cfg.me_early_termination &&
635       early_terminate(info))
636   {
637     return;
638   }
639 
640   vector2d_t start = { info->best_mv.x >> 2, info->best_mv.y >> 2 };
641 
642   // step 2, grid search
643   int rounds_without_improvement = 0;
644   for (int iDist = 1; iDist <= iSearchRange; iDist *= 2) {
645     kvz_tz_pattern_search(info, step2_type, iDist, start, &best_dist);
646 
647     // Break the loop if the last three rounds didn't produce a better MV.
648     if (best_dist != iDist) rounds_without_improvement++;
649     if (rounds_without_improvement >= 3) break;
650   }
651 
652   if (start.x != 0 || start.y != 0) {
653     // repeat step 2 starting from the zero MV
654     start.x = 0;
655     start.y = 0;
656     rounds_without_improvement = 0;
657     for (int iDist = 1; iDist <= iSearchRange/2; iDist *= 2) {
658       kvz_tz_pattern_search(info, step2_type, iDist, start, &best_dist);
659 
660       if (best_dist != iDist) rounds_without_improvement++;
661       if (rounds_without_improvement >= 3) break;
662     }
663   }
664 
665   //step 3, raster scan
666   if (use_raster_scan && best_dist > iRaster) {
667     best_dist = iRaster;
668     kvz_tz_raster_search(info, iSearchRange, iRaster);
669   }
670 
671   //step 4
672 
673   //raster refinement
674   if (use_raster_refinement && best_dist > 0) {
675     for (int iDist = best_dist >> 1; iDist > 0; iDist >>= 1) {
676       start.x = info->best_mv.x >> 2;
677       start.y = info->best_mv.y >> 2;
678       kvz_tz_pattern_search(info, step4_type, iDist, start, &best_dist);
679     }
680   }
681 
682   //star refinement (repeat step 2 for the current starting point)
683   while (use_star_refinement && best_dist > 0) {
684     best_dist = 0;
685     start.x = info->best_mv.x >> 2;
686     start.y = info->best_mv.y >> 2;
687     for (int iDist = 1; iDist <= iSearchRange; iDist *= 2) {
688       kvz_tz_pattern_search(info, step4_type, iDist, start, &best_dist);
689     }
690   }
691 }
692 
693 
694 /**
695  * \brief Do motion search using the HEXBS algorithm.
696  *
697  * \param info      search info
698  * \param extra_mv  extra motion vector to check
699  * \param steps     how many steps are done at maximum before exiting, does not affect the final step
700  *
701  * Motion vector is searched by first searching iteratively with the large
702  * hexagon pattern until the best match is at the center of the hexagon.
703  * As a final step a smaller hexagon is used to check the adjacent pixels.
704  *
705  * If a non 0,0 predicted motion vector predictor is given as extra_mv,
706  * the 0,0 vector is also tried. This is hoped to help in the case where
707  * the predicted motion vector is way off. In the future even more additional
708  * points like 0,0 might be used, such as vectors from top or left.
709  */
hexagon_search(inter_search_info_t * info,vector2d_t extra_mv,uint32_t steps)710 static void hexagon_search(inter_search_info_t *info, vector2d_t extra_mv, uint32_t steps)
711 {
712   // The start of the hexagonal pattern has been repeated at the end so that
713   // the indices between 1-6 can be used as the start of a 3-point list of new
714   // points to search.
715   //   6--1,7
716   //  /     \    =)
717   // 5   0  2,8
718   //  \     /
719   //   4---3
720   static const vector2d_t large_hexbs[9] = {
721       { 0, 0 },
722       { 1, -2 }, { 2, 0 }, { 1, 2 }, { -1, 2 }, { -2, 0 }, { -1, -2 },
723       { 1, -2 }, { 2, 0 }
724   };
725   // This is used as the last step of the hexagon search.
726   //   1
727   // 2 0 3
728   //   4
729   static const vector2d_t small_hexbs[9] = {
730       { 0, 0 },
731       { 0, -1 }, { -1, 0 }, { 1, 0 }, { 0, 1 },
732       { -1, -1 }, { 1, -1 }, { -1, 1 }, { 1, 1 }
733   };
734 
735   info->best_cost = UINT32_MAX;
736 
737   // Select starting point from among merge candidates. These should
738   // include both mv_cand vectors and (0, 0).
739   select_starting_point(info, extra_mv);
740 
741   // Check if we should stop search
742   if (info->state->encoder_control->cfg.me_early_termination &&
743       early_terminate(info))
744   {
745     return;
746   }
747 
748   vector2d_t mv = { info->best_mv.x >> 2, info->best_mv.y >> 2 };
749 
750   // Current best index, either to merge_cands, large_hexbs or small_hexbs.
751   int best_index = 0;
752 
753   // Search the initial 7 points of the hexagon.
754   for (int i = 1; i < 7; ++i) {
755     if (check_mv_cost(info, mv.x + large_hexbs[i].x, mv.y + large_hexbs[i].y)) {
756       best_index = i;
757     }
758   }
759 
760   // Iteratively search the 3 new points around the best match, until the best
761   // match is in the center.
762   while (best_index != 0 && steps != 0) {
763     // decrement count if enabled
764     if (steps > 0) steps -= 1;
765 
766     // Starting point of the 3 offsets to be searched.
767     unsigned start;
768     if (best_index == 1) {
769       start = 6;
770     } else if (best_index == 8) {
771       start = 1;
772     } else {
773       start = best_index - 1;
774     }
775 
776     // Move the center to the best match.
777     mv.x += large_hexbs[best_index].x;
778     mv.y += large_hexbs[best_index].y;
779     best_index = 0;
780 
781     // Iterate through the next 3 points.
782     for (int i = 0; i < 3; ++i) {
783       vector2d_t offset = large_hexbs[start + i];
784       if (check_mv_cost(info, mv.x + offset.x, mv.y + offset.y)) {
785         best_index = start + i;
786       }
787     }
788   }
789 
790   // Move the center to the best match.
791   //mv.x += large_hexbs[best_index].x;
792   //mv.y += large_hexbs[best_index].y;
793 
794   // Do the final step of the search with a small pattern.
795   for (int i = 1; i < 9; ++i) {
796     check_mv_cost(info, mv.x + small_hexbs[i].x, mv.y + small_hexbs[i].y);
797   }
798 }
799 
800 /**
801 * \brief Do motion search using the diamond algorithm.
802 *
803 * \param info      search info
804 * \param extra_mv  extra motion vector to check
805 * \param steps     how many steps are done at maximum before exiting
806 *
807 * Motion vector is searched by searching iteratively with a diamond-shaped
808 * pattern. We take care of not checking the direction we came from, but
809 * further checking for avoiding visits to already visited points is not done.
810 *
811 * If a non 0,0 predicted motion vector predictor is given as extra_mv,
812 * the 0,0 vector is also tried. This is hoped to help in the case where
813 * the predicted motion vector is way off. In the future even more additional
814 * points like 0,0 might be used, such as vectors from top or left.
815 **/
diamond_search(inter_search_info_t * info,vector2d_t extra_mv,uint32_t steps)816 static void diamond_search(inter_search_info_t *info, vector2d_t extra_mv, uint32_t steps)
817 {
818   enum diapos {
819     DIA_UP = 0,
820     DIA_RIGHT = 1,
821     DIA_LEFT = 2,
822     DIA_DOWN = 3,
823     DIA_CENTER = 4,
824   };
825 
826   // a diamond shape with the center included
827   //   0
828   // 2 4 1
829   //   3
830   static const vector2d_t diamond[5] = {
831     {0, -1}, {1, 0}, {0, 1}, {-1, 0},
832     {0, 0}
833   };
834 
835   info->best_cost = UINT32_MAX;
836 
837   // Select starting point from among merge candidates. These should
838   // include both mv_cand vectors and (0, 0).
839   select_starting_point(info, extra_mv);
840 
841   // Check if we should stop search
842   if (info->state->encoder_control->cfg.me_early_termination &&
843     early_terminate(info))
844   {
845     return;
846   }
847 
848   // current motion vector
849   vector2d_t mv = { info->best_mv.x >> 2, info->best_mv.y >> 2 };
850 
851   // current best index
852   enum diapos best_index = DIA_CENTER;
853 
854   // initial search of the points of the diamond
855   for (int i = 0; i < 5; ++i) {
856     if (check_mv_cost(info, mv.x + diamond[i].x, mv.y + diamond[i].y)) {
857       best_index = i;
858     }
859   }
860 
861   if (best_index == DIA_CENTER) {
862     // the center point was the best in initial check
863     return;
864   }
865 
866   // Move the center to the best match.
867   mv.x += diamond[best_index].x;
868   mv.y += diamond[best_index].y;
869 
870   // the arrival direction, the index of the diamond member that will be excluded
871   enum diapos from_dir = DIA_CENTER;
872 
873   // whether we found a better candidate this iteration
874   uint8_t better_found;
875 
876   do {
877     better_found = 0;
878     // decrement count if enabled
879     if (steps > 0) steps -= 1;
880 
881     // search the points of the diamond
882     for (int i = 0; i < 4; ++i) {
883       // this is where we came from so it's checked already
884       if (i == from_dir) continue;
885 
886       if (check_mv_cost(info, mv.x + diamond[i].x, mv.y + diamond[i].y)) {
887         best_index = i;
888         better_found = 1;
889       }
890     }
891 
892     if (better_found) {
893       // Move the center to the best match.
894       mv.x += diamond[best_index].x;
895       mv.y += diamond[best_index].y;
896 
897       // record where we came from to the next iteration
898       // the xor operation flips the orientation
899       from_dir = best_index ^ 0x3;
900     }
901   } while (better_found && steps != 0);
902   // and we're done
903 }
904 
905 
search_mv_full(inter_search_info_t * info,int32_t search_range,vector2d_t extra_mv)906 static void search_mv_full(inter_search_info_t *info,
907                            int32_t search_range,
908                            vector2d_t extra_mv)
909 {
910   // Search around the 0-vector.
911   for (int y = -search_range; y <= search_range; y++) {
912     for (int x = -search_range; x <= search_range; x++) {
913       check_mv_cost(info, x, y);
914     }
915   }
916 
917   // Change to integer precision.
918   extra_mv.x >>= 2;
919   extra_mv.y >>= 2;
920 
921   // Check around extra_mv if it's not one of the merge candidates.
922   if (!mv_in_merge(info, extra_mv)) {
923     for (int y = -search_range; y <= search_range; y++) {
924       for (int x = -search_range; x <= search_range; x++) {
925         check_mv_cost(info, extra_mv.x + x, extra_mv.y + y);
926       }
927     }
928   }
929 
930   // Select starting point from among merge candidates. These should include
931   // both mv_cand vectors and (0, 0).
932   for (int i = 0; i < info->num_merge_cand; ++i) {
933     if (info->merge_cand[i].dir == 3) continue;
934 
935     vector2d_t mv = {
936       .x = info->merge_cand[i].mv[info->merge_cand[i].dir - 1][0] >> 2,
937       .y = info->merge_cand[i].mv[info->merge_cand[i].dir - 1][1] >> 2,
938     };
939 
940     // Ignore 0-vector because it has already been checked.
941     if (mv.x == 0 && mv.y == 0) continue;
942 
943     vector2d_t min_mv = { mv.x - search_range, mv.y - search_range };
944     vector2d_t max_mv = { mv.x + search_range, mv.y + search_range };
945 
946     for (int y = min_mv.y; y <= max_mv.y; ++y) {
947       for (int x = min_mv.x; x <= max_mv.x; ++x) {
948         if (!intmv_within_tile(info, x, y)) {
949           continue;
950         }
951 
952         // Avoid calculating the same points over and over again.
953         bool already_tested = false;
954         for (int j = -1; j < i; ++j) {
955           int xx = 0;
956           int yy = 0;
957           if (j >= 0) {
958             if (info->merge_cand[j].dir == 3) continue;
959             xx = info->merge_cand[j].mv[info->merge_cand[j].dir - 1][0] >> 2;
960             yy = info->merge_cand[j].mv[info->merge_cand[j].dir - 1][1] >> 2;
961           }
962           if (x >= xx - search_range && x <= xx + search_range &&
963               y >= yy - search_range && y <= yy + search_range)
964           {
965             already_tested = true;
966             x = xx + search_range;
967             break;
968           }
969         }
970         if (already_tested) continue;
971 
972         check_mv_cost(info, x, y);
973       }
974     }
975   }
976 }
977 
978 
979 /**
980  * \brief Do fractional motion estimation
981  *
982  * Algoritm first searches 1/2-pel positions around integer mv and after best match is found,
983  * refines the search by searching best 1/4-pel postion around best 1/2-pel position.
984  */
search_frac(inter_search_info_t * info)985 static void search_frac(inter_search_info_t *info)
986 {
987   // Map indexes to relative coordinates in the following way:
988   // 5 3 6
989   // 1 0 2
990   // 7 4 8
991   static const vector2d_t square[9] = {
992       {  0,  0 },  { -1,  0 },  {  1,  0 },
993       {  0, -1 },  {  0,  1 },  { -1, -1 },
994       {  1, -1 },  { -1,  1 },  {  1,  1 }
995   };
996 
997   // Set mv to pixel precision
998   vector2d_t mv = { info->best_mv.x >> 2, info->best_mv.y >> 2 };
999 
1000   unsigned best_cost = UINT32_MAX;
1001   uint32_t best_bitcost = 0;
1002   uint32_t bitcosts[4] = { 0 };
1003   unsigned best_index = 0;
1004 
1005   unsigned costs[4] = { 0 };
1006 
1007   ALIGNED(64) kvz_pixel filtered[4][LCU_LUMA_SIZE];
1008 
1009   // Storage buffers for intermediate horizontally filtered results.
1010   // Have the first columns in contiguous memory for vectorization.
1011   ALIGNED(64) int16_t intermediate[5][KVZ_IPOL_MAX_IM_SIZE_LUMA_SIMD];
1012   int16_t hor_first_cols[5][KVZ_EXT_BLOCK_W_LUMA + 1];
1013 
1014   const kvz_picture *ref = info->ref;
1015   const kvz_picture *pic = info->pic;
1016   vector2d_t orig = info->origin;
1017   const int width = info->width;
1018   const int height = info->height;
1019   const int internal_width  = ((width  + 7) >> 3) << 3; // Round up to closest 8
1020   const int internal_height = ((height + 7) >> 3) << 3;
1021 
1022   const encoder_state_t *state = info->state;
1023   int fme_level = state->encoder_control->cfg.fme_level;
1024   int8_t sample_off_x = 0;
1025   int8_t sample_off_y = 0;
1026 
1027   // Space for (possibly) extrapolated pixels and the part from the picture
1028   // One extra row and column compared to normal interpolation and some extra for AVX2.
1029   // The extrapolation function will set the pointers and stride.
1030   kvz_pixel ext_buffer[KVZ_FME_MAX_INPUT_SIZE_SIMD];
1031   kvz_pixel *ext = NULL;
1032   kvz_pixel *ext_origin = NULL;
1033   int ext_s = 0;
1034   kvz_epol_args epol_args = {
1035     .src = ref->y,
1036     .src_w = ref->width,
1037     .src_h = ref->height,
1038     .src_s = ref->stride,
1039     .blk_x = state->tile->offset_x + orig.x + mv.x - 1,
1040     .blk_y = state->tile->offset_y + orig.y + mv.y - 1,
1041     .blk_w = internal_width + 1,  // TODO: real width
1042     .blk_h = internal_height + 1, // TODO: real height
1043     .pad_l = KVZ_LUMA_FILTER_OFFSET,
1044     .pad_r = KVZ_EXT_PADDING_LUMA - KVZ_LUMA_FILTER_OFFSET,
1045     .pad_t = KVZ_LUMA_FILTER_OFFSET,
1046     .pad_b = KVZ_EXT_PADDING_LUMA - KVZ_LUMA_FILTER_OFFSET,
1047     .pad_b_simd = 0 // AVX2 padding unnecessary because of blk_h
1048   };
1049 
1050   // Initialize separately. Gets rid of warning
1051   // about using nonstandard extension.
1052   epol_args.buf = ext_buffer;
1053   epol_args.ext = &ext;
1054   epol_args.ext_origin = &ext_origin;
1055   epol_args.ext_s = &ext_s;
1056 
1057   kvz_get_extended_block(&epol_args);
1058 
1059   kvz_pixel *tmp_pic = pic->y + orig.y * pic->stride + orig.x;
1060   int tmp_stride = pic->stride;
1061 
1062   // Search integer position
1063   costs[0] = kvz_satd_any_size(width, height,
1064     tmp_pic, tmp_stride,
1065     ext_origin + ext_s + 1, ext_s);
1066 
1067   costs[0] += info->mvd_cost_func(state,
1068                                   mv.x, mv.y, 2,
1069                                   info->mv_cand,
1070                                   info->merge_cand,
1071                                   info->num_merge_cand,
1072                                   info->ref_idx,
1073                                   &bitcosts[0]);
1074   best_cost = costs[0];
1075   best_bitcost = bitcosts[0];
1076 
1077   //Set mv to half-pixel precision
1078   mv.x *= 2;
1079   mv.y *= 2;
1080 
1081   ipol_blocks_func * filter_steps[4] = {
1082     kvz_filter_hpel_blocks_hor_ver_luma,
1083     kvz_filter_hpel_blocks_diag_luma,
1084     kvz_filter_qpel_blocks_hor_ver_luma,
1085     kvz_filter_qpel_blocks_diag_luma,
1086   };
1087 
1088   // Search halfpel positions around best integer mv
1089   int i = 1;
1090   for (int step = 0; step < fme_level; ++step){
1091 
1092     const int mv_shift = (step < 2) ? 1 : 0;
1093 
1094     filter_steps[step](state->encoder_control,
1095       ext_origin,
1096       ext_s,
1097       internal_width,
1098       internal_height,
1099       filtered,
1100       intermediate,
1101       fme_level,
1102       hor_first_cols,
1103       sample_off_x,
1104       sample_off_y);
1105 
1106     const vector2d_t *pattern[4] = { &square[i], &square[i + 1], &square[i + 2], &square[i + 3] };
1107 
1108     int8_t within_tile[4];
1109     for (int j = 0; j < 4; j++) {
1110       within_tile[j] =
1111         fracmv_within_tile(info, (mv.x + pattern[j]->x) * (1 << mv_shift), (mv.y + pattern[j]->y) * (1 << mv_shift));
1112     };
1113 
1114     kvz_pixel *filtered_pos[4] = { 0 };
1115     filtered_pos[0] = &filtered[0][0];
1116     filtered_pos[1] = &filtered[1][0];
1117     filtered_pos[2] = &filtered[2][0];
1118     filtered_pos[3] = &filtered[3][0];
1119 
1120     kvz_satd_any_size_quad(width, height, (const kvz_pixel **)filtered_pos, LCU_WIDTH, tmp_pic, tmp_stride, 4, costs, within_tile);
1121 
1122     for (int j = 0; j < 4; j++) {
1123       if (within_tile[j]) {
1124         costs[j] += info->mvd_cost_func(
1125             state,
1126             mv.x + pattern[j]->x,
1127             mv.y + pattern[j]->y,
1128             mv_shift,
1129             info->mv_cand,
1130             info->merge_cand,
1131             info->num_merge_cand,
1132             info->ref_idx,
1133             &bitcosts[j]
1134         );
1135       }
1136     }
1137 
1138     for (int j = 0; j < 4; ++j) {
1139       if (within_tile[j] && costs[j] < best_cost) {
1140         best_cost = costs[j];
1141         best_bitcost = bitcosts[j];
1142         best_index = i + j;
1143       }
1144     }
1145 
1146     i += 4;
1147 
1148     // Update mv for the best position on current precision
1149     if (step == 1 || step == fme_level - 1) {
1150       // Move search to best_index
1151       mv.x += square[best_index].x;
1152       mv.y += square[best_index].y;
1153 
1154       // On last hpel step...
1155       if (step == MIN(fme_level - 1, 1)) {
1156         //Set mv to quarterpel precision
1157         mv.x *= 2;
1158         mv.y *= 2;
1159         sample_off_x = square[best_index].x;
1160         sample_off_y = square[best_index].y;
1161         best_index = 0;
1162         i = 1;
1163       }
1164     }
1165   }
1166 
1167   info->best_mv = mv;
1168   info->best_cost = best_cost;
1169   info->best_bitcost = best_bitcost;
1170 }
1171 
1172 /**
1173 * \brief Calculate the scaled MV
1174 */
get_scaled_mv(int16_t mv,int scale)1175 static INLINE int16_t get_scaled_mv(int16_t mv, int scale)
1176 {
1177   int32_t scaled = scale * mv;
1178   return CLIP(-32768, 32767, (scaled + 127 + (scaled < 0)) >> 8);
1179 }
1180 /**
1181 * \brief Scale the MV according to the POC difference
1182 *
1183 * \param current_poc        POC of current frame
1184 * \param current_ref_poc    POC of reference frame
1185 * \param neighbor_poc       POC of neighbor frame
1186 * \param neighbor_ref_poc   POC of neighbors reference frame
1187 * \param mv_cand            MV candidates to scale
1188 */
apply_mv_scaling(int32_t current_poc,int32_t current_ref_poc,int32_t neighbor_poc,int32_t neighbor_ref_poc,vector2d_t * mv_cand)1189 static void apply_mv_scaling(int32_t current_poc,
1190   int32_t current_ref_poc,
1191   int32_t neighbor_poc,
1192   int32_t neighbor_ref_poc,
1193   vector2d_t* mv_cand)
1194 {
1195   int32_t diff_current = current_poc - current_ref_poc;
1196   int32_t diff_neighbor = neighbor_poc - neighbor_ref_poc;
1197 
1198   if (diff_current == diff_neighbor) return;
1199   if (diff_neighbor == 0) return;
1200 
1201   diff_current = CLIP(-128, 127, diff_current);
1202   diff_neighbor = CLIP(-128, 127, diff_neighbor);
1203 
1204   int scale = CLIP(-4096, 4095,
1205     (diff_current * ((0x4000 + (abs(diff_neighbor) >> 1)) / diff_neighbor) + 32) >> 6);
1206 
1207   mv_cand->x = get_scaled_mv(mv_cand->x, scale);
1208   mv_cand->y = get_scaled_mv(mv_cand->y, scale);
1209 }
1210 
1211 
1212 /**
1213  * \brief Perform inter search for a single reference frame.
1214  */
search_pu_inter_ref(inter_search_info_t * info,int depth,lcu_t * lcu,cu_info_t * cur_cu,double * inter_cost,uint32_t * inter_bitcost,double * best_LX_cost,cu_info_t * unipred_LX)1215 static void search_pu_inter_ref(inter_search_info_t *info,
1216   int depth,
1217   lcu_t *lcu, cu_info_t *cur_cu,
1218   double *inter_cost,
1219   uint32_t *inter_bitcost,
1220   double *best_LX_cost,
1221   cu_info_t *unipred_LX)
1222 {
1223   const kvz_config *cfg = &info->state->encoder_control->cfg;
1224 
1225   // which list, L0 or L1, ref_idx is in and in what index
1226   int8_t ref_list = -1;
1227   // the index of the ref_idx in L0 or L1 list
1228   int8_t LX_idx;
1229   // max value of LX_idx plus one
1230   const int8_t LX_IDX_MAX_PLUS_1 = MAX(info->state->frame->ref_LX_size[0],
1231     info->state->frame->ref_LX_size[1]);
1232 
1233   for (LX_idx = 0; LX_idx < LX_IDX_MAX_PLUS_1; LX_idx++)
1234   {
1235     // check if ref_idx is in L0
1236     if (LX_idx < info->state->frame->ref_LX_size[0] &&
1237       info->state->frame->ref_LX[0][LX_idx] == info->ref_idx) {
1238       ref_list = 0;
1239       break;
1240     }
1241 
1242     // check if ref_idx is in L1
1243     if (LX_idx < info->state->frame->ref_LX_size[1] &&
1244       info->state->frame->ref_LX[1][LX_idx] == info->ref_idx) {
1245       ref_list = 1;
1246       break;
1247     }
1248   }
1249   // ref_idx has to be found in either L0 or L1
1250   assert(LX_idx < LX_IDX_MAX_PLUS_1);
1251 
1252   // store temp values to be stored back later
1253   int8_t temp_ref_idx = cur_cu->inter.mv_ref[ref_list];
1254 
1255   // Get MV candidates
1256   cur_cu->inter.mv_ref[ref_list] = LX_idx;
1257 
1258   kvz_inter_get_mv_cand(info->state,
1259     info->origin.x,
1260     info->origin.y,
1261     info->width,
1262     info->height,
1263     info->mv_cand,
1264     cur_cu,
1265     lcu,
1266     ref_list);
1267 
1268   // store old values back
1269   cur_cu->inter.mv_ref[ref_list] = temp_ref_idx;
1270 
1271   vector2d_t mv = { 0, 0 };
1272 
1273   // Take starting point for MV search from previous frame.
1274   // When temporal motion vector candidates are added, there is probably
1275   // no point to this anymore, but for now it helps.
1276   const int mid_x = info->state->tile->offset_x + info->origin.x + (info->width >> 1);
1277   const int mid_y = info->state->tile->offset_y + info->origin.y + (info->height >> 1);
1278   const cu_array_t* ref_array = info->state->frame->ref->cu_arrays[info->ref_idx];
1279   const cu_info_t* ref_cu = kvz_cu_array_at_const(ref_array, mid_x, mid_y);
1280   if (ref_cu->type == CU_INTER) {
1281     vector2d_t mv_previous = { 0, 0 };
1282     if (ref_cu->inter.mv_dir & 1) {
1283       mv_previous.x = ref_cu->inter.mv[0][0];
1284       mv_previous.y = ref_cu->inter.mv[0][1];
1285     }
1286     else {
1287       mv_previous.x = ref_cu->inter.mv[1][0];
1288       mv_previous.y = ref_cu->inter.mv[1][1];
1289     }
1290     // Apply mv scaling if neighbor poc is available
1291     if (info->state->frame->ref_LX_size[ref_list] > 0) {
1292       // When there are reference pictures from the future (POC > current POC)
1293       // in L0 or L1, the primary list for the colocated PU is the inverse of
1294       // collocated_from_l0_flag. Otherwise it is equal to reflist.
1295       //
1296       // Kvazaar always sets collocated_from_l0_flag so the list is L1 when
1297       // there are future references.
1298       int col_list = ref_list;
1299       for (int i = 0; i < info->state->frame->ref->used_size; i++) {
1300         if (info->state->frame->ref->pocs[i] > info->state->frame->poc) {
1301           col_list = 1;
1302           break;
1303         }
1304       }
1305       if ((ref_cu->inter.mv_dir & (col_list + 1)) == 0) {
1306         // Use the other list if the colocated PU does not have a MV for the
1307         // primary list.
1308         col_list = 1 - col_list;
1309       }
1310 
1311       uint8_t neighbor_poc_index = info->state->frame->ref_LX[ref_list][LX_idx];
1312       // Scaling takes current POC, reference POC, neighbor POC and neighbor reference POC as argument
1313       apply_mv_scaling(
1314         info->state->frame->poc,
1315         info->state->frame->ref->pocs[info->state->frame->ref_LX[ref_list][LX_idx]],
1316         info->state->frame->ref->pocs[neighbor_poc_index],
1317         info->state->frame->ref->images[neighbor_poc_index]->ref_pocs[
1318           info->state->frame->ref->ref_LXs[neighbor_poc_index]
1319           [col_list]
1320           [ref_cu->inter.mv_ref[col_list]]
1321         ],
1322         &mv_previous
1323       );
1324     }
1325 
1326     // Check if the mv is valid after scaling
1327     if (fracmv_within_tile(info, mv_previous.x, mv_previous.y)) {
1328       mv = mv_previous;
1329     }
1330   }
1331 
1332   int search_range = 32;
1333   switch (cfg->ime_algorithm) {
1334     case KVZ_IME_FULL64: search_range = 64; break;
1335     case KVZ_IME_FULL32: search_range = 32; break;
1336     case KVZ_IME_FULL16: search_range = 16; break;
1337     case KVZ_IME_FULL8: search_range = 8; break;
1338     default: break;
1339   }
1340 
1341   info->best_cost = UINT32_MAX;
1342 
1343   switch (cfg->ime_algorithm) {
1344     case KVZ_IME_TZ:
1345       tz_search(info, mv);
1346       break;
1347 
1348     case KVZ_IME_FULL64:
1349     case KVZ_IME_FULL32:
1350     case KVZ_IME_FULL16:
1351     case KVZ_IME_FULL8:
1352     case KVZ_IME_FULL:
1353       search_mv_full(info, search_range, mv);
1354       break;
1355 
1356     case KVZ_IME_DIA:
1357       diamond_search(info, mv, info->state->encoder_control->cfg.me_max_steps);
1358       break;
1359 
1360     default:
1361       hexagon_search(info, mv, info->state->encoder_control->cfg.me_max_steps);
1362       break;
1363   }
1364 
1365   if (cfg->fme_level > 0 && info->best_cost < *inter_cost) {
1366     search_frac(info);
1367 
1368   } else if (info->best_cost < UINT32_MAX) {
1369     // Recalculate inter cost with SATD.
1370     info->best_cost = kvz_image_calc_satd(
1371         info->state->tile->frame->source,
1372         info->ref,
1373         info->origin.x,
1374         info->origin.y,
1375         info->state->tile->offset_x + info->origin.x + (info->best_mv.x >> 2),
1376         info->state->tile->offset_y + info->origin.y + (info->best_mv.y >> 2),
1377         info->width,
1378         info->height);
1379     info->best_cost += info->best_bitcost * (int)(info->state->lambda_sqrt + 0.5);
1380   }
1381 
1382   mv = info->best_mv;
1383 
1384   int merged = 0;
1385   int merge_idx = 0;
1386   // Check every candidate to find a match
1387   for (merge_idx = 0; merge_idx < info->num_merge_cand; merge_idx++) {
1388     if (info->merge_cand[merge_idx].dir != 3 &&
1389         info->merge_cand[merge_idx].mv[info->merge_cand[merge_idx].dir - 1][0] == mv.x &&
1390         info->merge_cand[merge_idx].mv[info->merge_cand[merge_idx].dir - 1][1] == mv.y &&
1391         (uint32_t)info->state->frame->ref_LX[info->merge_cand[merge_idx].dir - 1][
1392         info->merge_cand[merge_idx].ref[info->merge_cand[merge_idx].dir - 1]] == info->ref_idx)
1393     {
1394       merged = 1;
1395       break;
1396     }
1397   }
1398 
1399   // Only check when candidates are different
1400   int cu_mv_cand = 0;
1401   if (!merged) {
1402     cu_mv_cand =
1403       select_mv_cand(info->state, info->mv_cand, mv.x, mv.y, NULL);
1404   }
1405 
1406   if (info->best_cost < *inter_cost) {
1407     // Map reference index to L0/L1 pictures
1408     cur_cu->inter.mv_dir = ref_list+1;
1409     uint8_t mv_ref_coded = LX_idx;
1410 
1411     cur_cu->merged                  = merged;
1412     cur_cu->merge_idx               = merge_idx;
1413     cur_cu->inter.mv_ref[ref_list]  = LX_idx;
1414     cur_cu->inter.mv[ref_list][0]   = (int16_t)mv.x;
1415     cur_cu->inter.mv[ref_list][1]   = (int16_t)mv.y;
1416 
1417     CU_SET_MV_CAND(cur_cu, ref_list, cu_mv_cand);
1418 
1419     *inter_cost = info->best_cost;
1420     *inter_bitcost = info->best_bitcost + cur_cu->inter.mv_dir - 1 + mv_ref_coded;
1421   }
1422 
1423 
1424   // Update best unipreds for biprediction
1425   if (info->best_cost < best_LX_cost[ref_list]) {
1426     bool valid_mv = fracmv_within_tile(info, mv.x, mv.y);
1427     if (valid_mv) {
1428       // Map reference index to L0/L1 pictures
1429       unipred_LX[ref_list].inter.mv_dir = ref_list + 1;
1430       unipred_LX[ref_list].inter.mv_ref[ref_list] = LX_idx;
1431       unipred_LX[ref_list].inter.mv[ref_list][0] = (int16_t)mv.x;
1432       unipred_LX[ref_list].inter.mv[ref_list][1] = (int16_t)mv.y;
1433 
1434       CU_SET_MV_CAND(&unipred_LX[ref_list], ref_list, cu_mv_cand);
1435 
1436       best_LX_cost[ref_list] = info->best_cost;
1437     }
1438   }
1439 }
1440 
1441 
1442 /**
1443  * \brief Search bipred modes for a PU.
1444  */
search_pu_inter_bipred(inter_search_info_t * info,int depth,lcu_t * lcu,cu_info_t * cur_cu,double * inter_cost,uint32_t * inter_bitcost)1445 static void search_pu_inter_bipred(inter_search_info_t *info,
1446                                    int depth,
1447                                    lcu_t *lcu, cu_info_t *cur_cu,
1448                                    double *inter_cost,
1449                                    uint32_t *inter_bitcost)
1450 {
1451   const image_list_t *const ref = info->state->frame->ref;
1452   uint8_t (*ref_LX)[16] = info->state->frame->ref_LX;
1453   const videoframe_t * const frame = info->state->tile->frame;
1454   const int x         = info->origin.x;
1455   const int y         = info->origin.y;
1456   const int width     = info->width;
1457   const int height    = info->height;
1458 
1459   static const uint8_t priorityList0[] = { 0, 1, 0, 2, 1, 2, 0, 3, 1, 3, 2, 3 };
1460   static const uint8_t priorityList1[] = { 1, 0, 2, 0, 2, 1, 3, 0, 3, 1, 3, 2 };
1461   const unsigned num_cand_pairs =
1462     MIN(info->num_merge_cand * (info->num_merge_cand - 1), 12);
1463 
1464   inter_merge_cand_t *merge_cand = info->merge_cand;
1465 
1466   for (int32_t idx = 0; idx < num_cand_pairs; idx++) {
1467     uint8_t i = priorityList0[idx];
1468     uint8_t j = priorityList1[idx];
1469     if (i >= info->num_merge_cand || j >= info->num_merge_cand) break;
1470 
1471     // Find one L0 and L1 candidate according to the priority list
1472     if (!(merge_cand[i].dir & 0x1) || !(merge_cand[j].dir & 0x2)) continue;
1473 
1474     if (ref_LX[0][merge_cand[i].ref[0]] == ref_LX[1][merge_cand[j].ref[1]] &&
1475         merge_cand[i].mv[0][0] == merge_cand[j].mv[1][0] &&
1476         merge_cand[i].mv[0][1] == merge_cand[j].mv[1][1])
1477     {
1478       continue;
1479     }
1480 
1481     int16_t mv[2][2];
1482     mv[0][0] = merge_cand[i].mv[0][0];
1483     mv[0][1] = merge_cand[i].mv[0][1];
1484     mv[1][0] = merge_cand[j].mv[1][0];
1485     mv[1][1] = merge_cand[j].mv[1][1];
1486 
1487     // Don't try merge candidates that don't satisfy mv constraints.
1488     if (!fracmv_within_tile(info, mv[0][0], mv[0][1]) ||
1489         !fracmv_within_tile(info, mv[1][0], mv[1][1]))
1490     {
1491       continue;
1492     }
1493 
1494     kvz_inter_recon_bipred(info->state,
1495                            ref->images[ref_LX[0][merge_cand[i].ref[0]]],
1496                            ref->images[ref_LX[1][merge_cand[j].ref[1]]],
1497                            x, y,
1498                            width,
1499                            height,
1500                            mv,
1501                            lcu,
1502                            true,
1503                            false);
1504 
1505     const kvz_pixel *rec = &lcu->rec.y[SUB_SCU(y) * LCU_WIDTH + SUB_SCU(x)];
1506     const kvz_pixel *src = &frame->source->y[x + y * frame->source->width];
1507     uint32_t cost =
1508       kvz_satd_any_size(width, height, rec, LCU_WIDTH, src, frame->source->width);
1509 
1510     uint32_t bitcost[2] = { 0, 0 };
1511 
1512     cost += info->mvd_cost_func(info->state,
1513                                merge_cand[i].mv[0][0],
1514                                merge_cand[i].mv[0][1],
1515                                0,
1516                                info->mv_cand,
1517                                NULL, 0, 0,
1518                                &bitcost[0]);
1519     cost += info->mvd_cost_func(info->state,
1520                                merge_cand[i].mv[1][0],
1521                                merge_cand[i].mv[1][1],
1522                                0,
1523                                info->mv_cand,
1524                                NULL, 0, 0,
1525                                &bitcost[1]);
1526 
1527     const uint8_t mv_ref_coded[2] = {
1528       merge_cand[i].ref[0],
1529       merge_cand[j].ref[1]
1530     };
1531     const int extra_bits = mv_ref_coded[0] + mv_ref_coded[1] + 2 /* mv dir cost */;
1532     cost += info->state->lambda_sqrt * extra_bits + 0.5;
1533 
1534     if (cost < *inter_cost) {
1535       cur_cu->inter.mv_dir = 3;
1536 
1537       cur_cu->inter.mv_ref[0] = merge_cand[i].ref[0];
1538       cur_cu->inter.mv_ref[1] = merge_cand[j].ref[1];
1539 
1540       cur_cu->inter.mv[0][0] = merge_cand[i].mv[0][0];
1541       cur_cu->inter.mv[0][1] = merge_cand[i].mv[0][1];
1542       cur_cu->inter.mv[1][0] = merge_cand[j].mv[1][0];
1543       cur_cu->inter.mv[1][1] = merge_cand[j].mv[1][1];
1544       cur_cu->merged = 0;
1545 
1546       // Check every candidate to find a match
1547       for (int merge_idx = 0; merge_idx < info->num_merge_cand; merge_idx++) {
1548         if (merge_cand[merge_idx].mv[0][0] == cur_cu->inter.mv[0][0] &&
1549             merge_cand[merge_idx].mv[0][1] == cur_cu->inter.mv[0][1] &&
1550             merge_cand[merge_idx].mv[1][0] == cur_cu->inter.mv[1][0] &&
1551             merge_cand[merge_idx].mv[1][1] == cur_cu->inter.mv[1][1] &&
1552             merge_cand[merge_idx].ref[0] == cur_cu->inter.mv_ref[0] &&
1553             merge_cand[merge_idx].ref[1] == cur_cu->inter.mv_ref[1])
1554         {
1555           cur_cu->merged = 1;
1556           cur_cu->merge_idx = merge_idx;
1557           break;
1558         }
1559       }
1560 
1561       // Each motion vector has its own candidate
1562       for (int reflist = 0; reflist < 2; reflist++) {
1563         kvz_inter_get_mv_cand(info->state, x, y, width, height, info->mv_cand, cur_cu, lcu, reflist);
1564         int cu_mv_cand = select_mv_cand(
1565             info->state,
1566             info->mv_cand,
1567             cur_cu->inter.mv[reflist][0],
1568             cur_cu->inter.mv[reflist][1],
1569             NULL);
1570         CU_SET_MV_CAND(cur_cu, reflist, cu_mv_cand);
1571       }
1572 
1573       *inter_cost = cost;
1574       *inter_bitcost = bitcost[0] + bitcost[1] + extra_bits;
1575     }
1576   }
1577 }
1578 
1579 /**
1580  * \brief Check if an identical merge candidate exists in a list
1581  *
1582  * \param all_cand        Full list of available merge candidates
1583  * \param cand_to_add     Merge candidate to be checked for duplicates
1584  * \param added_idx_list  List of indices of unique merge candidates
1585  * \param list_size       Size of the list
1586  *
1587  * \return                Does an identical candidate exist in list
1588  */
merge_candidate_in_list(inter_merge_cand_t * all_cands,inter_merge_cand_t * cand_to_add,int8_t * added_idx_list,int list_size)1589 static bool merge_candidate_in_list(inter_merge_cand_t * all_cands,
1590                                     inter_merge_cand_t * cand_to_add,
1591                                     int8_t * added_idx_list,
1592                                     int list_size)
1593 {
1594   bool found = false;
1595   for (int i = 0; i < list_size && !found; ++i) {
1596     inter_merge_cand_t * list_cand = &all_cands[added_idx_list[i]];
1597 
1598     found = cand_to_add->dir == list_cand->dir &&
1599         cand_to_add->ref[0] == list_cand->ref[0] &&
1600         cand_to_add->mv[0][0] == list_cand->mv[0][0] &&
1601         cand_to_add->mv[0][1] == list_cand->mv[0][1] &&
1602         cand_to_add->ref[1] == list_cand->ref[1] &&
1603         cand_to_add->mv[1][0] == list_cand->mv[1][0] &&
1604         cand_to_add->mv[1][1] == list_cand->mv[1][1];
1605   }
1606 
1607   return found;
1608 }
1609 
1610 /**
1611  * \brief Update PU to have best modes at this depth.
1612  *
1613  * \param state       encoder state
1614  * \param x_cu        x-coordinate of the containing CU
1615  * \param y_cu        y-coordinate of the containing CU
1616  * \param depth       depth of the CU in the quadtree
1617  * \param part_mode   partition mode of the CU
1618  * \param i_pu        index of the PU in the CU
1619  * \param lcu         containing LCU
1620  *
1621  * \param inter_cost    Return inter cost of the best mode
1622  * \param inter_bitcost Return inter bitcost of the best mode
1623  */
search_pu_inter(encoder_state_t * const state,int x_cu,int y_cu,int depth,part_mode_t part_mode,int i_pu,lcu_t * lcu,double * inter_cost,uint32_t * inter_bitcost)1624 static void search_pu_inter(encoder_state_t * const state,
1625                             int x_cu, int y_cu,
1626                             int depth,
1627                             part_mode_t part_mode,
1628                             int i_pu,
1629                             lcu_t *lcu,
1630                             double *inter_cost,
1631                             uint32_t *inter_bitcost)
1632 {
1633   *inter_cost = MAX_INT;
1634   *inter_bitcost = MAX_INT;
1635 
1636   const kvz_config *cfg = &state->encoder_control->cfg;
1637   const videoframe_t * const frame = state->tile->frame;
1638   const int width_cu  = LCU_WIDTH >> depth;
1639   const int x         = PU_GET_X(part_mode, width_cu, x_cu, i_pu);
1640   const int y         = PU_GET_Y(part_mode, width_cu, y_cu, i_pu);
1641   const int width     = PU_GET_W(part_mode, width_cu, i_pu);
1642   const int height    = PU_GET_H(part_mode, width_cu, i_pu);
1643 
1644   // Merge candidate A1 may not be used for the second PU of Nx2N, nLx2N and
1645   // nRx2N partitions.
1646   const bool merge_a1 = i_pu == 0 || width >= height;
1647   // Merge candidate B1 may not be used for the second PU of 2NxN, 2NxnU and
1648   // 2NxnD partitions.
1649   const bool merge_b1 = i_pu == 0 || width <= height;
1650 
1651   const int x_local   = SUB_SCU(x);
1652   const int y_local   = SUB_SCU(y);
1653   cu_info_t *cur_cu   = LCU_GET_CU_AT_PX(lcu, x_local, y_local);
1654 
1655   inter_search_info_t info = {
1656     .state          = state,
1657     .pic            = frame->source,
1658     .origin         = { x, y },
1659     .width          = width,
1660     .height         = height,
1661     .mvd_cost_func  = cfg->mv_rdo ? kvz_calc_mvd_cost_cabac : calc_mvd_cost,
1662     .optimized_sad  = kvz_get_optimized_sad(width),
1663   };
1664 
1665   // Search for merge mode candidates
1666   info.num_merge_cand = kvz_inter_get_merge_cand(
1667       state,
1668       x, y,
1669       width, height,
1670       merge_a1, merge_b1,
1671       info.merge_cand,
1672       lcu
1673   );
1674 
1675   // Default to candidate 0
1676   CU_SET_MV_CAND(cur_cu, 0, 0);
1677   CU_SET_MV_CAND(cur_cu, 1, 0);
1678 
1679   // Merge Analysis starts here
1680   int8_t mrg_cands[MRG_MAX_NUM_CANDS];
1681   double mrg_costs[MRG_MAX_NUM_CANDS];
1682   for (int i = 0; i < MRG_MAX_NUM_CANDS; ++i) {
1683     mrg_cands[i] = -1;
1684     mrg_costs[i] = MAX_DOUBLE;
1685   }
1686 
1687   int num_rdo_cands = 0;
1688 
1689   // Check motion vector constraints and perform rough search
1690   for (int merge_idx = 0; merge_idx < info.num_merge_cand; ++merge_idx) {
1691 
1692     inter_merge_cand_t *cur_cand = &info.merge_cand[merge_idx];
1693     cur_cu->inter.mv_dir = cur_cand->dir;
1694     cur_cu->inter.mv_ref[0] = cur_cand->ref[0];
1695     cur_cu->inter.mv_ref[1] = cur_cand->ref[1];
1696     cur_cu->inter.mv[0][0] = cur_cand->mv[0][0];
1697     cur_cu->inter.mv[0][1] = cur_cand->mv[0][1];
1698     cur_cu->inter.mv[1][0] = cur_cand->mv[1][0];
1699     cur_cu->inter.mv[1][1] = cur_cand->mv[1][1];
1700 
1701     // If bipred is not enabled, do not try candidates with mv_dir == 3.
1702     // Bipred is also forbidden for 4x8 and 8x4 blocks by the standard.
1703     if (cur_cu->inter.mv_dir == 3 && !state->encoder_control->cfg.bipred) continue;
1704     if (cur_cu->inter.mv_dir == 3 && !(width + height > 12)) continue;
1705 
1706     bool is_duplicate = merge_candidate_in_list(info.merge_cand, cur_cand,
1707       mrg_cands,
1708       num_rdo_cands);
1709 
1710     // Don't try merge candidates that don't satisfy mv constraints.
1711     // Don't add duplicates to list
1712     if (!fracmv_within_tile(&info, cur_cu->inter.mv[0][0], cur_cu->inter.mv[0][1]) ||
1713         !fracmv_within_tile(&info, cur_cu->inter.mv[1][0], cur_cu->inter.mv[1][1]) ||
1714         is_duplicate)
1715     {
1716       continue;
1717     }
1718 
1719     kvz_inter_pred_pu(state, lcu, x_cu, y_cu, width_cu, true, false, i_pu);
1720     mrg_costs[num_rdo_cands] = kvz_satd_any_size(width, height,
1721       lcu->rec.y + y_local * LCU_WIDTH + x_local, LCU_WIDTH,
1722       lcu->ref.y + y_local * LCU_WIDTH + x_local, LCU_WIDTH);
1723 
1724     // Add cost of coding the merge index
1725     mrg_costs[num_rdo_cands] += merge_idx * info.state->lambda_sqrt;
1726 
1727     mrg_cands[num_rdo_cands] = merge_idx;
1728     num_rdo_cands++;
1729   }
1730 
1731   // Sort candidates by cost
1732   kvz_sort_modes(mrg_cands, mrg_costs, num_rdo_cands);
1733 
1734   // Limit by availability
1735   // TODO: Do not limit to just 1
1736   num_rdo_cands = MIN(1, num_rdo_cands);
1737 
1738   // Early Skip Mode Decision
1739   bool has_chroma = state->encoder_control->chroma_format != KVZ_CSP_400;
1740   if (cfg->early_skip && cur_cu->part_size == SIZE_2Nx2N) {
1741     for (int merge_rdo_idx = 0; merge_rdo_idx < num_rdo_cands; ++merge_rdo_idx) {
1742 
1743       // Reconstruct blocks with merge candidate.
1744       // Check luma CBF. Then, check chroma CBFs if luma CBF is not set
1745       // and chroma exists.
1746       // Early terminate if merge candidate with zero CBF is found.
1747       int merge_idx = mrg_cands[merge_rdo_idx];
1748       cur_cu->inter.mv_dir = info.merge_cand[merge_idx].dir;
1749       cur_cu->inter.mv_ref[0] = info.merge_cand[merge_idx].ref[0];
1750       cur_cu->inter.mv_ref[1] = info.merge_cand[merge_idx].ref[1];
1751       cur_cu->inter.mv[0][0] = info.merge_cand[merge_idx].mv[0][0];
1752       cur_cu->inter.mv[0][1] = info.merge_cand[merge_idx].mv[0][1];
1753       cur_cu->inter.mv[1][0] = info.merge_cand[merge_idx].mv[1][0];
1754       cur_cu->inter.mv[1][1] = info.merge_cand[merge_idx].mv[1][1];
1755       kvz_lcu_fill_trdepth(lcu, x, y, depth, MAX(1, depth));
1756       kvz_inter_recon_cu(state, lcu, x, y, width, true, false);
1757       kvz_quantize_lcu_residual(state, true, false, x, y, depth, cur_cu, lcu, true);
1758 
1759       if (cbf_is_set(cur_cu->cbf, depth, COLOR_Y)) {
1760         continue;
1761       }
1762       else if (has_chroma) {
1763         kvz_inter_recon_cu(state, lcu, x, y, width, false, has_chroma);
1764         kvz_quantize_lcu_residual(state, false, has_chroma, x, y, depth, cur_cu, lcu, true);
1765         if (!cbf_is_set_any(cur_cu->cbf, depth)) {
1766           cur_cu->type = CU_INTER;
1767           cur_cu->merge_idx = merge_idx;
1768           cur_cu->skipped = true;
1769           *inter_cost = 0.0;  // TODO: Check this
1770           *inter_bitcost = merge_idx; // TODO: Check this
1771           return;
1772         }
1773       }
1774     }
1775   }
1776 
1777   // AMVP search starts here
1778 
1779   // Store unipred information of L0 and L1 for biprediction
1780   // Best cost will be left at MAX_DOUBLE if no valid CU is found
1781   double best_cost_LX[2] = { MAX_DOUBLE, MAX_DOUBLE };
1782   cu_info_t unipreds[2];
1783 
1784   for (int ref_idx = 0; ref_idx < state->frame->ref->used_size; ref_idx++) {
1785     info.ref_idx = ref_idx;
1786     info.ref = state->frame->ref->images[ref_idx];
1787 
1788     search_pu_inter_ref(&info, depth, lcu, cur_cu, inter_cost, inter_bitcost, best_cost_LX, unipreds);
1789   }
1790 
1791   // Search bi-pred positions
1792   bool can_use_bipred = state->frame->slicetype == KVZ_SLICE_B
1793     && cfg->bipred
1794     && width + height >= 16; // 4x8 and 8x4 PBs are restricted to unipred
1795 
1796   if (can_use_bipred) {
1797 
1798     // Try biprediction from valid acquired unipreds.
1799     if (best_cost_LX[0] != MAX_DOUBLE && best_cost_LX[1] != MAX_DOUBLE) {
1800 
1801       // TODO: logic is copy paste from search_pu_inter_bipred.
1802       // Get rid of duplicate code asap.
1803       const image_list_t *const ref = info.state->frame->ref;
1804       uint8_t(*ref_LX)[16] = info.state->frame->ref_LX;
1805 
1806       inter_merge_cand_t *merge_cand = info.merge_cand;
1807 
1808       int16_t mv[2][2];
1809       mv[0][0] = unipreds[0].inter.mv[0][0];
1810       mv[0][1] = unipreds[0].inter.mv[0][1];
1811       mv[1][0] = unipreds[1].inter.mv[1][0];
1812       mv[1][1] = unipreds[1].inter.mv[1][1];
1813 
1814       kvz_inter_recon_bipred(info.state,
1815         ref->images[ref_LX[0][unipreds[0].inter.mv_ref[0]]],
1816         ref->images[ref_LX[1][unipreds[1].inter.mv_ref[1]]],
1817         x, y,
1818         width,
1819         height,
1820         mv,
1821         lcu,
1822         true,
1823         false);
1824 
1825       const kvz_pixel *rec = &lcu->rec.y[SUB_SCU(y) * LCU_WIDTH + SUB_SCU(x)];
1826       const kvz_pixel *src = &lcu->ref.y[SUB_SCU(y) * LCU_WIDTH + SUB_SCU(x)];
1827       uint32_t cost =
1828         kvz_satd_any_size(width, height, rec, LCU_WIDTH, src, LCU_WIDTH);
1829 
1830       uint32_t bitcost[2] = { 0, 0 };
1831 
1832       cost += info.mvd_cost_func(info.state,
1833         unipreds[0].inter.mv[0][0],
1834         unipreds[0].inter.mv[0][1],
1835         0,
1836         info.mv_cand,
1837         NULL, 0, 0,
1838         &bitcost[0]);
1839       cost += info.mvd_cost_func(info.state,
1840         unipreds[1].inter.mv[1][0],
1841         unipreds[1].inter.mv[1][1],
1842         0,
1843         info.mv_cand,
1844         NULL, 0, 0,
1845         &bitcost[1]);
1846 
1847       const uint8_t mv_ref_coded[2] = {
1848         unipreds[0].inter.mv_ref[0],
1849         unipreds[1].inter.mv_ref[1]
1850       };
1851       const int extra_bits = mv_ref_coded[0] + mv_ref_coded[1] + 2 /* mv dir cost */;
1852       cost += info.state->lambda_sqrt * extra_bits + 0.5;
1853 
1854       if (cost < *inter_cost) {
1855         cur_cu->inter.mv_dir = 3;
1856 
1857         cur_cu->inter.mv_ref[0] = unipreds[0].inter.mv_ref[0];
1858         cur_cu->inter.mv_ref[1] = unipreds[1].inter.mv_ref[1];
1859 
1860         cur_cu->inter.mv[0][0] = unipreds[0].inter.mv[0][0];
1861         cur_cu->inter.mv[0][1] = unipreds[0].inter.mv[0][1];
1862         cur_cu->inter.mv[1][0] = unipreds[1].inter.mv[1][0];
1863         cur_cu->inter.mv[1][1] = unipreds[1].inter.mv[1][1];
1864         cur_cu->merged = 0;
1865 
1866         // Check every candidate to find a match
1867         for (int merge_idx = 0; merge_idx < info.num_merge_cand; merge_idx++) {
1868           if (merge_cand[merge_idx].mv[0][0] == cur_cu->inter.mv[0][0] &&
1869             merge_cand[merge_idx].mv[0][1] == cur_cu->inter.mv[0][1] &&
1870             merge_cand[merge_idx].mv[1][0] == cur_cu->inter.mv[1][0] &&
1871             merge_cand[merge_idx].mv[1][1] == cur_cu->inter.mv[1][1] &&
1872             merge_cand[merge_idx].ref[0] == cur_cu->inter.mv_ref[0] &&
1873             merge_cand[merge_idx].ref[1] == cur_cu->inter.mv_ref[1])
1874           {
1875             cur_cu->merged = 1;
1876             cur_cu->merge_idx = merge_idx;
1877             break;
1878           }
1879         }
1880 
1881         // Each motion vector has its own candidate
1882         for (int reflist = 0; reflist < 2; reflist++) {
1883           kvz_inter_get_mv_cand(info.state, x, y, width, height, info.mv_cand, cur_cu, lcu, reflist);
1884           int cu_mv_cand = select_mv_cand(
1885             info.state,
1886             info.mv_cand,
1887             cur_cu->inter.mv[reflist][0],
1888             cur_cu->inter.mv[reflist][1],
1889             NULL);
1890           CU_SET_MV_CAND(cur_cu, reflist, cu_mv_cand);
1891         }
1892 
1893         *inter_cost = cost;
1894         *inter_bitcost = bitcost[0] + bitcost[1] + extra_bits;
1895       }
1896     }
1897 
1898     // TODO: this probably should have a separate command line option
1899     if (cfg->rdo == 3) {
1900       search_pu_inter_bipred(&info, depth, lcu, cur_cu, inter_cost, inter_bitcost);
1901     }
1902   }
1903 
1904   // Compare best merge cost to amvp cost
1905   if (mrg_costs[0] < *inter_cost) {
1906     *inter_cost = mrg_costs[0];
1907     *inter_bitcost = 0; // TODO: Check this
1908     int merge_idx = mrg_cands[0];
1909     cur_cu->type = CU_INTER;
1910     cur_cu->merge_idx = merge_idx;
1911     cur_cu->inter.mv_dir = info.merge_cand[merge_idx].dir;
1912     cur_cu->inter.mv_ref[0] = info.merge_cand[merge_idx].ref[0];
1913     cur_cu->inter.mv_ref[1] = info.merge_cand[merge_idx].ref[1];
1914     cur_cu->inter.mv[0][0] = info.merge_cand[merge_idx].mv[0][0];
1915     cur_cu->inter.mv[0][1] = info.merge_cand[merge_idx].mv[0][1];
1916     cur_cu->inter.mv[1][0] = info.merge_cand[merge_idx].mv[1][0];
1917     cur_cu->inter.mv[1][1] = info.merge_cand[merge_idx].mv[1][1];
1918     cur_cu->merged = true;
1919     cur_cu->skipped = false;
1920   }
1921 
1922   if (*inter_cost < INT_MAX && cur_cu->inter.mv_dir == 1) {
1923     assert(fracmv_within_tile(&info, cur_cu->inter.mv[0][0], cur_cu->inter.mv[0][1]));
1924   }
1925 }
1926 
1927 /**
1928 * \brief Calculate inter coding cost for luma and chroma CBs (--rd=2 accuracy).
1929 *
1930 * Calculate inter coding cost of each CB. This should match the intra coding cost
1931 * calculation that is used on this RDO accuracy, since CU type decision is based
1932 * on this.
1933 *
1934 * The cost includes SSD distortion, transform unit tree bits and motion vector bits
1935 * for both luma and chroma if enabled.
1936 *
1937 * \param state       encoder state
1938 * \param x           x-coordinate of the CU
1939 * \param y           y-coordinate of the CU
1940 * \param depth       depth of the CU in the quadtree
1941 * \param lcu         containing LCU
1942 *
1943 * \param inter_cost    Return inter cost
1944 * \param inter_bitcost Return inter bitcost
1945 */
kvz_cu_cost_inter_rd2(encoder_state_t * const state,int x,int y,int depth,lcu_t * lcu,double * inter_cost,uint32_t * inter_bitcost)1946 void kvz_cu_cost_inter_rd2(encoder_state_t * const state,
1947   int x, int y, int depth,
1948   lcu_t *lcu,
1949   double   *inter_cost,
1950   uint32_t *inter_bitcost){
1951 
1952   cu_info_t *cur_cu = LCU_GET_CU_AT_PX(lcu, SUB_SCU(x), SUB_SCU(y));
1953   int tr_depth = MAX(1, depth);
1954   if (cur_cu->part_size != SIZE_2Nx2N) {
1955     tr_depth = depth + 1;
1956   }
1957   kvz_lcu_fill_trdepth(lcu, x, y, depth, tr_depth);
1958 
1959   const bool reconstruct_chroma = state->encoder_control->chroma_format != KVZ_CSP_400;
1960   kvz_inter_recon_cu(state, lcu, x, y, CU_WIDTH_FROM_DEPTH(depth), true, reconstruct_chroma);
1961   kvz_quantize_lcu_residual(state, true, reconstruct_chroma,
1962     x, y, depth,
1963     NULL,
1964     lcu,
1965     false);
1966 
1967   *inter_cost = kvz_cu_rd_cost_luma(state, SUB_SCU(x), SUB_SCU(y), depth, cur_cu, lcu);
1968   if (reconstruct_chroma) {
1969     *inter_cost += kvz_cu_rd_cost_chroma(state, SUB_SCU(x), SUB_SCU(y), depth, cur_cu, lcu);
1970   }
1971 
1972   *inter_cost += *inter_bitcost * state->lambda;
1973 }
1974 
1975 
1976 /**
1977  * \brief Update CU to have best modes at this depth.
1978  *
1979  * Only searches the 2Nx2N partition mode.
1980  *
1981  * \param state       encoder state
1982  * \param x           x-coordinate of the CU
1983  * \param y           y-coordinate of the CU
1984  * \param depth       depth of the CU in the quadtree
1985  * \param lcu         containing LCU
1986  *
1987  * \param inter_cost    Return inter cost
1988  * \param inter_bitcost Return inter bitcost
1989  */
kvz_search_cu_inter(encoder_state_t * const state,int x,int y,int depth,lcu_t * lcu,double * inter_cost,uint32_t * inter_bitcost)1990 void kvz_search_cu_inter(encoder_state_t * const state,
1991                          int x, int y, int depth,
1992                          lcu_t *lcu,
1993                          double   *inter_cost,
1994                          uint32_t *inter_bitcost)
1995 {
1996   search_pu_inter(state,
1997                   x, y, depth,
1998                   SIZE_2Nx2N, 0,
1999                   lcu,
2000                   inter_cost,
2001                   inter_bitcost);
2002 
2003   // Calculate more accurate cost when needed
2004   if (state->encoder_control->cfg.rdo >= 2) {
2005     kvz_cu_cost_inter_rd2(state,
2006       x, y, depth,
2007       lcu,
2008       inter_cost,
2009       inter_bitcost);
2010   }
2011 }
2012 
2013 
2014 /**
2015  * \brief Update CU to have best modes at this depth.
2016  *
2017  * Only searches the given partition mode.
2018  *
2019  * \param state       encoder state
2020  * \param x           x-coordinate of the CU
2021  * \param y           y-coordinate of the CU
2022  * \param depth       depth of the CU in the quadtree
2023  * \param part_mode   partition mode to search
2024  * \param lcu         containing LCU
2025  *
2026  * \param inter_cost    Return inter cost
2027  * \param inter_bitcost Return inter bitcost
2028  */
kvz_search_cu_smp(encoder_state_t * const state,int x,int y,int depth,part_mode_t part_mode,lcu_t * lcu,double * inter_cost,uint32_t * inter_bitcost)2029 void kvz_search_cu_smp(encoder_state_t * const state,
2030                        int x, int y,
2031                        int depth,
2032                        part_mode_t part_mode,
2033                        lcu_t *lcu,
2034                        double *inter_cost,
2035                        uint32_t *inter_bitcost)
2036 {
2037   const int num_pu  = kvz_part_mode_num_parts[part_mode];
2038   const int width   = LCU_WIDTH >> depth;
2039   const int y_local = SUB_SCU(y);
2040   const int x_local = SUB_SCU(x);
2041 
2042   *inter_cost    = 0;
2043   *inter_bitcost = 0;
2044 
2045   for (int i = 0; i < num_pu; ++i) {
2046     const int x_pu      = PU_GET_X(part_mode, width, x_local, i);
2047     const int y_pu      = PU_GET_Y(part_mode, width, y_local, i);
2048     const int width_pu  = PU_GET_W(part_mode, width, i);
2049     const int height_pu = PU_GET_H(part_mode, width, i);
2050     cu_info_t *cur_pu   = LCU_GET_CU_AT_PX(lcu, x_pu, y_pu);
2051 
2052     cur_pu->type      = CU_INTER;
2053     cur_pu->part_size = part_mode;
2054     cur_pu->depth     = depth;
2055     cur_pu->qp        = state->qp;
2056 
2057     double cost      = MAX_INT;
2058     uint32_t bitcost = MAX_INT;
2059 
2060     search_pu_inter(state, x, y, depth, part_mode, i, lcu, &cost, &bitcost);
2061 
2062     if (cost >= MAX_INT) {
2063       // Could not find any motion vector.
2064       *inter_cost    = MAX_INT;
2065       *inter_bitcost = MAX_INT;
2066       return;
2067     }
2068 
2069     *inter_cost    += cost;
2070     *inter_bitcost += bitcost;
2071 
2072     for (int y = y_pu; y < y_pu + height_pu; y += SCU_WIDTH) {
2073       for (int x = x_pu; x < x_pu + width_pu; x += SCU_WIDTH) {
2074         cu_info_t *scu = LCU_GET_CU_AT_PX(lcu, x, y);
2075         scu->type = CU_INTER;
2076         scu->inter = cur_pu->inter;
2077       }
2078     }
2079   }
2080 
2081   // Calculate more accurate cost when needed
2082   if (state->encoder_control->cfg.rdo >= 2) {
2083     kvz_cu_cost_inter_rd2(state,
2084       x, y, depth,
2085       lcu,
2086       inter_cost,
2087       inter_bitcost);
2088   }
2089 
2090   // Count bits spent for coding the partition mode.
2091   int smp_extra_bits = 1; // horizontal or vertical
2092   if (state->encoder_control->cfg.amp_enable) {
2093     smp_extra_bits += 1; // symmetric or asymmetric
2094     if (part_mode != SIZE_2NxN && part_mode != SIZE_Nx2N) {
2095       smp_extra_bits += 1; // U,L or D,R
2096     }
2097   }
2098   // The transform is split for SMP and AMP blocks so we need more bits for
2099   // coding the CBF.
2100   smp_extra_bits += 6;
2101 
2102   *inter_cost += (state->encoder_control->cfg.rdo >= 2 ? state->lambda : state->lambda_sqrt) * smp_extra_bits;
2103   *inter_bitcost += smp_extra_bits;
2104 }
2105