1 /*
2  * Copyright (c) 2018, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <stdlib.h>
13 #include <memory.h>
14 #include <math.h>
15 #include <assert.h>
16 
17 #include <smmintrin.h>
18 
19 #include "config/av1_rtcd.h"
20 
21 #include "aom_ports/mem.h"
22 #include "aom_ports/system_state.h"
23 #include "av1/encoder/corner_match.h"
24 
25 DECLARE_ALIGNED(16, static const uint8_t,
26                 byte_mask[16]) = { 255, 255, 255, 255, 255, 255, 255, 255,
27                                    255, 255, 255, 255, 255, 0,   0,   0 };
28 #if MATCH_SZ != 13
29 #error "Need to change byte_mask in corner_match_sse4.c if MATCH_SZ != 13"
30 #endif
31 
32 /* Compute corr(im1, im2) * MATCH_SZ * stddev(im1), where the
33    correlation/standard deviation are taken over MATCH_SZ by MATCH_SZ windows
34    of each image, centered at (x1, y1) and (x2, y2) respectively.
35 */
av1_compute_cross_correlation_sse4_1(unsigned char * im1,int stride1,int x1,int y1,unsigned char * im2,int stride2,int x2,int y2)36 double av1_compute_cross_correlation_sse4_1(unsigned char *im1, int stride1,
37                                             int x1, int y1, unsigned char *im2,
38                                             int stride2, int x2, int y2) {
39   int i;
40   // 2 16-bit partial sums in lanes 0, 4 (== 2 32-bit partial sums in lanes 0,
41   // 2)
42   __m128i sum1_vec = _mm_setzero_si128();
43   __m128i sum2_vec = _mm_setzero_si128();
44   // 4 32-bit partial sums of squares
45   __m128i sumsq2_vec = _mm_setzero_si128();
46   __m128i cross_vec = _mm_setzero_si128();
47 
48   const __m128i mask = _mm_load_si128((__m128i *)byte_mask);
49   const __m128i zero = _mm_setzero_si128();
50 
51   im1 += (y1 - MATCH_SZ_BY2) * stride1 + (x1 - MATCH_SZ_BY2);
52   im2 += (y2 - MATCH_SZ_BY2) * stride2 + (x2 - MATCH_SZ_BY2);
53 
54   for (i = 0; i < MATCH_SZ; ++i) {
55     const __m128i v1 =
56         _mm_and_si128(_mm_loadu_si128((__m128i *)&im1[i * stride1]), mask);
57     const __m128i v2 =
58         _mm_and_si128(_mm_loadu_si128((__m128i *)&im2[i * stride2]), mask);
59 
60     // Using the 'sad' intrinsic here is a bit faster than adding
61     // v1_l + v1_r and v2_l + v2_r, plus it avoids the need for a 16->32 bit
62     // conversion step later, for a net speedup of ~10%
63     sum1_vec = _mm_add_epi16(sum1_vec, _mm_sad_epu8(v1, zero));
64     sum2_vec = _mm_add_epi16(sum2_vec, _mm_sad_epu8(v2, zero));
65 
66     const __m128i v1_l = _mm_cvtepu8_epi16(v1);
67     const __m128i v1_r = _mm_cvtepu8_epi16(_mm_srli_si128(v1, 8));
68     const __m128i v2_l = _mm_cvtepu8_epi16(v2);
69     const __m128i v2_r = _mm_cvtepu8_epi16(_mm_srli_si128(v2, 8));
70 
71     sumsq2_vec = _mm_add_epi32(
72         sumsq2_vec,
73         _mm_add_epi32(_mm_madd_epi16(v2_l, v2_l), _mm_madd_epi16(v2_r, v2_r)));
74     cross_vec = _mm_add_epi32(
75         cross_vec,
76         _mm_add_epi32(_mm_madd_epi16(v1_l, v2_l), _mm_madd_epi16(v1_r, v2_r)));
77   }
78 
79   // Now we can treat the four registers (sum1_vec, sum2_vec, sumsq2_vec,
80   // cross_vec)
81   // as holding 4 32-bit elements each, which we want to sum horizontally.
82   // We do this by transposing and then summing vertically.
83   __m128i tmp_0 = _mm_unpacklo_epi32(sum1_vec, sum2_vec);
84   __m128i tmp_1 = _mm_unpackhi_epi32(sum1_vec, sum2_vec);
85   __m128i tmp_2 = _mm_unpacklo_epi32(sumsq2_vec, cross_vec);
86   __m128i tmp_3 = _mm_unpackhi_epi32(sumsq2_vec, cross_vec);
87 
88   __m128i tmp_4 = _mm_unpacklo_epi64(tmp_0, tmp_2);
89   __m128i tmp_5 = _mm_unpackhi_epi64(tmp_0, tmp_2);
90   __m128i tmp_6 = _mm_unpacklo_epi64(tmp_1, tmp_3);
91   __m128i tmp_7 = _mm_unpackhi_epi64(tmp_1, tmp_3);
92 
93   __m128i res =
94       _mm_add_epi32(_mm_add_epi32(tmp_4, tmp_5), _mm_add_epi32(tmp_6, tmp_7));
95 
96   int sum1 = _mm_extract_epi32(res, 0);
97   int sum2 = _mm_extract_epi32(res, 1);
98   int sumsq2 = _mm_extract_epi32(res, 2);
99   int cross = _mm_extract_epi32(res, 3);
100 
101   int var2 = sumsq2 * MATCH_SZ_SQ - sum2 * sum2;
102   int cov = cross * MATCH_SZ_SQ - sum1 * sum2;
103   aom_clear_system_state();
104   return cov / sqrt((double)var2);
105 }
106