1 /*****************************************************************************
2 * This file is part of Kvazaar HEVC encoder.
3 *
4 * Copyright (c) 2021, Tampere University, ITU/ISO/IEC, project contributors
5 * All rights reserved.
6 *
7 * Redistribution and use in source and binary forms, with or without modification,
8 * are permitted provided that the following conditions are met:
9 *
10 * * Redistributions of source code must retain the above copyright notice, this
11 * list of conditions and the following disclaimer.
12 *
13 * * Redistributions in binary form must reproduce the above copyright notice, this
14 * list of conditions and the following disclaimer in the documentation and/or
15 * other materials provided with the distribution.
16 *
17 * * Neither the name of the Tampere University or ITU/ISO/IEC nor the names of its
18 * contributors may be used to endorse or promote products derived from
19 * this software without specific prior written permission.
20 *
21 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
22 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
23 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
25 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
26 * INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION HOWEVER CAUSED AND ON
28 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 * INCLUDING NEGLIGENCE OR OTHERWISE ARISING IN ANY WAY OUT OF THE USE OF THIS
31 ****************************************************************************/
32
33 #include "search_inter.h"
34
35 #include <limits.h>
36 #include <stdlib.h>
37
38 #include "cabac.h"
39 #include "encoder.h"
40 #include "image.h"
41 #include "imagelist.h"
42 #include "inter.h"
43 #include "kvazaar.h"
44 #include "rdo.h"
45 #include "search.h"
46 #include "strategies/strategies-ipol.h"
47 #include "strategies/strategies-picture.h"
48 #include "transform.h"
49 #include "videoframe.h"
50
51 typedef struct {
52 encoder_state_t *state;
53
54 /**
55 * \brief Current frame
56 */
57 const kvz_picture *pic;
58 /**
59 * \brief Reference frame
60 */
61 const kvz_picture *ref;
62
63 /**
64 * \brief Index of the reference frame
65 */
66 int32_t ref_idx;
67
68 /**
69 * \brief Top-left corner of the PU
70 */
71 const vector2d_t origin;
72 int32_t width;
73 int32_t height;
74
75 int16_t mv_cand[2][2];
76 inter_merge_cand_t merge_cand[MRG_MAX_NUM_CANDS];
77 int32_t num_merge_cand;
78
79 kvz_mvd_cost_func *mvd_cost_func;
80
81 /**
82 * \brief Best motion vector among the ones tested so far
83 */
84 vector2d_t best_mv;
85 /**
86 * \brief Cost of best_mv
87 */
88 uint32_t best_cost;
89 /**
90 * \brief Bit cost of best_mv
91 */
92 uint32_t best_bitcost;
93
94 /**
95 * \brief Possible optimized SAD implementation for the width, leave as
96 * NULL for arbitrary-width blocks
97 */
98 optimized_sad_func_ptr_t optimized_sad;
99
100 } inter_search_info_t;
101
102
103 /**
104 * \return True if referred block is within current tile.
105 */
fracmv_within_tile(const inter_search_info_t * info,int x,int y)106 static INLINE bool fracmv_within_tile(const inter_search_info_t *info, int x, int y)
107 {
108 const encoder_control_t *ctrl = info->state->encoder_control;
109
110 const bool is_frac_luma = x % 4 != 0 || y % 4 != 0;
111 const bool is_frac_chroma = x % 8 != 0 || y % 8 != 0;
112
113 if (ctrl->cfg.owf && ctrl->cfg.wpp) {
114 // Check that the block does not reference pixels that are not final.
115
116 // Margin as luma pixels.
117 int margin = 0;
118 if (is_frac_luma) {
119 // Fractional motion estimation needs up to 4 pixels outside the
120 // block.
121 margin = 4;
122 } else if (is_frac_chroma) {
123 // Odd chroma interpolation needs up to 2 luma pixels outside the
124 // block.
125 margin = 2;
126 }
127
128 if (ctrl->cfg.sao_type) {
129 // Make sure we don't refer to pixels for which SAO reconstruction
130 // has not been done.
131 margin += SAO_DELAY_PX;
132 } else if (ctrl->cfg.deblock_enable) {
133 // Make sure we don't refer to pixels that have not been deblocked.
134 margin += DEBLOCK_DELAY_PX;
135 }
136
137 // Coordinates of the top-left corner of the containing LCU.
138 const vector2d_t orig_lcu = {
139 .x = info->origin.x / LCU_WIDTH,
140 .y = info->origin.y / LCU_WIDTH,
141 };
142 // Difference between the coordinates of the LCU containing the
143 // bottom-left corner of the referenced block and the LCU containing
144 // this block.
145 const vector2d_t mv_lcu = {
146 ((info->origin.x + info->width + margin) * 4 + x) / (LCU_WIDTH << 2) - orig_lcu.x,
147 ((info->origin.y + info->height + margin) * 4 + y) / (LCU_WIDTH << 2) - orig_lcu.y,
148 };
149
150 if (mv_lcu.y > ctrl->max_inter_ref_lcu.down) {
151 return false;
152 }
153
154 if (mv_lcu.x + mv_lcu.y >
155 ctrl->max_inter_ref_lcu.down + ctrl->max_inter_ref_lcu.right)
156 {
157 return false;
158 }
159 }
160
161 if (ctrl->cfg.mv_constraint == KVZ_MV_CONSTRAIN_NONE) {
162 return true;
163 }
164
165 // Margin as luma quater pixels.
166 int margin = 0;
167 if (ctrl->cfg.mv_constraint == KVZ_MV_CONSTRAIN_FRAME_AND_TILE_MARGIN) {
168 if (is_frac_luma) {
169 margin = 4 << 2;
170 } else if (is_frac_chroma) {
171 margin = 2 << 2;
172 }
173 }
174
175 // TODO implement KVZ_MV_CONSTRAIN_FRAM and KVZ_MV_CONSTRAIN_TILE.
176 const vector2d_t abs_mv = {
177 info->origin.x * 4 + x,
178 info->origin.y * 4 + y,
179 };
180
181 // Check that both margin constraints are satisfied.
182 const int from_right =
183 (info->state->tile->frame->width << 2) - (abs_mv.x + (info->width << 2));
184 const int from_bottom =
185 (info->state->tile->frame->height << 2) - (abs_mv.y + (info->height << 2));
186
187 return abs_mv.x >= margin &&
188 abs_mv.y >= margin &&
189 from_right >= margin &&
190 from_bottom >= margin;
191 }
192
193
194 /**
195 * \return True if referred block is within current tile.
196 */
intmv_within_tile(const inter_search_info_t * info,int x,int y)197 static INLINE bool intmv_within_tile(const inter_search_info_t *info, int x, int y)
198 {
199 return fracmv_within_tile(info, x * 4, y * 4);
200 }
201
202
203 /**
204 * \brief Calculate cost for an integer motion vector.
205 *
206 * Updates info->best_mv, info->best_cost and info->best_bitcost to the new
207 * motion vector if it yields a lower cost than the current one.
208 *
209 * If the motion vector violates the MV constraints for tiles or WPP, the
210 * cost is not set.
211 *
212 * \return true if info->best_mv was changed, false otherwise
213 */
check_mv_cost(inter_search_info_t * info,int x,int y)214 static bool check_mv_cost(inter_search_info_t *info, int x, int y)
215 {
216 if (!intmv_within_tile(info, x, y)) return false;
217
218 uint32_t bitcost = 0;
219 uint32_t cost = kvz_image_calc_sad(
220 info->pic,
221 info->ref,
222 info->origin.x,
223 info->origin.y,
224 info->state->tile->offset_x + info->origin.x + x,
225 info->state->tile->offset_y + info->origin.y + y,
226 info->width,
227 info->height,
228 info->optimized_sad
229 );
230
231 if (cost >= info->best_cost) return false;
232
233 cost += info->mvd_cost_func(
234 info->state,
235 x, y, 2,
236 info->mv_cand,
237 info->merge_cand,
238 info->num_merge_cand,
239 info->ref_idx,
240 &bitcost
241 );
242
243 if (cost >= info->best_cost) return false;
244
245 // Set to motion vector in quarter pixel precision.
246 info->best_mv.x = x * 4;
247 info->best_mv.y = y * 4;
248 info->best_cost = cost;
249 info->best_bitcost = bitcost;
250
251 return true;
252 }
253
254
get_ep_ex_golomb_bitcost(unsigned symbol)255 static unsigned get_ep_ex_golomb_bitcost(unsigned symbol)
256 {
257 // Calculate 2 * log2(symbol + 2)
258
259 unsigned bins = 0;
260 symbol += 2;
261 if (symbol >= 1 << 8) { bins += 16; symbol >>= 8; }
262 if (symbol >= 1 << 4) { bins += 8; symbol >>= 4; }
263 if (symbol >= 1 << 2) { bins += 4; symbol >>= 2; }
264 if (symbol >= 1 << 1) { bins += 2; }
265
266 // TODO: It might be a good idea to put a small slope on this function to
267 // make sure any search function that follows the gradient heads towards
268 // a smaller MVD, but that would require fractinal costs and bits being
269 // used everywhere in inter search.
270 // return num_bins + 0.001 * symbol;
271
272 return bins;
273 }
274
275
276 /**
277 * \brief Checks if mv is one of the merge candidates.
278 * \return true if found else return false
279 */
mv_in_merge(const inter_search_info_t * info,vector2d_t mv)280 static bool mv_in_merge(const inter_search_info_t *info, vector2d_t mv)
281 {
282 for (int i = 0; i < info->num_merge_cand; ++i) {
283 if (info->merge_cand[i].dir == 3) continue;
284 const vector2d_t merge_mv = {
285 (info->merge_cand[i].mv[info->merge_cand[i].dir - 1][0] + 2) >> 2,
286 (info->merge_cand[i].mv[info->merge_cand[i].dir - 1][1] + 2) >> 2
287 };
288 if (merge_mv.x == mv.x && merge_mv.y == mv.y) {
289 return true;
290 }
291 }
292 return false;
293 }
294
295
296 /**
297 * \brief Select starting point for integer motion estimation search.
298 *
299 * Checks the zero vector, extra_mv and merge candidates and updates
300 * info->best_mv to the best one.
301 */
select_starting_point(inter_search_info_t * info,vector2d_t extra_mv)302 static void select_starting_point(inter_search_info_t *info, vector2d_t extra_mv)
303 {
304 // Check the 0-vector, so we can ignore all 0-vectors in the merge cand list.
305 check_mv_cost(info, 0, 0);
306
307 // Change to integer precision.
308 extra_mv.x >>= 2;
309 extra_mv.y >>= 2;
310
311 // Check mv_in if it's not one of the merge candidates.
312 if ((extra_mv.x != 0 || extra_mv.y != 0) && !mv_in_merge(info, extra_mv)) {
313 check_mv_cost(info, extra_mv.x, extra_mv.y);
314 }
315
316 // Go through candidates
317 for (unsigned i = 0; i < info->num_merge_cand; ++i) {
318 if (info->merge_cand[i].dir == 3) continue;
319
320 int x = (info->merge_cand[i].mv[info->merge_cand[i].dir - 1][0] + 2) >> 2;
321 int y = (info->merge_cand[i].mv[info->merge_cand[i].dir - 1][1] + 2) >> 2;
322
323 if (x == 0 && y == 0) continue;
324
325 check_mv_cost(info, x, y);
326 }
327 }
328
329
get_mvd_coding_cost(const encoder_state_t * state,const cabac_data_t * cabac,const int32_t mvd_hor,const int32_t mvd_ver)330 static uint32_t get_mvd_coding_cost(const encoder_state_t *state,
331 const cabac_data_t* cabac,
332 const int32_t mvd_hor,
333 const int32_t mvd_ver)
334 {
335 unsigned bitcost = 0;
336 const vector2d_t abs_mvd = { abs(mvd_hor), abs(mvd_ver) };
337
338 bitcost += get_ep_ex_golomb_bitcost(abs_mvd.x) << CTX_FRAC_BITS;
339 bitcost += get_ep_ex_golomb_bitcost(abs_mvd.y) << CTX_FRAC_BITS;
340
341 // Round and shift back to integer bits.
342 return (bitcost + CTX_FRAC_HALF_BIT) >> CTX_FRAC_BITS;
343 }
344
345
select_mv_cand(const encoder_state_t * state,int16_t mv_cand[2][2],int32_t mv_x,int32_t mv_y,uint32_t * cost_out)346 static int select_mv_cand(const encoder_state_t *state,
347 int16_t mv_cand[2][2],
348 int32_t mv_x,
349 int32_t mv_y,
350 uint32_t *cost_out)
351 {
352 const bool same_cand =
353 (mv_cand[0][0] == mv_cand[1][0] && mv_cand[0][1] == mv_cand[1][1]);
354
355 if (same_cand && !cost_out) {
356 // Pick the first one if both candidates are the same.
357 return 0;
358 }
359
360 uint32_t (*mvd_coding_cost)(const encoder_state_t * const state,
361 const cabac_data_t*,
362 int32_t, int32_t);
363 if (state->encoder_control->cfg.mv_rdo) {
364 mvd_coding_cost = kvz_get_mvd_coding_cost_cabac;
365 } else {
366 mvd_coding_cost = get_mvd_coding_cost;
367 }
368
369 uint32_t cand1_cost = mvd_coding_cost(
370 state, &state->cabac,
371 mv_x - mv_cand[0][0],
372 mv_y - mv_cand[0][1]);
373
374 uint32_t cand2_cost;
375 if (same_cand) {
376 cand2_cost = cand1_cost;
377 } else {
378 cand2_cost = mvd_coding_cost(
379 state, &state->cabac,
380 mv_x - mv_cand[1][0],
381 mv_y - mv_cand[1][1]);
382 }
383
384 if (cost_out) {
385 *cost_out = MIN(cand1_cost, cand2_cost);
386 }
387
388 // Pick the second candidate if it has lower cost.
389 return cand2_cost < cand1_cost ? 1 : 0;
390 }
391
392
calc_mvd_cost(const encoder_state_t * state,int x,int y,int mv_shift,int16_t mv_cand[2][2],inter_merge_cand_t merge_cand[MRG_MAX_NUM_CANDS],int16_t num_cand,int32_t ref_idx,uint32_t * bitcost)393 static uint32_t calc_mvd_cost(const encoder_state_t *state,
394 int x,
395 int y,
396 int mv_shift,
397 int16_t mv_cand[2][2],
398 inter_merge_cand_t merge_cand[MRG_MAX_NUM_CANDS],
399 int16_t num_cand,
400 int32_t ref_idx,
401 uint32_t *bitcost)
402 {
403 uint32_t temp_bitcost = 0;
404 uint32_t merge_idx;
405 int8_t merged = 0;
406
407 x *= 1 << mv_shift;
408 y *= 1 << mv_shift;
409
410 // Check every candidate to find a match
411 for(merge_idx = 0; merge_idx < (uint32_t)num_cand; merge_idx++) {
412 if (merge_cand[merge_idx].dir == 3) continue;
413 if (merge_cand[merge_idx].mv[merge_cand[merge_idx].dir - 1][0] == x &&
414 merge_cand[merge_idx].mv[merge_cand[merge_idx].dir - 1][1] == y &&
415 state->frame->ref_LX[merge_cand[merge_idx].dir - 1][
416 merge_cand[merge_idx].ref[merge_cand[merge_idx].dir - 1]
417 ] == ref_idx) {
418 temp_bitcost += merge_idx;
419 merged = 1;
420 break;
421 }
422 }
423
424 // Check mvd cost only if mv is not merged
425 if (!merged) {
426 uint32_t mvd_cost = 0;
427 select_mv_cand(state, mv_cand, x, y, &mvd_cost);
428 temp_bitcost += mvd_cost;
429 }
430 *bitcost = temp_bitcost;
431 return temp_bitcost*(int32_t)(state->lambda_sqrt + 0.5);
432 }
433
434
early_terminate(inter_search_info_t * info)435 static bool early_terminate(inter_search_info_t *info)
436 {
437 static const vector2d_t small_hexbs[7] = {
438 { 0, -1 }, { -1, 0 }, { 0, 1 }, { 1, 0 },
439 { 0, -1 }, { -1, 0 }, { 0, 0 },
440 };
441
442 vector2d_t mv = { info->best_mv.x >> 2, info->best_mv.y >> 2 };
443
444 int first_index = 0;
445 int last_index = 3;
446
447 for (int k = 0; k < 2; ++k) {
448 double threshold;
449 if (info->state->encoder_control->cfg.me_early_termination ==
450 KVZ_ME_EARLY_TERMINATION_SENSITIVE)
451 {
452 threshold = info->best_cost * 0.95;
453 } else {
454 threshold = info->best_cost;
455 }
456
457 int best_index = 6;
458 for (int i = first_index; i <= last_index; i++) {
459 int x = mv.x + small_hexbs[i].x;
460 int y = mv.y + small_hexbs[i].y;
461
462 if (check_mv_cost(info, x, y)) {
463 best_index = i;
464 }
465 }
466
467 // Adjust the movement vector
468 mv.x += small_hexbs[best_index].x;
469 mv.y += small_hexbs[best_index].y;
470
471 // If best match is not better than threshold, we stop the search.
472 if (info->best_cost >= threshold) {
473 return true;
474 }
475
476 first_index = (best_index + 3) % 4;
477 last_index = first_index + 2;
478 }
479 return false;
480 }
481
482
kvz_tz_pattern_search(inter_search_info_t * info,unsigned pattern_type,const int iDist,vector2d_t mv,int * best_dist)483 void kvz_tz_pattern_search(inter_search_info_t *info,
484 unsigned pattern_type,
485 const int iDist,
486 vector2d_t mv,
487 int *best_dist)
488 {
489 assert(pattern_type < 4);
490
491 //implemented search patterns
492 const vector2d_t pattern[4][8] = {
493 //diamond (8 points)
494 //[ ][ ][ ][ ][1][ ][ ][ ][ ]
495 //[ ][ ][ ][ ][ ][ ][ ][ ][ ]
496 //[ ][ ][8][ ][ ][ ][5][ ][ ]
497 //[ ][ ][ ][ ][ ][ ][ ][ ][ ]
498 //[4][ ][ ][ ][o][ ][ ][ ][2]
499 //[ ][ ][ ][ ][ ][ ][ ][ ][ ]
500 //[ ][ ][7][ ][ ][ ][6][ ][ ]
501 //[ ][ ][ ][ ][ ][ ][ ][ ][ ]
502 //[ ][ ][ ][ ][3][ ][ ][ ][ ]
503 {
504 { 0, iDist }, { iDist, 0 }, { 0, -iDist }, { -iDist, 0 },
505 { iDist / 2, iDist / 2 }, { iDist / 2, -iDist / 2 }, { -iDist / 2, -iDist / 2 }, { -iDist / 2, iDist / 2 }
506 },
507
508 //square (8 points)
509 //[8][ ][ ][ ][1][ ][ ][ ][2]
510 //[ ][ ][ ][ ][ ][ ][ ][ ][ ]
511 //[ ][ ][ ][ ][ ][ ][ ][ ][ ]
512 //[ ][ ][ ][ ][ ][ ][ ][ ][ ]
513 //[7][ ][ ][ ][o][ ][ ][ ][3]
514 //[ ][ ][ ][ ][ ][ ][ ][ ][ ]
515 //[ ][ ][ ][ ][ ][ ][ ][ ][ ]
516 //[ ][ ][ ][ ][ ][ ][ ][ ][ ]
517 //[6][ ][ ][ ][5][ ][ ][ ][4]
518 {
519 { 0, iDist }, { iDist, iDist }, { iDist, 0 }, { iDist, -iDist }, { 0, -iDist },
520 { -iDist, -iDist }, { -iDist, 0 }, { -iDist, iDist }
521 },
522
523 //octagon (8 points)
524 //[ ][ ][5][ ][ ][ ][1][ ][ ]
525 //[ ][ ][ ][ ][ ][ ][ ][ ][ ]
526 //[ ][ ][ ][ ][ ][ ][ ][ ][2]
527 //[4][ ][ ][ ][ ][ ][ ][ ][ ]
528 //[ ][ ][ ][ ][o][ ][ ][ ][ ]
529 //[ ][ ][ ][ ][ ][ ][ ][ ][ ]
530 //[8][ ][ ][ ][ ][ ][ ][ ][6]
531 //[ ][ ][ ][ ][ ][ ][ ][ ][ ]
532 //[ ][ ][7][ ][ ][ ][3][ ][ ]
533 {
534 { iDist / 2, iDist }, { iDist, iDist / 2 }, { iDist / 2, -iDist }, { -iDist, iDist / 2 },
535 { -iDist / 2, iDist }, { iDist, -iDist / 2 }, { -iDist / 2, -iDist }, { -iDist, -iDist / 2 }
536 },
537
538 //hexagon (6 points)
539 //[ ][ ][5][ ][ ][ ][1][ ][ ]
540 //[ ][ ][ ][ ][ ][ ][ ][ ][ ]
541 //[ ][ ][ ][ ][ ][ ][ ][ ][ ]
542 //[ ][ ][ ][ ][ ][ ][ ][ ][ ]
543 //[4][ ][ ][ ][o][ ][ ][ ][2]
544 //[ ][ ][ ][ ][ ][ ][ ][ ][ ]
545 //[ ][ ][ ][ ][ ][ ][ ][ ][ ]
546 //[ ][ ][ ][ ][ ][ ][ ][ ][ ]
547 //[ ][ ][6][ ][ ][ ][3][ ][ ]
548 {
549 { iDist / 2, iDist }, { iDist, 0 }, { iDist / 2, -iDist }, { -iDist, 0 },
550 { iDist / 2, iDist }, { -iDist / 2, -iDist }, { 0, 0 }, { 0, 0 }
551 }
552 };
553
554 // Set the number of points to be checked.
555 int n_points;
556 if (iDist == 1) {
557 switch (pattern_type) {
558 case 0:
559 n_points = 4;
560 break;
561 case 2:
562 n_points = 4;
563 break;
564 case 3:
565 n_points = 4;
566 break;
567 default:
568 n_points = 8;
569 break;
570 };
571 } else {
572 switch (pattern_type) {
573 case 3:
574 n_points = 6;
575 break;
576 default:
577 n_points = 8;
578 break;
579 };
580 }
581
582 // Compute SAD values for all chosen points.
583 int best_index = -1;
584 for (int i = 0; i < n_points; i++) {
585 vector2d_t offset = pattern[pattern_type][i];
586 int x = mv.x + offset.x;
587 int y = mv.y + offset.y;
588
589 if (check_mv_cost(info, x, y)) {
590 best_index = i;
591 }
592 }
593
594 if (best_index >= 0) {
595 *best_dist = iDist;
596 }
597 }
598
599
kvz_tz_raster_search(inter_search_info_t * info,int iSearchRange,int iRaster)600 void kvz_tz_raster_search(inter_search_info_t *info,
601 int iSearchRange,
602 int iRaster)
603 {
604 const vector2d_t mv = { info->best_mv.x >> 2, info->best_mv.y >> 2 };
605
606 //compute SAD values for every point in the iRaster downsampled version of the current search area
607 for (int y = iSearchRange; y >= -iSearchRange; y -= iRaster) {
608 for (int x = -iSearchRange; x <= iSearchRange; x += iRaster) {
609 check_mv_cost(info, mv.x + x, mv.y + y);
610 }
611 }
612 }
613
614
tz_search(inter_search_info_t * info,vector2d_t extra_mv)615 static void tz_search(inter_search_info_t *info, vector2d_t extra_mv)
616 {
617 //TZ parameters
618 const int iSearchRange = 96; // search range for each stage
619 const int iRaster = 5; // search distance limit and downsampling factor for step 3
620 const unsigned step2_type = 0; // search patterns for steps 2 and 4
621 const unsigned step4_type = 0;
622 const bool use_raster_scan = false; // enable step 3
623 const bool use_raster_refinement = false; // enable step 4 mode 1
624 const bool use_star_refinement = true; // enable step 4 mode 2 (only one mode will be executed)
625
626 int best_dist = 0;
627 info->best_cost = UINT32_MAX;
628
629 // Select starting point from among merge candidates. These should
630 // include both mv_cand vectors and (0, 0).
631 select_starting_point(info, extra_mv);
632
633 // Check if we should stop search
634 if (info->state->encoder_control->cfg.me_early_termination &&
635 early_terminate(info))
636 {
637 return;
638 }
639
640 vector2d_t start = { info->best_mv.x >> 2, info->best_mv.y >> 2 };
641
642 // step 2, grid search
643 int rounds_without_improvement = 0;
644 for (int iDist = 1; iDist <= iSearchRange; iDist *= 2) {
645 kvz_tz_pattern_search(info, step2_type, iDist, start, &best_dist);
646
647 // Break the loop if the last three rounds didn't produce a better MV.
648 if (best_dist != iDist) rounds_without_improvement++;
649 if (rounds_without_improvement >= 3) break;
650 }
651
652 if (start.x != 0 || start.y != 0) {
653 // repeat step 2 starting from the zero MV
654 start.x = 0;
655 start.y = 0;
656 rounds_without_improvement = 0;
657 for (int iDist = 1; iDist <= iSearchRange/2; iDist *= 2) {
658 kvz_tz_pattern_search(info, step2_type, iDist, start, &best_dist);
659
660 if (best_dist != iDist) rounds_without_improvement++;
661 if (rounds_without_improvement >= 3) break;
662 }
663 }
664
665 //step 3, raster scan
666 if (use_raster_scan && best_dist > iRaster) {
667 best_dist = iRaster;
668 kvz_tz_raster_search(info, iSearchRange, iRaster);
669 }
670
671 //step 4
672
673 //raster refinement
674 if (use_raster_refinement && best_dist > 0) {
675 for (int iDist = best_dist >> 1; iDist > 0; iDist >>= 1) {
676 start.x = info->best_mv.x >> 2;
677 start.y = info->best_mv.y >> 2;
678 kvz_tz_pattern_search(info, step4_type, iDist, start, &best_dist);
679 }
680 }
681
682 //star refinement (repeat step 2 for the current starting point)
683 while (use_star_refinement && best_dist > 0) {
684 best_dist = 0;
685 start.x = info->best_mv.x >> 2;
686 start.y = info->best_mv.y >> 2;
687 for (int iDist = 1; iDist <= iSearchRange; iDist *= 2) {
688 kvz_tz_pattern_search(info, step4_type, iDist, start, &best_dist);
689 }
690 }
691 }
692
693
694 /**
695 * \brief Do motion search using the HEXBS algorithm.
696 *
697 * \param info search info
698 * \param extra_mv extra motion vector to check
699 * \param steps how many steps are done at maximum before exiting, does not affect the final step
700 *
701 * Motion vector is searched by first searching iteratively with the large
702 * hexagon pattern until the best match is at the center of the hexagon.
703 * As a final step a smaller hexagon is used to check the adjacent pixels.
704 *
705 * If a non 0,0 predicted motion vector predictor is given as extra_mv,
706 * the 0,0 vector is also tried. This is hoped to help in the case where
707 * the predicted motion vector is way off. In the future even more additional
708 * points like 0,0 might be used, such as vectors from top or left.
709 */
hexagon_search(inter_search_info_t * info,vector2d_t extra_mv,uint32_t steps)710 static void hexagon_search(inter_search_info_t *info, vector2d_t extra_mv, uint32_t steps)
711 {
712 // The start of the hexagonal pattern has been repeated at the end so that
713 // the indices between 1-6 can be used as the start of a 3-point list of new
714 // points to search.
715 // 6--1,7
716 // / \ =)
717 // 5 0 2,8
718 // \ /
719 // 4---3
720 static const vector2d_t large_hexbs[9] = {
721 { 0, 0 },
722 { 1, -2 }, { 2, 0 }, { 1, 2 }, { -1, 2 }, { -2, 0 }, { -1, -2 },
723 { 1, -2 }, { 2, 0 }
724 };
725 // This is used as the last step of the hexagon search.
726 // 1
727 // 2 0 3
728 // 4
729 static const vector2d_t small_hexbs[9] = {
730 { 0, 0 },
731 { 0, -1 }, { -1, 0 }, { 1, 0 }, { 0, 1 },
732 { -1, -1 }, { 1, -1 }, { -1, 1 }, { 1, 1 }
733 };
734
735 info->best_cost = UINT32_MAX;
736
737 // Select starting point from among merge candidates. These should
738 // include both mv_cand vectors and (0, 0).
739 select_starting_point(info, extra_mv);
740
741 // Check if we should stop search
742 if (info->state->encoder_control->cfg.me_early_termination &&
743 early_terminate(info))
744 {
745 return;
746 }
747
748 vector2d_t mv = { info->best_mv.x >> 2, info->best_mv.y >> 2 };
749
750 // Current best index, either to merge_cands, large_hexbs or small_hexbs.
751 int best_index = 0;
752
753 // Search the initial 7 points of the hexagon.
754 for (int i = 1; i < 7; ++i) {
755 if (check_mv_cost(info, mv.x + large_hexbs[i].x, mv.y + large_hexbs[i].y)) {
756 best_index = i;
757 }
758 }
759
760 // Iteratively search the 3 new points around the best match, until the best
761 // match is in the center.
762 while (best_index != 0 && steps != 0) {
763 // decrement count if enabled
764 if (steps > 0) steps -= 1;
765
766 // Starting point of the 3 offsets to be searched.
767 unsigned start;
768 if (best_index == 1) {
769 start = 6;
770 } else if (best_index == 8) {
771 start = 1;
772 } else {
773 start = best_index - 1;
774 }
775
776 // Move the center to the best match.
777 mv.x += large_hexbs[best_index].x;
778 mv.y += large_hexbs[best_index].y;
779 best_index = 0;
780
781 // Iterate through the next 3 points.
782 for (int i = 0; i < 3; ++i) {
783 vector2d_t offset = large_hexbs[start + i];
784 if (check_mv_cost(info, mv.x + offset.x, mv.y + offset.y)) {
785 best_index = start + i;
786 }
787 }
788 }
789
790 // Move the center to the best match.
791 //mv.x += large_hexbs[best_index].x;
792 //mv.y += large_hexbs[best_index].y;
793
794 // Do the final step of the search with a small pattern.
795 for (int i = 1; i < 9; ++i) {
796 check_mv_cost(info, mv.x + small_hexbs[i].x, mv.y + small_hexbs[i].y);
797 }
798 }
799
800 /**
801 * \brief Do motion search using the diamond algorithm.
802 *
803 * \param info search info
804 * \param extra_mv extra motion vector to check
805 * \param steps how many steps are done at maximum before exiting
806 *
807 * Motion vector is searched by searching iteratively with a diamond-shaped
808 * pattern. We take care of not checking the direction we came from, but
809 * further checking for avoiding visits to already visited points is not done.
810 *
811 * If a non 0,0 predicted motion vector predictor is given as extra_mv,
812 * the 0,0 vector is also tried. This is hoped to help in the case where
813 * the predicted motion vector is way off. In the future even more additional
814 * points like 0,0 might be used, such as vectors from top or left.
815 **/
diamond_search(inter_search_info_t * info,vector2d_t extra_mv,uint32_t steps)816 static void diamond_search(inter_search_info_t *info, vector2d_t extra_mv, uint32_t steps)
817 {
818 enum diapos {
819 DIA_UP = 0,
820 DIA_RIGHT = 1,
821 DIA_LEFT = 2,
822 DIA_DOWN = 3,
823 DIA_CENTER = 4,
824 };
825
826 // a diamond shape with the center included
827 // 0
828 // 2 4 1
829 // 3
830 static const vector2d_t diamond[5] = {
831 {0, -1}, {1, 0}, {0, 1}, {-1, 0},
832 {0, 0}
833 };
834
835 info->best_cost = UINT32_MAX;
836
837 // Select starting point from among merge candidates. These should
838 // include both mv_cand vectors and (0, 0).
839 select_starting_point(info, extra_mv);
840
841 // Check if we should stop search
842 if (info->state->encoder_control->cfg.me_early_termination &&
843 early_terminate(info))
844 {
845 return;
846 }
847
848 // current motion vector
849 vector2d_t mv = { info->best_mv.x >> 2, info->best_mv.y >> 2 };
850
851 // current best index
852 enum diapos best_index = DIA_CENTER;
853
854 // initial search of the points of the diamond
855 for (int i = 0; i < 5; ++i) {
856 if (check_mv_cost(info, mv.x + diamond[i].x, mv.y + diamond[i].y)) {
857 best_index = i;
858 }
859 }
860
861 if (best_index == DIA_CENTER) {
862 // the center point was the best in initial check
863 return;
864 }
865
866 // Move the center to the best match.
867 mv.x += diamond[best_index].x;
868 mv.y += diamond[best_index].y;
869
870 // the arrival direction, the index of the diamond member that will be excluded
871 enum diapos from_dir = DIA_CENTER;
872
873 // whether we found a better candidate this iteration
874 uint8_t better_found;
875
876 do {
877 better_found = 0;
878 // decrement count if enabled
879 if (steps > 0) steps -= 1;
880
881 // search the points of the diamond
882 for (int i = 0; i < 4; ++i) {
883 // this is where we came from so it's checked already
884 if (i == from_dir) continue;
885
886 if (check_mv_cost(info, mv.x + diamond[i].x, mv.y + diamond[i].y)) {
887 best_index = i;
888 better_found = 1;
889 }
890 }
891
892 if (better_found) {
893 // Move the center to the best match.
894 mv.x += diamond[best_index].x;
895 mv.y += diamond[best_index].y;
896
897 // record where we came from to the next iteration
898 // the xor operation flips the orientation
899 from_dir = best_index ^ 0x3;
900 }
901 } while (better_found && steps != 0);
902 // and we're done
903 }
904
905
search_mv_full(inter_search_info_t * info,int32_t search_range,vector2d_t extra_mv)906 static void search_mv_full(inter_search_info_t *info,
907 int32_t search_range,
908 vector2d_t extra_mv)
909 {
910 // Search around the 0-vector.
911 for (int y = -search_range; y <= search_range; y++) {
912 for (int x = -search_range; x <= search_range; x++) {
913 check_mv_cost(info, x, y);
914 }
915 }
916
917 // Change to integer precision.
918 extra_mv.x >>= 2;
919 extra_mv.y >>= 2;
920
921 // Check around extra_mv if it's not one of the merge candidates.
922 if (!mv_in_merge(info, extra_mv)) {
923 for (int y = -search_range; y <= search_range; y++) {
924 for (int x = -search_range; x <= search_range; x++) {
925 check_mv_cost(info, extra_mv.x + x, extra_mv.y + y);
926 }
927 }
928 }
929
930 // Select starting point from among merge candidates. These should include
931 // both mv_cand vectors and (0, 0).
932 for (int i = 0; i < info->num_merge_cand; ++i) {
933 if (info->merge_cand[i].dir == 3) continue;
934
935 vector2d_t mv = {
936 .x = info->merge_cand[i].mv[info->merge_cand[i].dir - 1][0] >> 2,
937 .y = info->merge_cand[i].mv[info->merge_cand[i].dir - 1][1] >> 2,
938 };
939
940 // Ignore 0-vector because it has already been checked.
941 if (mv.x == 0 && mv.y == 0) continue;
942
943 vector2d_t min_mv = { mv.x - search_range, mv.y - search_range };
944 vector2d_t max_mv = { mv.x + search_range, mv.y + search_range };
945
946 for (int y = min_mv.y; y <= max_mv.y; ++y) {
947 for (int x = min_mv.x; x <= max_mv.x; ++x) {
948 if (!intmv_within_tile(info, x, y)) {
949 continue;
950 }
951
952 // Avoid calculating the same points over and over again.
953 bool already_tested = false;
954 for (int j = -1; j < i; ++j) {
955 int xx = 0;
956 int yy = 0;
957 if (j >= 0) {
958 if (info->merge_cand[j].dir == 3) continue;
959 xx = info->merge_cand[j].mv[info->merge_cand[j].dir - 1][0] >> 2;
960 yy = info->merge_cand[j].mv[info->merge_cand[j].dir - 1][1] >> 2;
961 }
962 if (x >= xx - search_range && x <= xx + search_range &&
963 y >= yy - search_range && y <= yy + search_range)
964 {
965 already_tested = true;
966 x = xx + search_range;
967 break;
968 }
969 }
970 if (already_tested) continue;
971
972 check_mv_cost(info, x, y);
973 }
974 }
975 }
976 }
977
978
979 /**
980 * \brief Do fractional motion estimation
981 *
982 * Algoritm first searches 1/2-pel positions around integer mv and after best match is found,
983 * refines the search by searching best 1/4-pel postion around best 1/2-pel position.
984 */
search_frac(inter_search_info_t * info)985 static void search_frac(inter_search_info_t *info)
986 {
987 // Map indexes to relative coordinates in the following way:
988 // 5 3 6
989 // 1 0 2
990 // 7 4 8
991 static const vector2d_t square[9] = {
992 { 0, 0 }, { -1, 0 }, { 1, 0 },
993 { 0, -1 }, { 0, 1 }, { -1, -1 },
994 { 1, -1 }, { -1, 1 }, { 1, 1 }
995 };
996
997 // Set mv to pixel precision
998 vector2d_t mv = { info->best_mv.x >> 2, info->best_mv.y >> 2 };
999
1000 unsigned best_cost = UINT32_MAX;
1001 uint32_t best_bitcost = 0;
1002 uint32_t bitcosts[4] = { 0 };
1003 unsigned best_index = 0;
1004
1005 unsigned costs[4] = { 0 };
1006
1007 ALIGNED(64) kvz_pixel filtered[4][LCU_LUMA_SIZE];
1008
1009 // Storage buffers for intermediate horizontally filtered results.
1010 // Have the first columns in contiguous memory for vectorization.
1011 ALIGNED(64) int16_t intermediate[5][KVZ_IPOL_MAX_IM_SIZE_LUMA_SIMD];
1012 int16_t hor_first_cols[5][KVZ_EXT_BLOCK_W_LUMA + 1];
1013
1014 const kvz_picture *ref = info->ref;
1015 const kvz_picture *pic = info->pic;
1016 vector2d_t orig = info->origin;
1017 const int width = info->width;
1018 const int height = info->height;
1019 const int internal_width = ((width + 7) >> 3) << 3; // Round up to closest 8
1020 const int internal_height = ((height + 7) >> 3) << 3;
1021
1022 const encoder_state_t *state = info->state;
1023 int fme_level = state->encoder_control->cfg.fme_level;
1024 int8_t sample_off_x = 0;
1025 int8_t sample_off_y = 0;
1026
1027 // Space for (possibly) extrapolated pixels and the part from the picture
1028 // One extra row and column compared to normal interpolation and some extra for AVX2.
1029 // The extrapolation function will set the pointers and stride.
1030 kvz_pixel ext_buffer[KVZ_FME_MAX_INPUT_SIZE_SIMD];
1031 kvz_pixel *ext = NULL;
1032 kvz_pixel *ext_origin = NULL;
1033 int ext_s = 0;
1034 kvz_epol_args epol_args = {
1035 .src = ref->y,
1036 .src_w = ref->width,
1037 .src_h = ref->height,
1038 .src_s = ref->stride,
1039 .blk_x = state->tile->offset_x + orig.x + mv.x - 1,
1040 .blk_y = state->tile->offset_y + orig.y + mv.y - 1,
1041 .blk_w = internal_width + 1, // TODO: real width
1042 .blk_h = internal_height + 1, // TODO: real height
1043 .pad_l = KVZ_LUMA_FILTER_OFFSET,
1044 .pad_r = KVZ_EXT_PADDING_LUMA - KVZ_LUMA_FILTER_OFFSET,
1045 .pad_t = KVZ_LUMA_FILTER_OFFSET,
1046 .pad_b = KVZ_EXT_PADDING_LUMA - KVZ_LUMA_FILTER_OFFSET,
1047 .pad_b_simd = 0 // AVX2 padding unnecessary because of blk_h
1048 };
1049
1050 // Initialize separately. Gets rid of warning
1051 // about using nonstandard extension.
1052 epol_args.buf = ext_buffer;
1053 epol_args.ext = &ext;
1054 epol_args.ext_origin = &ext_origin;
1055 epol_args.ext_s = &ext_s;
1056
1057 kvz_get_extended_block(&epol_args);
1058
1059 kvz_pixel *tmp_pic = pic->y + orig.y * pic->stride + orig.x;
1060 int tmp_stride = pic->stride;
1061
1062 // Search integer position
1063 costs[0] = kvz_satd_any_size(width, height,
1064 tmp_pic, tmp_stride,
1065 ext_origin + ext_s + 1, ext_s);
1066
1067 costs[0] += info->mvd_cost_func(state,
1068 mv.x, mv.y, 2,
1069 info->mv_cand,
1070 info->merge_cand,
1071 info->num_merge_cand,
1072 info->ref_idx,
1073 &bitcosts[0]);
1074 best_cost = costs[0];
1075 best_bitcost = bitcosts[0];
1076
1077 //Set mv to half-pixel precision
1078 mv.x *= 2;
1079 mv.y *= 2;
1080
1081 ipol_blocks_func * filter_steps[4] = {
1082 kvz_filter_hpel_blocks_hor_ver_luma,
1083 kvz_filter_hpel_blocks_diag_luma,
1084 kvz_filter_qpel_blocks_hor_ver_luma,
1085 kvz_filter_qpel_blocks_diag_luma,
1086 };
1087
1088 // Search halfpel positions around best integer mv
1089 int i = 1;
1090 for (int step = 0; step < fme_level; ++step){
1091
1092 const int mv_shift = (step < 2) ? 1 : 0;
1093
1094 filter_steps[step](state->encoder_control,
1095 ext_origin,
1096 ext_s,
1097 internal_width,
1098 internal_height,
1099 filtered,
1100 intermediate,
1101 fme_level,
1102 hor_first_cols,
1103 sample_off_x,
1104 sample_off_y);
1105
1106 const vector2d_t *pattern[4] = { &square[i], &square[i + 1], &square[i + 2], &square[i + 3] };
1107
1108 int8_t within_tile[4];
1109 for (int j = 0; j < 4; j++) {
1110 within_tile[j] =
1111 fracmv_within_tile(info, (mv.x + pattern[j]->x) * (1 << mv_shift), (mv.y + pattern[j]->y) * (1 << mv_shift));
1112 };
1113
1114 kvz_pixel *filtered_pos[4] = { 0 };
1115 filtered_pos[0] = &filtered[0][0];
1116 filtered_pos[1] = &filtered[1][0];
1117 filtered_pos[2] = &filtered[2][0];
1118 filtered_pos[3] = &filtered[3][0];
1119
1120 kvz_satd_any_size_quad(width, height, (const kvz_pixel **)filtered_pos, LCU_WIDTH, tmp_pic, tmp_stride, 4, costs, within_tile);
1121
1122 for (int j = 0; j < 4; j++) {
1123 if (within_tile[j]) {
1124 costs[j] += info->mvd_cost_func(
1125 state,
1126 mv.x + pattern[j]->x,
1127 mv.y + pattern[j]->y,
1128 mv_shift,
1129 info->mv_cand,
1130 info->merge_cand,
1131 info->num_merge_cand,
1132 info->ref_idx,
1133 &bitcosts[j]
1134 );
1135 }
1136 }
1137
1138 for (int j = 0; j < 4; ++j) {
1139 if (within_tile[j] && costs[j] < best_cost) {
1140 best_cost = costs[j];
1141 best_bitcost = bitcosts[j];
1142 best_index = i + j;
1143 }
1144 }
1145
1146 i += 4;
1147
1148 // Update mv for the best position on current precision
1149 if (step == 1 || step == fme_level - 1) {
1150 // Move search to best_index
1151 mv.x += square[best_index].x;
1152 mv.y += square[best_index].y;
1153
1154 // On last hpel step...
1155 if (step == MIN(fme_level - 1, 1)) {
1156 //Set mv to quarterpel precision
1157 mv.x *= 2;
1158 mv.y *= 2;
1159 sample_off_x = square[best_index].x;
1160 sample_off_y = square[best_index].y;
1161 best_index = 0;
1162 i = 1;
1163 }
1164 }
1165 }
1166
1167 info->best_mv = mv;
1168 info->best_cost = best_cost;
1169 info->best_bitcost = best_bitcost;
1170 }
1171
1172 /**
1173 * \brief Calculate the scaled MV
1174 */
get_scaled_mv(int16_t mv,int scale)1175 static INLINE int16_t get_scaled_mv(int16_t mv, int scale)
1176 {
1177 int32_t scaled = scale * mv;
1178 return CLIP(-32768, 32767, (scaled + 127 + (scaled < 0)) >> 8);
1179 }
1180 /**
1181 * \brief Scale the MV according to the POC difference
1182 *
1183 * \param current_poc POC of current frame
1184 * \param current_ref_poc POC of reference frame
1185 * \param neighbor_poc POC of neighbor frame
1186 * \param neighbor_ref_poc POC of neighbors reference frame
1187 * \param mv_cand MV candidates to scale
1188 */
apply_mv_scaling(int32_t current_poc,int32_t current_ref_poc,int32_t neighbor_poc,int32_t neighbor_ref_poc,vector2d_t * mv_cand)1189 static void apply_mv_scaling(int32_t current_poc,
1190 int32_t current_ref_poc,
1191 int32_t neighbor_poc,
1192 int32_t neighbor_ref_poc,
1193 vector2d_t* mv_cand)
1194 {
1195 int32_t diff_current = current_poc - current_ref_poc;
1196 int32_t diff_neighbor = neighbor_poc - neighbor_ref_poc;
1197
1198 if (diff_current == diff_neighbor) return;
1199 if (diff_neighbor == 0) return;
1200
1201 diff_current = CLIP(-128, 127, diff_current);
1202 diff_neighbor = CLIP(-128, 127, diff_neighbor);
1203
1204 int scale = CLIP(-4096, 4095,
1205 (diff_current * ((0x4000 + (abs(diff_neighbor) >> 1)) / diff_neighbor) + 32) >> 6);
1206
1207 mv_cand->x = get_scaled_mv(mv_cand->x, scale);
1208 mv_cand->y = get_scaled_mv(mv_cand->y, scale);
1209 }
1210
1211
1212 /**
1213 * \brief Perform inter search for a single reference frame.
1214 */
search_pu_inter_ref(inter_search_info_t * info,int depth,lcu_t * lcu,cu_info_t * cur_cu,double * inter_cost,uint32_t * inter_bitcost,double * best_LX_cost,cu_info_t * unipred_LX)1215 static void search_pu_inter_ref(inter_search_info_t *info,
1216 int depth,
1217 lcu_t *lcu, cu_info_t *cur_cu,
1218 double *inter_cost,
1219 uint32_t *inter_bitcost,
1220 double *best_LX_cost,
1221 cu_info_t *unipred_LX)
1222 {
1223 const kvz_config *cfg = &info->state->encoder_control->cfg;
1224
1225 // which list, L0 or L1, ref_idx is in and in what index
1226 int8_t ref_list = -1;
1227 // the index of the ref_idx in L0 or L1 list
1228 int8_t LX_idx;
1229 // max value of LX_idx plus one
1230 const int8_t LX_IDX_MAX_PLUS_1 = MAX(info->state->frame->ref_LX_size[0],
1231 info->state->frame->ref_LX_size[1]);
1232
1233 for (LX_idx = 0; LX_idx < LX_IDX_MAX_PLUS_1; LX_idx++)
1234 {
1235 // check if ref_idx is in L0
1236 if (LX_idx < info->state->frame->ref_LX_size[0] &&
1237 info->state->frame->ref_LX[0][LX_idx] == info->ref_idx) {
1238 ref_list = 0;
1239 break;
1240 }
1241
1242 // check if ref_idx is in L1
1243 if (LX_idx < info->state->frame->ref_LX_size[1] &&
1244 info->state->frame->ref_LX[1][LX_idx] == info->ref_idx) {
1245 ref_list = 1;
1246 break;
1247 }
1248 }
1249 // ref_idx has to be found in either L0 or L1
1250 assert(LX_idx < LX_IDX_MAX_PLUS_1);
1251
1252 // store temp values to be stored back later
1253 int8_t temp_ref_idx = cur_cu->inter.mv_ref[ref_list];
1254
1255 // Get MV candidates
1256 cur_cu->inter.mv_ref[ref_list] = LX_idx;
1257
1258 kvz_inter_get_mv_cand(info->state,
1259 info->origin.x,
1260 info->origin.y,
1261 info->width,
1262 info->height,
1263 info->mv_cand,
1264 cur_cu,
1265 lcu,
1266 ref_list);
1267
1268 // store old values back
1269 cur_cu->inter.mv_ref[ref_list] = temp_ref_idx;
1270
1271 vector2d_t mv = { 0, 0 };
1272
1273 // Take starting point for MV search from previous frame.
1274 // When temporal motion vector candidates are added, there is probably
1275 // no point to this anymore, but for now it helps.
1276 const int mid_x = info->state->tile->offset_x + info->origin.x + (info->width >> 1);
1277 const int mid_y = info->state->tile->offset_y + info->origin.y + (info->height >> 1);
1278 const cu_array_t* ref_array = info->state->frame->ref->cu_arrays[info->ref_idx];
1279 const cu_info_t* ref_cu = kvz_cu_array_at_const(ref_array, mid_x, mid_y);
1280 if (ref_cu->type == CU_INTER) {
1281 vector2d_t mv_previous = { 0, 0 };
1282 if (ref_cu->inter.mv_dir & 1) {
1283 mv_previous.x = ref_cu->inter.mv[0][0];
1284 mv_previous.y = ref_cu->inter.mv[0][1];
1285 }
1286 else {
1287 mv_previous.x = ref_cu->inter.mv[1][0];
1288 mv_previous.y = ref_cu->inter.mv[1][1];
1289 }
1290 // Apply mv scaling if neighbor poc is available
1291 if (info->state->frame->ref_LX_size[ref_list] > 0) {
1292 // When there are reference pictures from the future (POC > current POC)
1293 // in L0 or L1, the primary list for the colocated PU is the inverse of
1294 // collocated_from_l0_flag. Otherwise it is equal to reflist.
1295 //
1296 // Kvazaar always sets collocated_from_l0_flag so the list is L1 when
1297 // there are future references.
1298 int col_list = ref_list;
1299 for (int i = 0; i < info->state->frame->ref->used_size; i++) {
1300 if (info->state->frame->ref->pocs[i] > info->state->frame->poc) {
1301 col_list = 1;
1302 break;
1303 }
1304 }
1305 if ((ref_cu->inter.mv_dir & (col_list + 1)) == 0) {
1306 // Use the other list if the colocated PU does not have a MV for the
1307 // primary list.
1308 col_list = 1 - col_list;
1309 }
1310
1311 uint8_t neighbor_poc_index = info->state->frame->ref_LX[ref_list][LX_idx];
1312 // Scaling takes current POC, reference POC, neighbor POC and neighbor reference POC as argument
1313 apply_mv_scaling(
1314 info->state->frame->poc,
1315 info->state->frame->ref->pocs[info->state->frame->ref_LX[ref_list][LX_idx]],
1316 info->state->frame->ref->pocs[neighbor_poc_index],
1317 info->state->frame->ref->images[neighbor_poc_index]->ref_pocs[
1318 info->state->frame->ref->ref_LXs[neighbor_poc_index]
1319 [col_list]
1320 [ref_cu->inter.mv_ref[col_list]]
1321 ],
1322 &mv_previous
1323 );
1324 }
1325
1326 // Check if the mv is valid after scaling
1327 if (fracmv_within_tile(info, mv_previous.x, mv_previous.y)) {
1328 mv = mv_previous;
1329 }
1330 }
1331
1332 int search_range = 32;
1333 switch (cfg->ime_algorithm) {
1334 case KVZ_IME_FULL64: search_range = 64; break;
1335 case KVZ_IME_FULL32: search_range = 32; break;
1336 case KVZ_IME_FULL16: search_range = 16; break;
1337 case KVZ_IME_FULL8: search_range = 8; break;
1338 default: break;
1339 }
1340
1341 info->best_cost = UINT32_MAX;
1342
1343 switch (cfg->ime_algorithm) {
1344 case KVZ_IME_TZ:
1345 tz_search(info, mv);
1346 break;
1347
1348 case KVZ_IME_FULL64:
1349 case KVZ_IME_FULL32:
1350 case KVZ_IME_FULL16:
1351 case KVZ_IME_FULL8:
1352 case KVZ_IME_FULL:
1353 search_mv_full(info, search_range, mv);
1354 break;
1355
1356 case KVZ_IME_DIA:
1357 diamond_search(info, mv, info->state->encoder_control->cfg.me_max_steps);
1358 break;
1359
1360 default:
1361 hexagon_search(info, mv, info->state->encoder_control->cfg.me_max_steps);
1362 break;
1363 }
1364
1365 if (cfg->fme_level > 0 && info->best_cost < *inter_cost) {
1366 search_frac(info);
1367
1368 } else if (info->best_cost < UINT32_MAX) {
1369 // Recalculate inter cost with SATD.
1370 info->best_cost = kvz_image_calc_satd(
1371 info->state->tile->frame->source,
1372 info->ref,
1373 info->origin.x,
1374 info->origin.y,
1375 info->state->tile->offset_x + info->origin.x + (info->best_mv.x >> 2),
1376 info->state->tile->offset_y + info->origin.y + (info->best_mv.y >> 2),
1377 info->width,
1378 info->height);
1379 info->best_cost += info->best_bitcost * (int)(info->state->lambda_sqrt + 0.5);
1380 }
1381
1382 mv = info->best_mv;
1383
1384 int merged = 0;
1385 int merge_idx = 0;
1386 // Check every candidate to find a match
1387 for (merge_idx = 0; merge_idx < info->num_merge_cand; merge_idx++) {
1388 if (info->merge_cand[merge_idx].dir != 3 &&
1389 info->merge_cand[merge_idx].mv[info->merge_cand[merge_idx].dir - 1][0] == mv.x &&
1390 info->merge_cand[merge_idx].mv[info->merge_cand[merge_idx].dir - 1][1] == mv.y &&
1391 (uint32_t)info->state->frame->ref_LX[info->merge_cand[merge_idx].dir - 1][
1392 info->merge_cand[merge_idx].ref[info->merge_cand[merge_idx].dir - 1]] == info->ref_idx)
1393 {
1394 merged = 1;
1395 break;
1396 }
1397 }
1398
1399 // Only check when candidates are different
1400 int cu_mv_cand = 0;
1401 if (!merged) {
1402 cu_mv_cand =
1403 select_mv_cand(info->state, info->mv_cand, mv.x, mv.y, NULL);
1404 }
1405
1406 if (info->best_cost < *inter_cost) {
1407 // Map reference index to L0/L1 pictures
1408 cur_cu->inter.mv_dir = ref_list+1;
1409 uint8_t mv_ref_coded = LX_idx;
1410
1411 cur_cu->merged = merged;
1412 cur_cu->merge_idx = merge_idx;
1413 cur_cu->inter.mv_ref[ref_list] = LX_idx;
1414 cur_cu->inter.mv[ref_list][0] = (int16_t)mv.x;
1415 cur_cu->inter.mv[ref_list][1] = (int16_t)mv.y;
1416
1417 CU_SET_MV_CAND(cur_cu, ref_list, cu_mv_cand);
1418
1419 *inter_cost = info->best_cost;
1420 *inter_bitcost = info->best_bitcost + cur_cu->inter.mv_dir - 1 + mv_ref_coded;
1421 }
1422
1423
1424 // Update best unipreds for biprediction
1425 if (info->best_cost < best_LX_cost[ref_list]) {
1426 bool valid_mv = fracmv_within_tile(info, mv.x, mv.y);
1427 if (valid_mv) {
1428 // Map reference index to L0/L1 pictures
1429 unipred_LX[ref_list].inter.mv_dir = ref_list + 1;
1430 unipred_LX[ref_list].inter.mv_ref[ref_list] = LX_idx;
1431 unipred_LX[ref_list].inter.mv[ref_list][0] = (int16_t)mv.x;
1432 unipred_LX[ref_list].inter.mv[ref_list][1] = (int16_t)mv.y;
1433
1434 CU_SET_MV_CAND(&unipred_LX[ref_list], ref_list, cu_mv_cand);
1435
1436 best_LX_cost[ref_list] = info->best_cost;
1437 }
1438 }
1439 }
1440
1441
1442 /**
1443 * \brief Search bipred modes for a PU.
1444 */
search_pu_inter_bipred(inter_search_info_t * info,int depth,lcu_t * lcu,cu_info_t * cur_cu,double * inter_cost,uint32_t * inter_bitcost)1445 static void search_pu_inter_bipred(inter_search_info_t *info,
1446 int depth,
1447 lcu_t *lcu, cu_info_t *cur_cu,
1448 double *inter_cost,
1449 uint32_t *inter_bitcost)
1450 {
1451 const image_list_t *const ref = info->state->frame->ref;
1452 uint8_t (*ref_LX)[16] = info->state->frame->ref_LX;
1453 const videoframe_t * const frame = info->state->tile->frame;
1454 const int x = info->origin.x;
1455 const int y = info->origin.y;
1456 const int width = info->width;
1457 const int height = info->height;
1458
1459 static const uint8_t priorityList0[] = { 0, 1, 0, 2, 1, 2, 0, 3, 1, 3, 2, 3 };
1460 static const uint8_t priorityList1[] = { 1, 0, 2, 0, 2, 1, 3, 0, 3, 1, 3, 2 };
1461 const unsigned num_cand_pairs =
1462 MIN(info->num_merge_cand * (info->num_merge_cand - 1), 12);
1463
1464 inter_merge_cand_t *merge_cand = info->merge_cand;
1465
1466 for (int32_t idx = 0; idx < num_cand_pairs; idx++) {
1467 uint8_t i = priorityList0[idx];
1468 uint8_t j = priorityList1[idx];
1469 if (i >= info->num_merge_cand || j >= info->num_merge_cand) break;
1470
1471 // Find one L0 and L1 candidate according to the priority list
1472 if (!(merge_cand[i].dir & 0x1) || !(merge_cand[j].dir & 0x2)) continue;
1473
1474 if (ref_LX[0][merge_cand[i].ref[0]] == ref_LX[1][merge_cand[j].ref[1]] &&
1475 merge_cand[i].mv[0][0] == merge_cand[j].mv[1][0] &&
1476 merge_cand[i].mv[0][1] == merge_cand[j].mv[1][1])
1477 {
1478 continue;
1479 }
1480
1481 int16_t mv[2][2];
1482 mv[0][0] = merge_cand[i].mv[0][0];
1483 mv[0][1] = merge_cand[i].mv[0][1];
1484 mv[1][0] = merge_cand[j].mv[1][0];
1485 mv[1][1] = merge_cand[j].mv[1][1];
1486
1487 // Don't try merge candidates that don't satisfy mv constraints.
1488 if (!fracmv_within_tile(info, mv[0][0], mv[0][1]) ||
1489 !fracmv_within_tile(info, mv[1][0], mv[1][1]))
1490 {
1491 continue;
1492 }
1493
1494 kvz_inter_recon_bipred(info->state,
1495 ref->images[ref_LX[0][merge_cand[i].ref[0]]],
1496 ref->images[ref_LX[1][merge_cand[j].ref[1]]],
1497 x, y,
1498 width,
1499 height,
1500 mv,
1501 lcu,
1502 true,
1503 false);
1504
1505 const kvz_pixel *rec = &lcu->rec.y[SUB_SCU(y) * LCU_WIDTH + SUB_SCU(x)];
1506 const kvz_pixel *src = &frame->source->y[x + y * frame->source->width];
1507 uint32_t cost =
1508 kvz_satd_any_size(width, height, rec, LCU_WIDTH, src, frame->source->width);
1509
1510 uint32_t bitcost[2] = { 0, 0 };
1511
1512 cost += info->mvd_cost_func(info->state,
1513 merge_cand[i].mv[0][0],
1514 merge_cand[i].mv[0][1],
1515 0,
1516 info->mv_cand,
1517 NULL, 0, 0,
1518 &bitcost[0]);
1519 cost += info->mvd_cost_func(info->state,
1520 merge_cand[i].mv[1][0],
1521 merge_cand[i].mv[1][1],
1522 0,
1523 info->mv_cand,
1524 NULL, 0, 0,
1525 &bitcost[1]);
1526
1527 const uint8_t mv_ref_coded[2] = {
1528 merge_cand[i].ref[0],
1529 merge_cand[j].ref[1]
1530 };
1531 const int extra_bits = mv_ref_coded[0] + mv_ref_coded[1] + 2 /* mv dir cost */;
1532 cost += info->state->lambda_sqrt * extra_bits + 0.5;
1533
1534 if (cost < *inter_cost) {
1535 cur_cu->inter.mv_dir = 3;
1536
1537 cur_cu->inter.mv_ref[0] = merge_cand[i].ref[0];
1538 cur_cu->inter.mv_ref[1] = merge_cand[j].ref[1];
1539
1540 cur_cu->inter.mv[0][0] = merge_cand[i].mv[0][0];
1541 cur_cu->inter.mv[0][1] = merge_cand[i].mv[0][1];
1542 cur_cu->inter.mv[1][0] = merge_cand[j].mv[1][0];
1543 cur_cu->inter.mv[1][1] = merge_cand[j].mv[1][1];
1544 cur_cu->merged = 0;
1545
1546 // Check every candidate to find a match
1547 for (int merge_idx = 0; merge_idx < info->num_merge_cand; merge_idx++) {
1548 if (merge_cand[merge_idx].mv[0][0] == cur_cu->inter.mv[0][0] &&
1549 merge_cand[merge_idx].mv[0][1] == cur_cu->inter.mv[0][1] &&
1550 merge_cand[merge_idx].mv[1][0] == cur_cu->inter.mv[1][0] &&
1551 merge_cand[merge_idx].mv[1][1] == cur_cu->inter.mv[1][1] &&
1552 merge_cand[merge_idx].ref[0] == cur_cu->inter.mv_ref[0] &&
1553 merge_cand[merge_idx].ref[1] == cur_cu->inter.mv_ref[1])
1554 {
1555 cur_cu->merged = 1;
1556 cur_cu->merge_idx = merge_idx;
1557 break;
1558 }
1559 }
1560
1561 // Each motion vector has its own candidate
1562 for (int reflist = 0; reflist < 2; reflist++) {
1563 kvz_inter_get_mv_cand(info->state, x, y, width, height, info->mv_cand, cur_cu, lcu, reflist);
1564 int cu_mv_cand = select_mv_cand(
1565 info->state,
1566 info->mv_cand,
1567 cur_cu->inter.mv[reflist][0],
1568 cur_cu->inter.mv[reflist][1],
1569 NULL);
1570 CU_SET_MV_CAND(cur_cu, reflist, cu_mv_cand);
1571 }
1572
1573 *inter_cost = cost;
1574 *inter_bitcost = bitcost[0] + bitcost[1] + extra_bits;
1575 }
1576 }
1577 }
1578
1579 /**
1580 * \brief Check if an identical merge candidate exists in a list
1581 *
1582 * \param all_cand Full list of available merge candidates
1583 * \param cand_to_add Merge candidate to be checked for duplicates
1584 * \param added_idx_list List of indices of unique merge candidates
1585 * \param list_size Size of the list
1586 *
1587 * \return Does an identical candidate exist in list
1588 */
merge_candidate_in_list(inter_merge_cand_t * all_cands,inter_merge_cand_t * cand_to_add,int8_t * added_idx_list,int list_size)1589 static bool merge_candidate_in_list(inter_merge_cand_t * all_cands,
1590 inter_merge_cand_t * cand_to_add,
1591 int8_t * added_idx_list,
1592 int list_size)
1593 {
1594 bool found = false;
1595 for (int i = 0; i < list_size && !found; ++i) {
1596 inter_merge_cand_t * list_cand = &all_cands[added_idx_list[i]];
1597
1598 found = cand_to_add->dir == list_cand->dir &&
1599 cand_to_add->ref[0] == list_cand->ref[0] &&
1600 cand_to_add->mv[0][0] == list_cand->mv[0][0] &&
1601 cand_to_add->mv[0][1] == list_cand->mv[0][1] &&
1602 cand_to_add->ref[1] == list_cand->ref[1] &&
1603 cand_to_add->mv[1][0] == list_cand->mv[1][0] &&
1604 cand_to_add->mv[1][1] == list_cand->mv[1][1];
1605 }
1606
1607 return found;
1608 }
1609
1610 /**
1611 * \brief Update PU to have best modes at this depth.
1612 *
1613 * \param state encoder state
1614 * \param x_cu x-coordinate of the containing CU
1615 * \param y_cu y-coordinate of the containing CU
1616 * \param depth depth of the CU in the quadtree
1617 * \param part_mode partition mode of the CU
1618 * \param i_pu index of the PU in the CU
1619 * \param lcu containing LCU
1620 *
1621 * \param inter_cost Return inter cost of the best mode
1622 * \param inter_bitcost Return inter bitcost of the best mode
1623 */
search_pu_inter(encoder_state_t * const state,int x_cu,int y_cu,int depth,part_mode_t part_mode,int i_pu,lcu_t * lcu,double * inter_cost,uint32_t * inter_bitcost)1624 static void search_pu_inter(encoder_state_t * const state,
1625 int x_cu, int y_cu,
1626 int depth,
1627 part_mode_t part_mode,
1628 int i_pu,
1629 lcu_t *lcu,
1630 double *inter_cost,
1631 uint32_t *inter_bitcost)
1632 {
1633 *inter_cost = MAX_INT;
1634 *inter_bitcost = MAX_INT;
1635
1636 const kvz_config *cfg = &state->encoder_control->cfg;
1637 const videoframe_t * const frame = state->tile->frame;
1638 const int width_cu = LCU_WIDTH >> depth;
1639 const int x = PU_GET_X(part_mode, width_cu, x_cu, i_pu);
1640 const int y = PU_GET_Y(part_mode, width_cu, y_cu, i_pu);
1641 const int width = PU_GET_W(part_mode, width_cu, i_pu);
1642 const int height = PU_GET_H(part_mode, width_cu, i_pu);
1643
1644 // Merge candidate A1 may not be used for the second PU of Nx2N, nLx2N and
1645 // nRx2N partitions.
1646 const bool merge_a1 = i_pu == 0 || width >= height;
1647 // Merge candidate B1 may not be used for the second PU of 2NxN, 2NxnU and
1648 // 2NxnD partitions.
1649 const bool merge_b1 = i_pu == 0 || width <= height;
1650
1651 const int x_local = SUB_SCU(x);
1652 const int y_local = SUB_SCU(y);
1653 cu_info_t *cur_cu = LCU_GET_CU_AT_PX(lcu, x_local, y_local);
1654
1655 inter_search_info_t info = {
1656 .state = state,
1657 .pic = frame->source,
1658 .origin = { x, y },
1659 .width = width,
1660 .height = height,
1661 .mvd_cost_func = cfg->mv_rdo ? kvz_calc_mvd_cost_cabac : calc_mvd_cost,
1662 .optimized_sad = kvz_get_optimized_sad(width),
1663 };
1664
1665 // Search for merge mode candidates
1666 info.num_merge_cand = kvz_inter_get_merge_cand(
1667 state,
1668 x, y,
1669 width, height,
1670 merge_a1, merge_b1,
1671 info.merge_cand,
1672 lcu
1673 );
1674
1675 // Default to candidate 0
1676 CU_SET_MV_CAND(cur_cu, 0, 0);
1677 CU_SET_MV_CAND(cur_cu, 1, 0);
1678
1679 // Merge Analysis starts here
1680 int8_t mrg_cands[MRG_MAX_NUM_CANDS];
1681 double mrg_costs[MRG_MAX_NUM_CANDS];
1682 for (int i = 0; i < MRG_MAX_NUM_CANDS; ++i) {
1683 mrg_cands[i] = -1;
1684 mrg_costs[i] = MAX_DOUBLE;
1685 }
1686
1687 int num_rdo_cands = 0;
1688
1689 // Check motion vector constraints and perform rough search
1690 for (int merge_idx = 0; merge_idx < info.num_merge_cand; ++merge_idx) {
1691
1692 inter_merge_cand_t *cur_cand = &info.merge_cand[merge_idx];
1693 cur_cu->inter.mv_dir = cur_cand->dir;
1694 cur_cu->inter.mv_ref[0] = cur_cand->ref[0];
1695 cur_cu->inter.mv_ref[1] = cur_cand->ref[1];
1696 cur_cu->inter.mv[0][0] = cur_cand->mv[0][0];
1697 cur_cu->inter.mv[0][1] = cur_cand->mv[0][1];
1698 cur_cu->inter.mv[1][0] = cur_cand->mv[1][0];
1699 cur_cu->inter.mv[1][1] = cur_cand->mv[1][1];
1700
1701 // If bipred is not enabled, do not try candidates with mv_dir == 3.
1702 // Bipred is also forbidden for 4x8 and 8x4 blocks by the standard.
1703 if (cur_cu->inter.mv_dir == 3 && !state->encoder_control->cfg.bipred) continue;
1704 if (cur_cu->inter.mv_dir == 3 && !(width + height > 12)) continue;
1705
1706 bool is_duplicate = merge_candidate_in_list(info.merge_cand, cur_cand,
1707 mrg_cands,
1708 num_rdo_cands);
1709
1710 // Don't try merge candidates that don't satisfy mv constraints.
1711 // Don't add duplicates to list
1712 if (!fracmv_within_tile(&info, cur_cu->inter.mv[0][0], cur_cu->inter.mv[0][1]) ||
1713 !fracmv_within_tile(&info, cur_cu->inter.mv[1][0], cur_cu->inter.mv[1][1]) ||
1714 is_duplicate)
1715 {
1716 continue;
1717 }
1718
1719 kvz_inter_pred_pu(state, lcu, x_cu, y_cu, width_cu, true, false, i_pu);
1720 mrg_costs[num_rdo_cands] = kvz_satd_any_size(width, height,
1721 lcu->rec.y + y_local * LCU_WIDTH + x_local, LCU_WIDTH,
1722 lcu->ref.y + y_local * LCU_WIDTH + x_local, LCU_WIDTH);
1723
1724 // Add cost of coding the merge index
1725 mrg_costs[num_rdo_cands] += merge_idx * info.state->lambda_sqrt;
1726
1727 mrg_cands[num_rdo_cands] = merge_idx;
1728 num_rdo_cands++;
1729 }
1730
1731 // Sort candidates by cost
1732 kvz_sort_modes(mrg_cands, mrg_costs, num_rdo_cands);
1733
1734 // Limit by availability
1735 // TODO: Do not limit to just 1
1736 num_rdo_cands = MIN(1, num_rdo_cands);
1737
1738 // Early Skip Mode Decision
1739 bool has_chroma = state->encoder_control->chroma_format != KVZ_CSP_400;
1740 if (cfg->early_skip && cur_cu->part_size == SIZE_2Nx2N) {
1741 for (int merge_rdo_idx = 0; merge_rdo_idx < num_rdo_cands; ++merge_rdo_idx) {
1742
1743 // Reconstruct blocks with merge candidate.
1744 // Check luma CBF. Then, check chroma CBFs if luma CBF is not set
1745 // and chroma exists.
1746 // Early terminate if merge candidate with zero CBF is found.
1747 int merge_idx = mrg_cands[merge_rdo_idx];
1748 cur_cu->inter.mv_dir = info.merge_cand[merge_idx].dir;
1749 cur_cu->inter.mv_ref[0] = info.merge_cand[merge_idx].ref[0];
1750 cur_cu->inter.mv_ref[1] = info.merge_cand[merge_idx].ref[1];
1751 cur_cu->inter.mv[0][0] = info.merge_cand[merge_idx].mv[0][0];
1752 cur_cu->inter.mv[0][1] = info.merge_cand[merge_idx].mv[0][1];
1753 cur_cu->inter.mv[1][0] = info.merge_cand[merge_idx].mv[1][0];
1754 cur_cu->inter.mv[1][1] = info.merge_cand[merge_idx].mv[1][1];
1755 kvz_lcu_fill_trdepth(lcu, x, y, depth, MAX(1, depth));
1756 kvz_inter_recon_cu(state, lcu, x, y, width, true, false);
1757 kvz_quantize_lcu_residual(state, true, false, x, y, depth, cur_cu, lcu, true);
1758
1759 if (cbf_is_set(cur_cu->cbf, depth, COLOR_Y)) {
1760 continue;
1761 }
1762 else if (has_chroma) {
1763 kvz_inter_recon_cu(state, lcu, x, y, width, false, has_chroma);
1764 kvz_quantize_lcu_residual(state, false, has_chroma, x, y, depth, cur_cu, lcu, true);
1765 if (!cbf_is_set_any(cur_cu->cbf, depth)) {
1766 cur_cu->type = CU_INTER;
1767 cur_cu->merge_idx = merge_idx;
1768 cur_cu->skipped = true;
1769 *inter_cost = 0.0; // TODO: Check this
1770 *inter_bitcost = merge_idx; // TODO: Check this
1771 return;
1772 }
1773 }
1774 }
1775 }
1776
1777 // AMVP search starts here
1778
1779 // Store unipred information of L0 and L1 for biprediction
1780 // Best cost will be left at MAX_DOUBLE if no valid CU is found
1781 double best_cost_LX[2] = { MAX_DOUBLE, MAX_DOUBLE };
1782 cu_info_t unipreds[2];
1783
1784 for (int ref_idx = 0; ref_idx < state->frame->ref->used_size; ref_idx++) {
1785 info.ref_idx = ref_idx;
1786 info.ref = state->frame->ref->images[ref_idx];
1787
1788 search_pu_inter_ref(&info, depth, lcu, cur_cu, inter_cost, inter_bitcost, best_cost_LX, unipreds);
1789 }
1790
1791 // Search bi-pred positions
1792 bool can_use_bipred = state->frame->slicetype == KVZ_SLICE_B
1793 && cfg->bipred
1794 && width + height >= 16; // 4x8 and 8x4 PBs are restricted to unipred
1795
1796 if (can_use_bipred) {
1797
1798 // Try biprediction from valid acquired unipreds.
1799 if (best_cost_LX[0] != MAX_DOUBLE && best_cost_LX[1] != MAX_DOUBLE) {
1800
1801 // TODO: logic is copy paste from search_pu_inter_bipred.
1802 // Get rid of duplicate code asap.
1803 const image_list_t *const ref = info.state->frame->ref;
1804 uint8_t(*ref_LX)[16] = info.state->frame->ref_LX;
1805
1806 inter_merge_cand_t *merge_cand = info.merge_cand;
1807
1808 int16_t mv[2][2];
1809 mv[0][0] = unipreds[0].inter.mv[0][0];
1810 mv[0][1] = unipreds[0].inter.mv[0][1];
1811 mv[1][0] = unipreds[1].inter.mv[1][0];
1812 mv[1][1] = unipreds[1].inter.mv[1][1];
1813
1814 kvz_inter_recon_bipred(info.state,
1815 ref->images[ref_LX[0][unipreds[0].inter.mv_ref[0]]],
1816 ref->images[ref_LX[1][unipreds[1].inter.mv_ref[1]]],
1817 x, y,
1818 width,
1819 height,
1820 mv,
1821 lcu,
1822 true,
1823 false);
1824
1825 const kvz_pixel *rec = &lcu->rec.y[SUB_SCU(y) * LCU_WIDTH + SUB_SCU(x)];
1826 const kvz_pixel *src = &lcu->ref.y[SUB_SCU(y) * LCU_WIDTH + SUB_SCU(x)];
1827 uint32_t cost =
1828 kvz_satd_any_size(width, height, rec, LCU_WIDTH, src, LCU_WIDTH);
1829
1830 uint32_t bitcost[2] = { 0, 0 };
1831
1832 cost += info.mvd_cost_func(info.state,
1833 unipreds[0].inter.mv[0][0],
1834 unipreds[0].inter.mv[0][1],
1835 0,
1836 info.mv_cand,
1837 NULL, 0, 0,
1838 &bitcost[0]);
1839 cost += info.mvd_cost_func(info.state,
1840 unipreds[1].inter.mv[1][0],
1841 unipreds[1].inter.mv[1][1],
1842 0,
1843 info.mv_cand,
1844 NULL, 0, 0,
1845 &bitcost[1]);
1846
1847 const uint8_t mv_ref_coded[2] = {
1848 unipreds[0].inter.mv_ref[0],
1849 unipreds[1].inter.mv_ref[1]
1850 };
1851 const int extra_bits = mv_ref_coded[0] + mv_ref_coded[1] + 2 /* mv dir cost */;
1852 cost += info.state->lambda_sqrt * extra_bits + 0.5;
1853
1854 if (cost < *inter_cost) {
1855 cur_cu->inter.mv_dir = 3;
1856
1857 cur_cu->inter.mv_ref[0] = unipreds[0].inter.mv_ref[0];
1858 cur_cu->inter.mv_ref[1] = unipreds[1].inter.mv_ref[1];
1859
1860 cur_cu->inter.mv[0][0] = unipreds[0].inter.mv[0][0];
1861 cur_cu->inter.mv[0][1] = unipreds[0].inter.mv[0][1];
1862 cur_cu->inter.mv[1][0] = unipreds[1].inter.mv[1][0];
1863 cur_cu->inter.mv[1][1] = unipreds[1].inter.mv[1][1];
1864 cur_cu->merged = 0;
1865
1866 // Check every candidate to find a match
1867 for (int merge_idx = 0; merge_idx < info.num_merge_cand; merge_idx++) {
1868 if (merge_cand[merge_idx].mv[0][0] == cur_cu->inter.mv[0][0] &&
1869 merge_cand[merge_idx].mv[0][1] == cur_cu->inter.mv[0][1] &&
1870 merge_cand[merge_idx].mv[1][0] == cur_cu->inter.mv[1][0] &&
1871 merge_cand[merge_idx].mv[1][1] == cur_cu->inter.mv[1][1] &&
1872 merge_cand[merge_idx].ref[0] == cur_cu->inter.mv_ref[0] &&
1873 merge_cand[merge_idx].ref[1] == cur_cu->inter.mv_ref[1])
1874 {
1875 cur_cu->merged = 1;
1876 cur_cu->merge_idx = merge_idx;
1877 break;
1878 }
1879 }
1880
1881 // Each motion vector has its own candidate
1882 for (int reflist = 0; reflist < 2; reflist++) {
1883 kvz_inter_get_mv_cand(info.state, x, y, width, height, info.mv_cand, cur_cu, lcu, reflist);
1884 int cu_mv_cand = select_mv_cand(
1885 info.state,
1886 info.mv_cand,
1887 cur_cu->inter.mv[reflist][0],
1888 cur_cu->inter.mv[reflist][1],
1889 NULL);
1890 CU_SET_MV_CAND(cur_cu, reflist, cu_mv_cand);
1891 }
1892
1893 *inter_cost = cost;
1894 *inter_bitcost = bitcost[0] + bitcost[1] + extra_bits;
1895 }
1896 }
1897
1898 // TODO: this probably should have a separate command line option
1899 if (cfg->rdo == 3) {
1900 search_pu_inter_bipred(&info, depth, lcu, cur_cu, inter_cost, inter_bitcost);
1901 }
1902 }
1903
1904 // Compare best merge cost to amvp cost
1905 if (mrg_costs[0] < *inter_cost) {
1906 *inter_cost = mrg_costs[0];
1907 *inter_bitcost = 0; // TODO: Check this
1908 int merge_idx = mrg_cands[0];
1909 cur_cu->type = CU_INTER;
1910 cur_cu->merge_idx = merge_idx;
1911 cur_cu->inter.mv_dir = info.merge_cand[merge_idx].dir;
1912 cur_cu->inter.mv_ref[0] = info.merge_cand[merge_idx].ref[0];
1913 cur_cu->inter.mv_ref[1] = info.merge_cand[merge_idx].ref[1];
1914 cur_cu->inter.mv[0][0] = info.merge_cand[merge_idx].mv[0][0];
1915 cur_cu->inter.mv[0][1] = info.merge_cand[merge_idx].mv[0][1];
1916 cur_cu->inter.mv[1][0] = info.merge_cand[merge_idx].mv[1][0];
1917 cur_cu->inter.mv[1][1] = info.merge_cand[merge_idx].mv[1][1];
1918 cur_cu->merged = true;
1919 cur_cu->skipped = false;
1920 }
1921
1922 if (*inter_cost < INT_MAX && cur_cu->inter.mv_dir == 1) {
1923 assert(fracmv_within_tile(&info, cur_cu->inter.mv[0][0], cur_cu->inter.mv[0][1]));
1924 }
1925 }
1926
1927 /**
1928 * \brief Calculate inter coding cost for luma and chroma CBs (--rd=2 accuracy).
1929 *
1930 * Calculate inter coding cost of each CB. This should match the intra coding cost
1931 * calculation that is used on this RDO accuracy, since CU type decision is based
1932 * on this.
1933 *
1934 * The cost includes SSD distortion, transform unit tree bits and motion vector bits
1935 * for both luma and chroma if enabled.
1936 *
1937 * \param state encoder state
1938 * \param x x-coordinate of the CU
1939 * \param y y-coordinate of the CU
1940 * \param depth depth of the CU in the quadtree
1941 * \param lcu containing LCU
1942 *
1943 * \param inter_cost Return inter cost
1944 * \param inter_bitcost Return inter bitcost
1945 */
kvz_cu_cost_inter_rd2(encoder_state_t * const state,int x,int y,int depth,lcu_t * lcu,double * inter_cost,uint32_t * inter_bitcost)1946 void kvz_cu_cost_inter_rd2(encoder_state_t * const state,
1947 int x, int y, int depth,
1948 lcu_t *lcu,
1949 double *inter_cost,
1950 uint32_t *inter_bitcost){
1951
1952 cu_info_t *cur_cu = LCU_GET_CU_AT_PX(lcu, SUB_SCU(x), SUB_SCU(y));
1953 int tr_depth = MAX(1, depth);
1954 if (cur_cu->part_size != SIZE_2Nx2N) {
1955 tr_depth = depth + 1;
1956 }
1957 kvz_lcu_fill_trdepth(lcu, x, y, depth, tr_depth);
1958
1959 const bool reconstruct_chroma = state->encoder_control->chroma_format != KVZ_CSP_400;
1960 kvz_inter_recon_cu(state, lcu, x, y, CU_WIDTH_FROM_DEPTH(depth), true, reconstruct_chroma);
1961 kvz_quantize_lcu_residual(state, true, reconstruct_chroma,
1962 x, y, depth,
1963 NULL,
1964 lcu,
1965 false);
1966
1967 *inter_cost = kvz_cu_rd_cost_luma(state, SUB_SCU(x), SUB_SCU(y), depth, cur_cu, lcu);
1968 if (reconstruct_chroma) {
1969 *inter_cost += kvz_cu_rd_cost_chroma(state, SUB_SCU(x), SUB_SCU(y), depth, cur_cu, lcu);
1970 }
1971
1972 *inter_cost += *inter_bitcost * state->lambda;
1973 }
1974
1975
1976 /**
1977 * \brief Update CU to have best modes at this depth.
1978 *
1979 * Only searches the 2Nx2N partition mode.
1980 *
1981 * \param state encoder state
1982 * \param x x-coordinate of the CU
1983 * \param y y-coordinate of the CU
1984 * \param depth depth of the CU in the quadtree
1985 * \param lcu containing LCU
1986 *
1987 * \param inter_cost Return inter cost
1988 * \param inter_bitcost Return inter bitcost
1989 */
kvz_search_cu_inter(encoder_state_t * const state,int x,int y,int depth,lcu_t * lcu,double * inter_cost,uint32_t * inter_bitcost)1990 void kvz_search_cu_inter(encoder_state_t * const state,
1991 int x, int y, int depth,
1992 lcu_t *lcu,
1993 double *inter_cost,
1994 uint32_t *inter_bitcost)
1995 {
1996 search_pu_inter(state,
1997 x, y, depth,
1998 SIZE_2Nx2N, 0,
1999 lcu,
2000 inter_cost,
2001 inter_bitcost);
2002
2003 // Calculate more accurate cost when needed
2004 if (state->encoder_control->cfg.rdo >= 2) {
2005 kvz_cu_cost_inter_rd2(state,
2006 x, y, depth,
2007 lcu,
2008 inter_cost,
2009 inter_bitcost);
2010 }
2011 }
2012
2013
2014 /**
2015 * \brief Update CU to have best modes at this depth.
2016 *
2017 * Only searches the given partition mode.
2018 *
2019 * \param state encoder state
2020 * \param x x-coordinate of the CU
2021 * \param y y-coordinate of the CU
2022 * \param depth depth of the CU in the quadtree
2023 * \param part_mode partition mode to search
2024 * \param lcu containing LCU
2025 *
2026 * \param inter_cost Return inter cost
2027 * \param inter_bitcost Return inter bitcost
2028 */
kvz_search_cu_smp(encoder_state_t * const state,int x,int y,int depth,part_mode_t part_mode,lcu_t * lcu,double * inter_cost,uint32_t * inter_bitcost)2029 void kvz_search_cu_smp(encoder_state_t * const state,
2030 int x, int y,
2031 int depth,
2032 part_mode_t part_mode,
2033 lcu_t *lcu,
2034 double *inter_cost,
2035 uint32_t *inter_bitcost)
2036 {
2037 const int num_pu = kvz_part_mode_num_parts[part_mode];
2038 const int width = LCU_WIDTH >> depth;
2039 const int y_local = SUB_SCU(y);
2040 const int x_local = SUB_SCU(x);
2041
2042 *inter_cost = 0;
2043 *inter_bitcost = 0;
2044
2045 for (int i = 0; i < num_pu; ++i) {
2046 const int x_pu = PU_GET_X(part_mode, width, x_local, i);
2047 const int y_pu = PU_GET_Y(part_mode, width, y_local, i);
2048 const int width_pu = PU_GET_W(part_mode, width, i);
2049 const int height_pu = PU_GET_H(part_mode, width, i);
2050 cu_info_t *cur_pu = LCU_GET_CU_AT_PX(lcu, x_pu, y_pu);
2051
2052 cur_pu->type = CU_INTER;
2053 cur_pu->part_size = part_mode;
2054 cur_pu->depth = depth;
2055 cur_pu->qp = state->qp;
2056
2057 double cost = MAX_INT;
2058 uint32_t bitcost = MAX_INT;
2059
2060 search_pu_inter(state, x, y, depth, part_mode, i, lcu, &cost, &bitcost);
2061
2062 if (cost >= MAX_INT) {
2063 // Could not find any motion vector.
2064 *inter_cost = MAX_INT;
2065 *inter_bitcost = MAX_INT;
2066 return;
2067 }
2068
2069 *inter_cost += cost;
2070 *inter_bitcost += bitcost;
2071
2072 for (int y = y_pu; y < y_pu + height_pu; y += SCU_WIDTH) {
2073 for (int x = x_pu; x < x_pu + width_pu; x += SCU_WIDTH) {
2074 cu_info_t *scu = LCU_GET_CU_AT_PX(lcu, x, y);
2075 scu->type = CU_INTER;
2076 scu->inter = cur_pu->inter;
2077 }
2078 }
2079 }
2080
2081 // Calculate more accurate cost when needed
2082 if (state->encoder_control->cfg.rdo >= 2) {
2083 kvz_cu_cost_inter_rd2(state,
2084 x, y, depth,
2085 lcu,
2086 inter_cost,
2087 inter_bitcost);
2088 }
2089
2090 // Count bits spent for coding the partition mode.
2091 int smp_extra_bits = 1; // horizontal or vertical
2092 if (state->encoder_control->cfg.amp_enable) {
2093 smp_extra_bits += 1; // symmetric or asymmetric
2094 if (part_mode != SIZE_2NxN && part_mode != SIZE_Nx2N) {
2095 smp_extra_bits += 1; // U,L or D,R
2096 }
2097 }
2098 // The transform is split for SMP and AMP blocks so we need more bits for
2099 // coding the CBF.
2100 smp_extra_bits += 6;
2101
2102 *inter_cost += (state->encoder_control->cfg.rdo >= 2 ? state->lambda : state->lambda_sqrt) * smp_extra_bits;
2103 *inter_bitcost += smp_extra_bits;
2104 }
2105