1 /*
2  * Copyright (c) 2016, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <assert.h>
13 #include <emmintrin.h>  // SSE2
14 
15 #include "config/aom_config.h"
16 #include "config/aom_dsp_rtcd.h"
17 
18 #include "aom_dsp/x86/synonyms.h"
19 #include "aom_ports/mem.h"
20 
21 #include "av1/common/filter.h"
22 #include "av1/common/reconinter.h"
23 
24 typedef uint32_t (*high_variance_fn_t)(const uint16_t *src, int src_stride,
25                                        const uint16_t *ref, int ref_stride,
26                                        uint32_t *sse, int *sum);
27 
28 uint32_t aom_highbd_calc8x8var_sse2(const uint16_t *src, int src_stride,
29                                     const uint16_t *ref, int ref_stride,
30                                     uint32_t *sse, int *sum);
31 
32 uint32_t aom_highbd_calc16x16var_sse2(const uint16_t *src, int src_stride,
33                                       const uint16_t *ref, int ref_stride,
34                                       uint32_t *sse, int *sum);
35 
highbd_8_variance_sse2(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int w,int h,uint32_t * sse,int * sum,high_variance_fn_t var_fn,int block_size)36 static void highbd_8_variance_sse2(const uint16_t *src, int src_stride,
37                                    const uint16_t *ref, int ref_stride, int w,
38                                    int h, uint32_t *sse, int *sum,
39                                    high_variance_fn_t var_fn, int block_size) {
40   int i, j;
41 
42   *sse = 0;
43   *sum = 0;
44 
45   for (i = 0; i < h; i += block_size) {
46     for (j = 0; j < w; j += block_size) {
47       unsigned int sse0;
48       int sum0;
49       var_fn(src + src_stride * i + j, src_stride, ref + ref_stride * i + j,
50              ref_stride, &sse0, &sum0);
51       *sse += sse0;
52       *sum += sum0;
53     }
54   }
55 }
56 
highbd_10_variance_sse2(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int w,int h,uint32_t * sse,int * sum,high_variance_fn_t var_fn,int block_size)57 static void highbd_10_variance_sse2(const uint16_t *src, int src_stride,
58                                     const uint16_t *ref, int ref_stride, int w,
59                                     int h, uint32_t *sse, int *sum,
60                                     high_variance_fn_t var_fn, int block_size) {
61   int i, j;
62   uint64_t sse_long = 0;
63   int32_t sum_long = 0;
64 
65   for (i = 0; i < h; i += block_size) {
66     for (j = 0; j < w; j += block_size) {
67       unsigned int sse0;
68       int sum0;
69       var_fn(src + src_stride * i + j, src_stride, ref + ref_stride * i + j,
70              ref_stride, &sse0, &sum0);
71       sse_long += sse0;
72       sum_long += sum0;
73     }
74   }
75   *sum = ROUND_POWER_OF_TWO(sum_long, 2);
76   *sse = (uint32_t)ROUND_POWER_OF_TWO(sse_long, 4);
77 }
78 
highbd_12_variance_sse2(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int w,int h,uint32_t * sse,int * sum,high_variance_fn_t var_fn,int block_size)79 static void highbd_12_variance_sse2(const uint16_t *src, int src_stride,
80                                     const uint16_t *ref, int ref_stride, int w,
81                                     int h, uint32_t *sse, int *sum,
82                                     high_variance_fn_t var_fn, int block_size) {
83   int i, j;
84   uint64_t sse_long = 0;
85   int32_t sum_long = 0;
86 
87   for (i = 0; i < h; i += block_size) {
88     for (j = 0; j < w; j += block_size) {
89       unsigned int sse0;
90       int sum0;
91       var_fn(src + src_stride * i + j, src_stride, ref + ref_stride * i + j,
92              ref_stride, &sse0, &sum0);
93       sse_long += sse0;
94       sum_long += sum0;
95     }
96   }
97   *sum = ROUND_POWER_OF_TWO(sum_long, 4);
98   *sse = (uint32_t)ROUND_POWER_OF_TWO(sse_long, 8);
99 }
100 
101 #define HIGH_GET_VAR(S)                                                       \
102   void aom_highbd_get##S##x##S##var_sse2(const uint8_t *src8, int src_stride, \
103                                          const uint8_t *ref8, int ref_stride, \
104                                          uint32_t *sse, int *sum) {           \
105     uint16_t *src = CONVERT_TO_SHORTPTR(src8);                                \
106     uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);                                \
107     aom_highbd_calc##S##x##S##var_sse2(src, src_stride, ref, ref_stride, sse, \
108                                        sum);                                  \
109   }                                                                           \
110                                                                               \
111   void aom_highbd_10_get##S##x##S##var_sse2(                                  \
112       const uint8_t *src8, int src_stride, const uint8_t *ref8,               \
113       int ref_stride, uint32_t *sse, int *sum) {                              \
114     uint16_t *src = CONVERT_TO_SHORTPTR(src8);                                \
115     uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);                                \
116     aom_highbd_calc##S##x##S##var_sse2(src, src_stride, ref, ref_stride, sse, \
117                                        sum);                                  \
118     *sum = ROUND_POWER_OF_TWO(*sum, 2);                                       \
119     *sse = ROUND_POWER_OF_TWO(*sse, 4);                                       \
120   }                                                                           \
121                                                                               \
122   void aom_highbd_12_get##S##x##S##var_sse2(                                  \
123       const uint8_t *src8, int src_stride, const uint8_t *ref8,               \
124       int ref_stride, uint32_t *sse, int *sum) {                              \
125     uint16_t *src = CONVERT_TO_SHORTPTR(src8);                                \
126     uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);                                \
127     aom_highbd_calc##S##x##S##var_sse2(src, src_stride, ref, ref_stride, sse, \
128                                        sum);                                  \
129     *sum = ROUND_POWER_OF_TWO(*sum, 4);                                       \
130     *sse = ROUND_POWER_OF_TWO(*sse, 8);                                       \
131   }
132 
133 HIGH_GET_VAR(16);
134 HIGH_GET_VAR(8);
135 
136 #undef HIGH_GET_VAR
137 
138 #define VAR_FN(w, h, block_size, shift)                                    \
139   uint32_t aom_highbd_8_variance##w##x##h##_sse2(                          \
140       const uint8_t *src8, int src_stride, const uint8_t *ref8,            \
141       int ref_stride, uint32_t *sse) {                                     \
142     int sum;                                                               \
143     uint16_t *src = CONVERT_TO_SHORTPTR(src8);                             \
144     uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);                             \
145     highbd_8_variance_sse2(                                                \
146         src, src_stride, ref, ref_stride, w, h, sse, &sum,                 \
147         aom_highbd_calc##block_size##x##block_size##var_sse2, block_size); \
148     return *sse - (uint32_t)(((int64_t)sum * sum) >> shift);               \
149   }                                                                        \
150                                                                            \
151   uint32_t aom_highbd_10_variance##w##x##h##_sse2(                         \
152       const uint8_t *src8, int src_stride, const uint8_t *ref8,            \
153       int ref_stride, uint32_t *sse) {                                     \
154     int sum;                                                               \
155     int64_t var;                                                           \
156     uint16_t *src = CONVERT_TO_SHORTPTR(src8);                             \
157     uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);                             \
158     highbd_10_variance_sse2(                                               \
159         src, src_stride, ref, ref_stride, w, h, sse, &sum,                 \
160         aom_highbd_calc##block_size##x##block_size##var_sse2, block_size); \
161     var = (int64_t)(*sse) - (((int64_t)sum * sum) >> shift);               \
162     return (var >= 0) ? (uint32_t)var : 0;                                 \
163   }                                                                        \
164                                                                            \
165   uint32_t aom_highbd_12_variance##w##x##h##_sse2(                         \
166       const uint8_t *src8, int src_stride, const uint8_t *ref8,            \
167       int ref_stride, uint32_t *sse) {                                     \
168     int sum;                                                               \
169     int64_t var;                                                           \
170     uint16_t *src = CONVERT_TO_SHORTPTR(src8);                             \
171     uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);                             \
172     highbd_12_variance_sse2(                                               \
173         src, src_stride, ref, ref_stride, w, h, sse, &sum,                 \
174         aom_highbd_calc##block_size##x##block_size##var_sse2, block_size); \
175     var = (int64_t)(*sse) - (((int64_t)sum * sum) >> shift);               \
176     return (var >= 0) ? (uint32_t)var : 0;                                 \
177   }
178 
179 VAR_FN(128, 128, 16, 14);
180 VAR_FN(128, 64, 16, 13);
181 VAR_FN(64, 128, 16, 13);
182 VAR_FN(64, 64, 16, 12);
183 VAR_FN(64, 32, 16, 11);
184 VAR_FN(32, 64, 16, 11);
185 VAR_FN(32, 32, 16, 10);
186 VAR_FN(32, 16, 16, 9);
187 VAR_FN(16, 32, 16, 9);
188 VAR_FN(16, 16, 16, 8);
189 VAR_FN(16, 8, 8, 7);
190 VAR_FN(8, 16, 8, 7);
191 VAR_FN(8, 8, 8, 6);
192 VAR_FN(8, 32, 8, 8);
193 VAR_FN(32, 8, 8, 8);
194 VAR_FN(16, 64, 16, 10);
195 VAR_FN(64, 16, 16, 10);
196 
197 #undef VAR_FN
198 
aom_highbd_8_mse16x16_sse2(const uint8_t * src8,int src_stride,const uint8_t * ref8,int ref_stride,unsigned int * sse)199 unsigned int aom_highbd_8_mse16x16_sse2(const uint8_t *src8, int src_stride,
200                                         const uint8_t *ref8, int ref_stride,
201                                         unsigned int *sse) {
202   int sum;
203   uint16_t *src = CONVERT_TO_SHORTPTR(src8);
204   uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
205   highbd_8_variance_sse2(src, src_stride, ref, ref_stride, 16, 16, sse, &sum,
206                          aom_highbd_calc16x16var_sse2, 16);
207   return *sse;
208 }
209 
aom_highbd_10_mse16x16_sse2(const uint8_t * src8,int src_stride,const uint8_t * ref8,int ref_stride,unsigned int * sse)210 unsigned int aom_highbd_10_mse16x16_sse2(const uint8_t *src8, int src_stride,
211                                          const uint8_t *ref8, int ref_stride,
212                                          unsigned int *sse) {
213   int sum;
214   uint16_t *src = CONVERT_TO_SHORTPTR(src8);
215   uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
216   highbd_10_variance_sse2(src, src_stride, ref, ref_stride, 16, 16, sse, &sum,
217                           aom_highbd_calc16x16var_sse2, 16);
218   return *sse;
219 }
220 
aom_highbd_12_mse16x16_sse2(const uint8_t * src8,int src_stride,const uint8_t * ref8,int ref_stride,unsigned int * sse)221 unsigned int aom_highbd_12_mse16x16_sse2(const uint8_t *src8, int src_stride,
222                                          const uint8_t *ref8, int ref_stride,
223                                          unsigned int *sse) {
224   int sum;
225   uint16_t *src = CONVERT_TO_SHORTPTR(src8);
226   uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
227   highbd_12_variance_sse2(src, src_stride, ref, ref_stride, 16, 16, sse, &sum,
228                           aom_highbd_calc16x16var_sse2, 16);
229   return *sse;
230 }
231 
aom_highbd_8_mse8x8_sse2(const uint8_t * src8,int src_stride,const uint8_t * ref8,int ref_stride,unsigned int * sse)232 unsigned int aom_highbd_8_mse8x8_sse2(const uint8_t *src8, int src_stride,
233                                       const uint8_t *ref8, int ref_stride,
234                                       unsigned int *sse) {
235   int sum;
236   uint16_t *src = CONVERT_TO_SHORTPTR(src8);
237   uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
238   highbd_8_variance_sse2(src, src_stride, ref, ref_stride, 8, 8, sse, &sum,
239                          aom_highbd_calc8x8var_sse2, 8);
240   return *sse;
241 }
242 
aom_highbd_10_mse8x8_sse2(const uint8_t * src8,int src_stride,const uint8_t * ref8,int ref_stride,unsigned int * sse)243 unsigned int aom_highbd_10_mse8x8_sse2(const uint8_t *src8, int src_stride,
244                                        const uint8_t *ref8, int ref_stride,
245                                        unsigned int *sse) {
246   int sum;
247   uint16_t *src = CONVERT_TO_SHORTPTR(src8);
248   uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
249   highbd_10_variance_sse2(src, src_stride, ref, ref_stride, 8, 8, sse, &sum,
250                           aom_highbd_calc8x8var_sse2, 8);
251   return *sse;
252 }
253 
aom_highbd_12_mse8x8_sse2(const uint8_t * src8,int src_stride,const uint8_t * ref8,int ref_stride,unsigned int * sse)254 unsigned int aom_highbd_12_mse8x8_sse2(const uint8_t *src8, int src_stride,
255                                        const uint8_t *ref8, int ref_stride,
256                                        unsigned int *sse) {
257   int sum;
258   uint16_t *src = CONVERT_TO_SHORTPTR(src8);
259   uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
260   highbd_12_variance_sse2(src, src_stride, ref, ref_stride, 8, 8, sse, &sum,
261                           aom_highbd_calc8x8var_sse2, 8);
262   return *sse;
263 }
264 
265 // The 2 unused parameters are place holders for PIC enabled build.
266 // These definitions are for functions defined in
267 // highbd_subpel_variance_impl_sse2.asm
268 #define DECL(w, opt)                                                         \
269   int aom_highbd_sub_pixel_variance##w##xh_##opt(                            \
270       const uint16_t *src, ptrdiff_t src_stride, int x_offset, int y_offset, \
271       const uint16_t *dst, ptrdiff_t dst_stride, int height,                 \
272       unsigned int *sse, void *unused0, void *unused);
273 #define DECLS(opt) \
274   DECL(8, opt);    \
275   DECL(16, opt)
276 
277 DECLS(sse2);
278 
279 #undef DECLS
280 #undef DECL
281 
282 #define FN(w, h, wf, wlog2, hlog2, opt, cast)                                  \
283   uint32_t aom_highbd_8_sub_pixel_variance##w##x##h##_##opt(                   \
284       const uint8_t *src8, int src_stride, int x_offset, int y_offset,         \
285       const uint8_t *dst8, int dst_stride, uint32_t *sse_ptr) {                \
286     uint16_t *src = CONVERT_TO_SHORTPTR(src8);                                 \
287     uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);                                 \
288     int se = 0;                                                                \
289     unsigned int sse = 0;                                                      \
290     unsigned int sse2;                                                         \
291     int row_rep = (w > 64) ? 2 : 1;                                            \
292     for (int wd_64 = 0; wd_64 < row_rep; wd_64++) {                            \
293       src += wd_64 * 64;                                                       \
294       dst += wd_64 * 64;                                                       \
295       int se2 = aom_highbd_sub_pixel_variance##wf##xh_##opt(                   \
296           src, src_stride, x_offset, y_offset, dst, dst_stride, h, &sse2,      \
297           NULL, NULL);                                                         \
298       se += se2;                                                               \
299       sse += sse2;                                                             \
300       if (w > wf) {                                                            \
301         se2 = aom_highbd_sub_pixel_variance##wf##xh_##opt(                     \
302             src + wf, src_stride, x_offset, y_offset, dst + wf, dst_stride, h, \
303             &sse2, NULL, NULL);                                                \
304         se += se2;                                                             \
305         sse += sse2;                                                           \
306         if (w > wf * 2) {                                                      \
307           se2 = aom_highbd_sub_pixel_variance##wf##xh_##opt(                   \
308               src + 2 * wf, src_stride, x_offset, y_offset, dst + 2 * wf,      \
309               dst_stride, h, &sse2, NULL, NULL);                               \
310           se += se2;                                                           \
311           sse += sse2;                                                         \
312           se2 = aom_highbd_sub_pixel_variance##wf##xh_##opt(                   \
313               src + 3 * wf, src_stride, x_offset, y_offset, dst + 3 * wf,      \
314               dst_stride, h, &sse2, NULL, NULL);                               \
315           se += se2;                                                           \
316           sse += sse2;                                                         \
317         }                                                                      \
318       }                                                                        \
319     }                                                                          \
320     *sse_ptr = sse;                                                            \
321     return sse - (uint32_t)((cast se * se) >> (wlog2 + hlog2));                \
322   }                                                                            \
323                                                                                \
324   uint32_t aom_highbd_10_sub_pixel_variance##w##x##h##_##opt(                  \
325       const uint8_t *src8, int src_stride, int x_offset, int y_offset,         \
326       const uint8_t *dst8, int dst_stride, uint32_t *sse_ptr) {                \
327     int64_t var;                                                               \
328     uint32_t sse;                                                              \
329     uint64_t long_sse = 0;                                                     \
330     uint16_t *src = CONVERT_TO_SHORTPTR(src8);                                 \
331     uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);                                 \
332     int se = 0;                                                                \
333     int row_rep = (w > 64) ? 2 : 1;                                            \
334     for (int wd_64 = 0; wd_64 < row_rep; wd_64++) {                            \
335       src += wd_64 * 64;                                                       \
336       dst += wd_64 * 64;                                                       \
337       int se2 = aom_highbd_sub_pixel_variance##wf##xh_##opt(                   \
338           src, src_stride, x_offset, y_offset, dst, dst_stride, h, &sse, NULL, \
339           NULL);                                                               \
340       se += se2;                                                               \
341       long_sse += sse;                                                         \
342       if (w > wf) {                                                            \
343         uint32_t sse2;                                                         \
344         se2 = aom_highbd_sub_pixel_variance##wf##xh_##opt(                     \
345             src + wf, src_stride, x_offset, y_offset, dst + wf, dst_stride, h, \
346             &sse2, NULL, NULL);                                                \
347         se += se2;                                                             \
348         long_sse += sse2;                                                      \
349         if (w > wf * 2) {                                                      \
350           se2 = aom_highbd_sub_pixel_variance##wf##xh_##opt(                   \
351               src + 2 * wf, src_stride, x_offset, y_offset, dst + 2 * wf,      \
352               dst_stride, h, &sse2, NULL, NULL);                               \
353           se += se2;                                                           \
354           long_sse += sse2;                                                    \
355           se2 = aom_highbd_sub_pixel_variance##wf##xh_##opt(                   \
356               src + 3 * wf, src_stride, x_offset, y_offset, dst + 3 * wf,      \
357               dst_stride, h, &sse2, NULL, NULL);                               \
358           se += se2;                                                           \
359           long_sse += sse2;                                                    \
360         }                                                                      \
361       }                                                                        \
362     }                                                                          \
363     se = ROUND_POWER_OF_TWO(se, 2);                                            \
364     sse = (uint32_t)ROUND_POWER_OF_TWO(long_sse, 4);                           \
365     *sse_ptr = sse;                                                            \
366     var = (int64_t)(sse) - ((cast se * se) >> (wlog2 + hlog2));                \
367     return (var >= 0) ? (uint32_t)var : 0;                                     \
368   }                                                                            \
369                                                                                \
370   uint32_t aom_highbd_12_sub_pixel_variance##w##x##h##_##opt(                  \
371       const uint8_t *src8, int src_stride, int x_offset, int y_offset,         \
372       const uint8_t *dst8, int dst_stride, uint32_t *sse_ptr) {                \
373     int start_row;                                                             \
374     uint32_t sse;                                                              \
375     int se = 0;                                                                \
376     int64_t var;                                                               \
377     uint64_t long_sse = 0;                                                     \
378     uint16_t *src = CONVERT_TO_SHORTPTR(src8);                                 \
379     uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);                                 \
380     int row_rep = (w > 64) ? 2 : 1;                                            \
381     for (start_row = 0; start_row < h; start_row += 16) {                      \
382       uint32_t sse2;                                                           \
383       int height = h - start_row < 16 ? h - start_row : 16;                    \
384       uint16_t *src_tmp = src + (start_row * src_stride);                      \
385       uint16_t *dst_tmp = dst + (start_row * dst_stride);                      \
386       for (int wd_64 = 0; wd_64 < row_rep; wd_64++) {                          \
387         src_tmp += wd_64 * 64;                                                 \
388         dst_tmp += wd_64 * 64;                                                 \
389         int se2 = aom_highbd_sub_pixel_variance##wf##xh_##opt(                 \
390             src_tmp, src_stride, x_offset, y_offset, dst_tmp, dst_stride,      \
391             height, &sse2, NULL, NULL);                                        \
392         se += se2;                                                             \
393         long_sse += sse2;                                                      \
394         if (w > wf) {                                                          \
395           se2 = aom_highbd_sub_pixel_variance##wf##xh_##opt(                   \
396               src_tmp + wf, src_stride, x_offset, y_offset, dst_tmp + wf,      \
397               dst_stride, height, &sse2, NULL, NULL);                          \
398           se += se2;                                                           \
399           long_sse += sse2;                                                    \
400           if (w > wf * 2) {                                                    \
401             se2 = aom_highbd_sub_pixel_variance##wf##xh_##opt(                 \
402                 src_tmp + 2 * wf, src_stride, x_offset, y_offset,              \
403                 dst_tmp + 2 * wf, dst_stride, height, &sse2, NULL, NULL);      \
404             se += se2;                                                         \
405             long_sse += sse2;                                                  \
406             se2 = aom_highbd_sub_pixel_variance##wf##xh_##opt(                 \
407                 src_tmp + 3 * wf, src_stride, x_offset, y_offset,              \
408                 dst_tmp + 3 * wf, dst_stride, height, &sse2, NULL, NULL);      \
409             se += se2;                                                         \
410             long_sse += sse2;                                                  \
411           }                                                                    \
412         }                                                                      \
413       }                                                                        \
414     }                                                                          \
415     se = ROUND_POWER_OF_TWO(se, 4);                                            \
416     sse = (uint32_t)ROUND_POWER_OF_TWO(long_sse, 8);                           \
417     *sse_ptr = sse;                                                            \
418     var = (int64_t)(sse) - ((cast se * se) >> (wlog2 + hlog2));                \
419     return (var >= 0) ? (uint32_t)var : 0;                                     \
420   }
421 
422 #define FNS(opt)                          \
423   FN(128, 128, 16, 7, 7, opt, (int64_t)); \
424   FN(128, 64, 16, 7, 6, opt, (int64_t));  \
425   FN(64, 128, 16, 6, 7, opt, (int64_t));  \
426   FN(64, 64, 16, 6, 6, opt, (int64_t));   \
427   FN(64, 32, 16, 6, 5, opt, (int64_t));   \
428   FN(32, 64, 16, 5, 6, opt, (int64_t));   \
429   FN(32, 32, 16, 5, 5, opt, (int64_t));   \
430   FN(32, 16, 16, 5, 4, opt, (int64_t));   \
431   FN(16, 32, 16, 4, 5, opt, (int64_t));   \
432   FN(16, 16, 16, 4, 4, opt, (int64_t));   \
433   FN(16, 8, 16, 4, 3, opt, (int64_t));    \
434   FN(8, 16, 8, 3, 4, opt, (int64_t));     \
435   FN(8, 8, 8, 3, 3, opt, (int64_t));      \
436   FN(8, 4, 8, 3, 2, opt, (int64_t));      \
437   FN(16, 4, 16, 4, 2, opt, (int64_t));    \
438   FN(8, 32, 8, 3, 5, opt, (int64_t));     \
439   FN(32, 8, 16, 5, 3, opt, (int64_t));    \
440   FN(16, 64, 16, 4, 6, opt, (int64_t));   \
441   FN(64, 16, 16, 6, 4, opt, (int64_t))
442 
443 FNS(sse2);
444 
445 #undef FNS
446 #undef FN
447 
448 // The 2 unused parameters are place holders for PIC enabled build.
449 #define DECL(w, opt)                                                         \
450   int aom_highbd_sub_pixel_avg_variance##w##xh_##opt(                        \
451       const uint16_t *src, ptrdiff_t src_stride, int x_offset, int y_offset, \
452       const uint16_t *dst, ptrdiff_t dst_stride, const uint16_t *sec,        \
453       ptrdiff_t sec_stride, int height, unsigned int *sse, void *unused0,    \
454       void *unused);
455 #define DECLS(opt) \
456   DECL(16, opt)    \
457   DECL(8, opt)
458 
459 DECLS(sse2);
460 #undef DECL
461 #undef DECLS
462 
463 #define FN(w, h, wf, wlog2, hlog2, opt, cast)                                  \
464   uint32_t aom_highbd_8_sub_pixel_avg_variance##w##x##h##_##opt(               \
465       const uint8_t *src8, int src_stride, int x_offset, int y_offset,         \
466       const uint8_t *dst8, int dst_stride, uint32_t *sse_ptr,                  \
467       const uint8_t *sec8) {                                                   \
468     uint32_t sse;                                                              \
469     uint16_t *src = CONVERT_TO_SHORTPTR(src8);                                 \
470     uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);                                 \
471     uint16_t *sec = CONVERT_TO_SHORTPTR(sec8);                                 \
472     int se = aom_highbd_sub_pixel_avg_variance##wf##xh_##opt(                  \
473         src, src_stride, x_offset, y_offset, dst, dst_stride, sec, w, h, &sse, \
474         NULL, NULL);                                                           \
475     if (w > wf) {                                                              \
476       uint32_t sse2;                                                           \
477       int se2 = aom_highbd_sub_pixel_avg_variance##wf##xh_##opt(               \
478           src + wf, src_stride, x_offset, y_offset, dst + wf, dst_stride,      \
479           sec + wf, w, h, &sse2, NULL, NULL);                                  \
480       se += se2;                                                               \
481       sse += sse2;                                                             \
482       if (w > wf * 2) {                                                        \
483         se2 = aom_highbd_sub_pixel_avg_variance##wf##xh_##opt(                 \
484             src + 2 * wf, src_stride, x_offset, y_offset, dst + 2 * wf,        \
485             dst_stride, sec + 2 * wf, w, h, &sse2, NULL, NULL);                \
486         se += se2;                                                             \
487         sse += sse2;                                                           \
488         se2 = aom_highbd_sub_pixel_avg_variance##wf##xh_##opt(                 \
489             src + 3 * wf, src_stride, x_offset, y_offset, dst + 3 * wf,        \
490             dst_stride, sec + 3 * wf, w, h, &sse2, NULL, NULL);                \
491         se += se2;                                                             \
492         sse += sse2;                                                           \
493       }                                                                        \
494     }                                                                          \
495     *sse_ptr = sse;                                                            \
496     return sse - (uint32_t)((cast se * se) >> (wlog2 + hlog2));                \
497   }                                                                            \
498                                                                                \
499   uint32_t aom_highbd_10_sub_pixel_avg_variance##w##x##h##_##opt(              \
500       const uint8_t *src8, int src_stride, int x_offset, int y_offset,         \
501       const uint8_t *dst8, int dst_stride, uint32_t *sse_ptr,                  \
502       const uint8_t *sec8) {                                                   \
503     int64_t var;                                                               \
504     uint32_t sse;                                                              \
505     uint16_t *src = CONVERT_TO_SHORTPTR(src8);                                 \
506     uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);                                 \
507     uint16_t *sec = CONVERT_TO_SHORTPTR(sec8);                                 \
508     int se = aom_highbd_sub_pixel_avg_variance##wf##xh_##opt(                  \
509         src, src_stride, x_offset, y_offset, dst, dst_stride, sec, w, h, &sse, \
510         NULL, NULL);                                                           \
511     if (w > wf) {                                                              \
512       uint32_t sse2;                                                           \
513       int se2 = aom_highbd_sub_pixel_avg_variance##wf##xh_##opt(               \
514           src + wf, src_stride, x_offset, y_offset, dst + wf, dst_stride,      \
515           sec + wf, w, h, &sse2, NULL, NULL);                                  \
516       se += se2;                                                               \
517       sse += sse2;                                                             \
518       if (w > wf * 2) {                                                        \
519         se2 = aom_highbd_sub_pixel_avg_variance##wf##xh_##opt(                 \
520             src + 2 * wf, src_stride, x_offset, y_offset, dst + 2 * wf,        \
521             dst_stride, sec + 2 * wf, w, h, &sse2, NULL, NULL);                \
522         se += se2;                                                             \
523         sse += sse2;                                                           \
524         se2 = aom_highbd_sub_pixel_avg_variance##wf##xh_##opt(                 \
525             src + 3 * wf, src_stride, x_offset, y_offset, dst + 3 * wf,        \
526             dst_stride, sec + 3 * wf, w, h, &sse2, NULL, NULL);                \
527         se += se2;                                                             \
528         sse += sse2;                                                           \
529       }                                                                        \
530     }                                                                          \
531     se = ROUND_POWER_OF_TWO(se, 2);                                            \
532     sse = ROUND_POWER_OF_TWO(sse, 4);                                          \
533     *sse_ptr = sse;                                                            \
534     var = (int64_t)(sse) - ((cast se * se) >> (wlog2 + hlog2));                \
535     return (var >= 0) ? (uint32_t)var : 0;                                     \
536   }                                                                            \
537                                                                                \
538   uint32_t aom_highbd_12_sub_pixel_avg_variance##w##x##h##_##opt(              \
539       const uint8_t *src8, int src_stride, int x_offset, int y_offset,         \
540       const uint8_t *dst8, int dst_stride, uint32_t *sse_ptr,                  \
541       const uint8_t *sec8) {                                                   \
542     int start_row;                                                             \
543     int64_t var;                                                               \
544     uint32_t sse;                                                              \
545     int se = 0;                                                                \
546     uint64_t long_sse = 0;                                                     \
547     uint16_t *src = CONVERT_TO_SHORTPTR(src8);                                 \
548     uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);                                 \
549     uint16_t *sec = CONVERT_TO_SHORTPTR(sec8);                                 \
550     for (start_row = 0; start_row < h; start_row += 16) {                      \
551       uint32_t sse2;                                                           \
552       int height = h - start_row < 16 ? h - start_row : 16;                    \
553       int se2 = aom_highbd_sub_pixel_avg_variance##wf##xh_##opt(               \
554           src + (start_row * src_stride), src_stride, x_offset, y_offset,      \
555           dst + (start_row * dst_stride), dst_stride, sec + (start_row * w),   \
556           w, height, &sse2, NULL, NULL);                                       \
557       se += se2;                                                               \
558       long_sse += sse2;                                                        \
559       if (w > wf) {                                                            \
560         se2 = aom_highbd_sub_pixel_avg_variance##wf##xh_##opt(                 \
561             src + wf + (start_row * src_stride), src_stride, x_offset,         \
562             y_offset, dst + wf + (start_row * dst_stride), dst_stride,         \
563             sec + wf + (start_row * w), w, height, &sse2, NULL, NULL);         \
564         se += se2;                                                             \
565         long_sse += sse2;                                                      \
566         if (w > wf * 2) {                                                      \
567           se2 = aom_highbd_sub_pixel_avg_variance##wf##xh_##opt(               \
568               src + 2 * wf + (start_row * src_stride), src_stride, x_offset,   \
569               y_offset, dst + 2 * wf + (start_row * dst_stride), dst_stride,   \
570               sec + 2 * wf + (start_row * w), w, height, &sse2, NULL, NULL);   \
571           se += se2;                                                           \
572           long_sse += sse2;                                                    \
573           se2 = aom_highbd_sub_pixel_avg_variance##wf##xh_##opt(               \
574               src + 3 * wf + (start_row * src_stride), src_stride, x_offset,   \
575               y_offset, dst + 3 * wf + (start_row * dst_stride), dst_stride,   \
576               sec + 3 * wf + (start_row * w), w, height, &sse2, NULL, NULL);   \
577           se += se2;                                                           \
578           long_sse += sse2;                                                    \
579         }                                                                      \
580       }                                                                        \
581     }                                                                          \
582     se = ROUND_POWER_OF_TWO(se, 4);                                            \
583     sse = (uint32_t)ROUND_POWER_OF_TWO(long_sse, 8);                           \
584     *sse_ptr = sse;                                                            \
585     var = (int64_t)(sse) - ((cast se * se) >> (wlog2 + hlog2));                \
586     return (var >= 0) ? (uint32_t)var : 0;                                     \
587   }
588 
589 #define FNS(opt)                        \
590   FN(64, 64, 16, 6, 6, opt, (int64_t)); \
591   FN(64, 32, 16, 6, 5, opt, (int64_t)); \
592   FN(32, 64, 16, 5, 6, opt, (int64_t)); \
593   FN(32, 32, 16, 5, 5, opt, (int64_t)); \
594   FN(32, 16, 16, 5, 4, opt, (int64_t)); \
595   FN(16, 32, 16, 4, 5, opt, (int64_t)); \
596   FN(16, 16, 16, 4, 4, opt, (int64_t)); \
597   FN(16, 8, 16, 4, 3, opt, (int64_t));  \
598   FN(8, 16, 8, 3, 4, opt, (int64_t));   \
599   FN(8, 8, 8, 3, 3, opt, (int64_t));    \
600   FN(8, 4, 8, 3, 2, opt, (int64_t));    \
601   FN(16, 4, 16, 4, 2, opt, (int64_t));  \
602   FN(8, 32, 8, 3, 5, opt, (int64_t));   \
603   FN(32, 8, 16, 5, 3, opt, (int64_t));  \
604   FN(16, 64, 16, 4, 6, opt, (int64_t)); \
605   FN(64, 16, 16, 6, 4, opt, (int64_t));
606 
607 FNS(sse2);
608 
609 #undef FNS
610 #undef FN
611 
highbd_compute_dist_wtd_comp_avg(__m128i * p0,__m128i * p1,const __m128i * w0,const __m128i * w1,const __m128i * r,void * const result)612 static INLINE void highbd_compute_dist_wtd_comp_avg(__m128i *p0, __m128i *p1,
613                                                     const __m128i *w0,
614                                                     const __m128i *w1,
615                                                     const __m128i *r,
616                                                     void *const result) {
617   assert(DIST_PRECISION_BITS <= 4);
618   __m128i mult0 = _mm_mullo_epi16(*p0, *w0);
619   __m128i mult1 = _mm_mullo_epi16(*p1, *w1);
620   __m128i sum = _mm_adds_epu16(mult0, mult1);
621   __m128i round = _mm_adds_epu16(sum, *r);
622   __m128i shift = _mm_srli_epi16(round, DIST_PRECISION_BITS);
623 
624   xx_storeu_128(result, shift);
625 }
626 
aom_highbd_dist_wtd_comp_avg_pred_sse2(uint8_t * comp_pred8,const uint8_t * pred8,int width,int height,const uint8_t * ref8,int ref_stride,const DIST_WTD_COMP_PARAMS * jcp_param)627 void aom_highbd_dist_wtd_comp_avg_pred_sse2(
628     uint8_t *comp_pred8, const uint8_t *pred8, int width, int height,
629     const uint8_t *ref8, int ref_stride,
630     const DIST_WTD_COMP_PARAMS *jcp_param) {
631   int i;
632   const uint16_t wt0 = (uint16_t)jcp_param->fwd_offset;
633   const uint16_t wt1 = (uint16_t)jcp_param->bck_offset;
634   const __m128i w0 = _mm_set_epi16(wt0, wt0, wt0, wt0, wt0, wt0, wt0, wt0);
635   const __m128i w1 = _mm_set_epi16(wt1, wt1, wt1, wt1, wt1, wt1, wt1, wt1);
636   const uint16_t round = ((1 << DIST_PRECISION_BITS) >> 1);
637   const __m128i r =
638       _mm_set_epi16(round, round, round, round, round, round, round, round);
639   uint16_t *pred = CONVERT_TO_SHORTPTR(pred8);
640   uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
641   uint16_t *comp_pred = CONVERT_TO_SHORTPTR(comp_pred8);
642 
643   if (width >= 8) {
644     // Read 8 pixels one row at a time
645     assert(!(width & 7));
646     for (i = 0; i < height; ++i) {
647       int j;
648       for (j = 0; j < width; j += 8) {
649         __m128i p0 = xx_loadu_128(ref);
650         __m128i p1 = xx_loadu_128(pred);
651 
652         highbd_compute_dist_wtd_comp_avg(&p0, &p1, &w0, &w1, &r, comp_pred);
653 
654         comp_pred += 8;
655         pred += 8;
656         ref += 8;
657       }
658       ref += ref_stride - width;
659     }
660   } else {
661     // Read 4 pixels two rows at a time
662     assert(!(width & 3));
663     for (i = 0; i < height; i += 2) {
664       __m128i p0_0 = xx_loadl_64(ref + 0 * ref_stride);
665       __m128i p0_1 = xx_loadl_64(ref + 1 * ref_stride);
666       __m128i p0 = _mm_unpacklo_epi64(p0_0, p0_1);
667       __m128i p1 = xx_loadu_128(pred);
668 
669       highbd_compute_dist_wtd_comp_avg(&p0, &p1, &w0, &w1, &r, comp_pred);
670 
671       comp_pred += 8;
672       pred += 8;
673       ref += 2 * ref_stride;
674     }
675   }
676 }
677 
aom_mse_4xh_16bit_highbd_sse2(uint16_t * dst,int dstride,uint16_t * src,int sstride,int h)678 uint64_t aom_mse_4xh_16bit_highbd_sse2(uint16_t *dst, int dstride,
679                                        uint16_t *src, int sstride, int h) {
680   uint64_t sum = 0;
681   __m128i reg0_4x16, reg1_4x16;
682   __m128i src_8x16;
683   __m128i dst_8x16;
684   __m128i res0_4x32, res1_4x32, res0_4x64, res1_4x64, res2_4x64, res3_4x64;
685   __m128i sub_result_8x16;
686   const __m128i zeros = _mm_setzero_si128();
687   __m128i square_result = _mm_setzero_si128();
688   for (int i = 0; i < h; i += 2) {
689     reg0_4x16 = _mm_loadl_epi64((__m128i const *)(&dst[(i + 0) * dstride]));
690     reg1_4x16 = _mm_loadl_epi64((__m128i const *)(&dst[(i + 1) * dstride]));
691     dst_8x16 = _mm_unpacklo_epi64(reg0_4x16, reg1_4x16);
692 
693     reg0_4x16 = _mm_loadl_epi64((__m128i const *)(&src[(i + 0) * sstride]));
694     reg1_4x16 = _mm_loadl_epi64((__m128i const *)(&src[(i + 1) * sstride]));
695     src_8x16 = _mm_unpacklo_epi64(reg0_4x16, reg1_4x16);
696 
697     sub_result_8x16 = _mm_sub_epi16(src_8x16, dst_8x16);
698 
699     res0_4x32 = _mm_unpacklo_epi16(sub_result_8x16, zeros);
700     res1_4x32 = _mm_unpackhi_epi16(sub_result_8x16, zeros);
701 
702     res0_4x32 = _mm_madd_epi16(res0_4x32, res0_4x32);
703     res1_4x32 = _mm_madd_epi16(res1_4x32, res1_4x32);
704 
705     res0_4x64 = _mm_unpacklo_epi32(res0_4x32, zeros);
706     res1_4x64 = _mm_unpackhi_epi32(res0_4x32, zeros);
707     res2_4x64 = _mm_unpacklo_epi32(res1_4x32, zeros);
708     res3_4x64 = _mm_unpackhi_epi32(res1_4x32, zeros);
709 
710     square_result = _mm_add_epi64(
711         square_result,
712         _mm_add_epi64(
713             _mm_add_epi64(_mm_add_epi64(res0_4x64, res1_4x64), res2_4x64),
714             res3_4x64));
715   }
716 
717   const __m128i sum_1x64 =
718       _mm_add_epi64(square_result, _mm_srli_si128(square_result, 8));
719   xx_storel_64(&sum, sum_1x64);
720   return sum;
721 }
722 
aom_mse_8xh_16bit_highbd_sse2(uint16_t * dst,int dstride,uint16_t * src,int sstride,int h)723 uint64_t aom_mse_8xh_16bit_highbd_sse2(uint16_t *dst, int dstride,
724                                        uint16_t *src, int sstride, int h) {
725   uint64_t sum = 0;
726   __m128i src_8x16;
727   __m128i dst_8x16;
728   __m128i res0_4x32, res1_4x32, res0_4x64, res1_4x64, res2_4x64, res3_4x64;
729   __m128i sub_result_8x16;
730   const __m128i zeros = _mm_setzero_si128();
731   __m128i square_result = _mm_setzero_si128();
732 
733   for (int i = 0; i < h; i++) {
734     dst_8x16 = _mm_loadu_si128((__m128i *)&dst[i * dstride]);
735     src_8x16 = _mm_loadu_si128((__m128i *)&src[i * sstride]);
736 
737     sub_result_8x16 = _mm_sub_epi16(src_8x16, dst_8x16);
738 
739     res0_4x32 = _mm_unpacklo_epi16(sub_result_8x16, zeros);
740     res1_4x32 = _mm_unpackhi_epi16(sub_result_8x16, zeros);
741 
742     res0_4x32 = _mm_madd_epi16(res0_4x32, res0_4x32);
743     res1_4x32 = _mm_madd_epi16(res1_4x32, res1_4x32);
744 
745     res0_4x64 = _mm_unpacklo_epi32(res0_4x32, zeros);
746     res1_4x64 = _mm_unpackhi_epi32(res0_4x32, zeros);
747     res2_4x64 = _mm_unpacklo_epi32(res1_4x32, zeros);
748     res3_4x64 = _mm_unpackhi_epi32(res1_4x32, zeros);
749 
750     square_result = _mm_add_epi64(
751         square_result,
752         _mm_add_epi64(
753             _mm_add_epi64(_mm_add_epi64(res0_4x64, res1_4x64), res2_4x64),
754             res3_4x64));
755   }
756 
757   const __m128i sum_1x64 =
758       _mm_add_epi64(square_result, _mm_srli_si128(square_result, 8));
759   xx_storel_64(&sum, sum_1x64);
760   return sum;
761 }
762 
aom_mse_wxh_16bit_highbd_sse2(uint16_t * dst,int dstride,uint16_t * src,int sstride,int w,int h)763 uint64_t aom_mse_wxh_16bit_highbd_sse2(uint16_t *dst, int dstride,
764                                        uint16_t *src, int sstride, int w,
765                                        int h) {
766   assert((w == 8 || w == 4) && (h == 8 || h == 4) &&
767          "w=8/4 and h=8/4 must satisfy");
768   switch (w) {
769     case 4: return aom_mse_4xh_16bit_highbd_sse2(dst, dstride, src, sstride, h);
770     case 8: return aom_mse_8xh_16bit_highbd_sse2(dst, dstride, src, sstride, h);
771     default: assert(0 && "unsupported width"); return -1;
772   }
773 }
774