1 /*
2  * Copyright (c) 2016, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <immintrin.h>
13 
14 #include "config/aom_dsp_rtcd.h"
15 
sad32x32(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride)16 static unsigned int sad32x32(const uint8_t *src_ptr, int src_stride,
17                              const uint8_t *ref_ptr, int ref_stride) {
18   __m256i s1, s2, r1, r2;
19   __m256i sum = _mm256_setzero_si256();
20   __m128i sum_i128;
21   int i;
22 
23   for (i = 0; i < 16; ++i) {
24     r1 = _mm256_loadu_si256((__m256i const *)ref_ptr);
25     r2 = _mm256_loadu_si256((__m256i const *)(ref_ptr + ref_stride));
26     s1 = _mm256_sad_epu8(r1, _mm256_loadu_si256((__m256i const *)src_ptr));
27     s2 = _mm256_sad_epu8(
28         r2, _mm256_loadu_si256((__m256i const *)(src_ptr + src_stride)));
29     sum = _mm256_add_epi32(sum, _mm256_add_epi32(s1, s2));
30     ref_ptr += ref_stride << 1;
31     src_ptr += src_stride << 1;
32   }
33 
34   sum = _mm256_add_epi32(sum, _mm256_srli_si256(sum, 8));
35   sum_i128 = _mm_add_epi32(_mm256_extracti128_si256(sum, 1),
36                            _mm256_castsi256_si128(sum));
37   return _mm_cvtsi128_si32(sum_i128);
38 }
39 
sad64x32(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride)40 static unsigned int sad64x32(const uint8_t *src_ptr, int src_stride,
41                              const uint8_t *ref_ptr, int ref_stride) {
42   unsigned int half_width = 32;
43   uint32_t sum = sad32x32(src_ptr, src_stride, ref_ptr, ref_stride);
44   src_ptr += half_width;
45   ref_ptr += half_width;
46   sum += sad32x32(src_ptr, src_stride, ref_ptr, ref_stride);
47   return sum;
48 }
49 
sad64x64(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride)50 static unsigned int sad64x64(const uint8_t *src_ptr, int src_stride,
51                              const uint8_t *ref_ptr, int ref_stride) {
52   uint32_t sum = sad64x32(src_ptr, src_stride, ref_ptr, ref_stride);
53   src_ptr += src_stride << 5;
54   ref_ptr += ref_stride << 5;
55   sum += sad64x32(src_ptr, src_stride, ref_ptr, ref_stride);
56   return sum;
57 }
58 
aom_sad128x64_avx2(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride)59 unsigned int aom_sad128x64_avx2(const uint8_t *src_ptr, int src_stride,
60                                 const uint8_t *ref_ptr, int ref_stride) {
61   unsigned int half_width = 64;
62   uint32_t sum = sad64x64(src_ptr, src_stride, ref_ptr, ref_stride);
63   src_ptr += half_width;
64   ref_ptr += half_width;
65   sum += sad64x64(src_ptr, src_stride, ref_ptr, ref_stride);
66   return sum;
67 }
68 
aom_sad64x128_avx2(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride)69 unsigned int aom_sad64x128_avx2(const uint8_t *src_ptr, int src_stride,
70                                 const uint8_t *ref_ptr, int ref_stride) {
71   uint32_t sum = sad64x64(src_ptr, src_stride, ref_ptr, ref_stride);
72   src_ptr += src_stride << 6;
73   ref_ptr += ref_stride << 6;
74   sum += sad64x64(src_ptr, src_stride, ref_ptr, ref_stride);
75   return sum;
76 }
77 
aom_sad128x128_avx2(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride)78 unsigned int aom_sad128x128_avx2(const uint8_t *src_ptr, int src_stride,
79                                  const uint8_t *ref_ptr, int ref_stride) {
80   uint32_t sum = aom_sad128x64_avx2(src_ptr, src_stride, ref_ptr, ref_stride);
81   src_ptr += src_stride << 6;
82   ref_ptr += ref_stride << 6;
83   sum += aom_sad128x64_avx2(src_ptr, src_stride, ref_ptr, ref_stride);
84   return sum;
85 }
86 
aom_sad_skip_128x64_avx2(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride)87 unsigned int aom_sad_skip_128x64_avx2(const uint8_t *src_ptr, int src_stride,
88                                       const uint8_t *ref_ptr, int ref_stride) {
89   const uint32_t half_width = 64;
90   uint32_t sum = sad64x32(src_ptr, src_stride * 2, ref_ptr, ref_stride * 2);
91   src_ptr += half_width;
92   ref_ptr += half_width;
93   sum += sad64x32(src_ptr, src_stride * 2, ref_ptr, ref_stride * 2);
94   return 2 * sum;
95 }
96 
aom_sad_skip_64x128_avx2(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride)97 unsigned int aom_sad_skip_64x128_avx2(const uint8_t *src_ptr, int src_stride,
98                                       const uint8_t *ref_ptr, int ref_stride) {
99   const uint32_t sum =
100       sad64x64(src_ptr, 2 * src_stride, ref_ptr, 2 * ref_stride);
101   return 2 * sum;
102 }
103 
aom_sad_skip_128x128_avx2(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride)104 unsigned int aom_sad_skip_128x128_avx2(const uint8_t *src_ptr, int src_stride,
105                                        const uint8_t *ref_ptr, int ref_stride) {
106   const uint32_t sum =
107       aom_sad128x64_avx2(src_ptr, 2 * src_stride, ref_ptr, 2 * ref_stride);
108   return 2 * sum;
109 }
110 
sad_w64_avg_avx2(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,const int h,const uint8_t * second_pred,const int second_pred_stride)111 static unsigned int sad_w64_avg_avx2(const uint8_t *src_ptr, int src_stride,
112                                      const uint8_t *ref_ptr, int ref_stride,
113                                      const int h, const uint8_t *second_pred,
114                                      const int second_pred_stride) {
115   int i, res;
116   __m256i sad1_reg, sad2_reg, ref1_reg, ref2_reg;
117   __m256i sum_sad = _mm256_setzero_si256();
118   __m256i sum_sad_h;
119   __m128i sum_sad128;
120   for (i = 0; i < h; i++) {
121     ref1_reg = _mm256_loadu_si256((__m256i const *)ref_ptr);
122     ref2_reg = _mm256_loadu_si256((__m256i const *)(ref_ptr + 32));
123     ref1_reg = _mm256_avg_epu8(
124         ref1_reg, _mm256_loadu_si256((__m256i const *)second_pred));
125     ref2_reg = _mm256_avg_epu8(
126         ref2_reg, _mm256_loadu_si256((__m256i const *)(second_pred + 32)));
127     sad1_reg =
128         _mm256_sad_epu8(ref1_reg, _mm256_loadu_si256((__m256i const *)src_ptr));
129     sad2_reg = _mm256_sad_epu8(
130         ref2_reg, _mm256_loadu_si256((__m256i const *)(src_ptr + 32)));
131     sum_sad = _mm256_add_epi32(sum_sad, _mm256_add_epi32(sad1_reg, sad2_reg));
132     ref_ptr += ref_stride;
133     src_ptr += src_stride;
134     second_pred += second_pred_stride;
135   }
136   sum_sad_h = _mm256_srli_si256(sum_sad, 8);
137   sum_sad = _mm256_add_epi32(sum_sad, sum_sad_h);
138   sum_sad128 = _mm256_extracti128_si256(sum_sad, 1);
139   sum_sad128 = _mm_add_epi32(_mm256_castsi256_si128(sum_sad), sum_sad128);
140   res = _mm_cvtsi128_si32(sum_sad128);
141 
142   return res;
143 }
144 
aom_sad64x128_avg_avx2(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,const uint8_t * second_pred)145 unsigned int aom_sad64x128_avg_avx2(const uint8_t *src_ptr, int src_stride,
146                                     const uint8_t *ref_ptr, int ref_stride,
147                                     const uint8_t *second_pred) {
148   uint32_t sum = sad_w64_avg_avx2(src_ptr, src_stride, ref_ptr, ref_stride, 64,
149                                   second_pred, 64);
150   src_ptr += src_stride << 6;
151   ref_ptr += ref_stride << 6;
152   second_pred += 64 << 6;
153   sum += sad_w64_avg_avx2(src_ptr, src_stride, ref_ptr, ref_stride, 64,
154                           second_pred, 64);
155   return sum;
156 }
157 
aom_sad128x64_avg_avx2(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,const uint8_t * second_pred)158 unsigned int aom_sad128x64_avg_avx2(const uint8_t *src_ptr, int src_stride,
159                                     const uint8_t *ref_ptr, int ref_stride,
160                                     const uint8_t *second_pred) {
161   unsigned int half_width = 64;
162   uint32_t sum = sad_w64_avg_avx2(src_ptr, src_stride, ref_ptr, ref_stride, 64,
163                                   second_pred, 128);
164   src_ptr += half_width;
165   ref_ptr += half_width;
166   second_pred += half_width;
167   sum += sad_w64_avg_avx2(src_ptr, src_stride, ref_ptr, ref_stride, 64,
168                           second_pred, 128);
169   return sum;
170 }
171 
aom_sad128x128_avg_avx2(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,const uint8_t * second_pred)172 unsigned int aom_sad128x128_avg_avx2(const uint8_t *src_ptr, int src_stride,
173                                      const uint8_t *ref_ptr, int ref_stride,
174                                      const uint8_t *second_pred) {
175   uint32_t sum = aom_sad128x64_avg_avx2(src_ptr, src_stride, ref_ptr,
176                                         ref_stride, second_pred);
177   src_ptr += src_stride << 6;
178   ref_ptr += ref_stride << 6;
179   second_pred += 128 << 6;
180   sum += aom_sad128x64_avg_avx2(src_ptr, src_stride, ref_ptr, ref_stride,
181                                 second_pred);
182   return sum;
183 }
184