1; Copyright (c) 2020, The rav1e contributors. All rights reserved
2;
3; This source code is subject to the terms of the BSD 2 Clause License and
4; the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
5; was not distributed with this source code in the LICENSE file, you can
6; obtain it at www.aomedia.org/license/software. If the Alliance for Open
7; Media Patent License 1.0 was not distributed with this source code in the
8; PATENTS file, you can obtain it at www.aomedia.org/license/patent.
9
10%include "config.asm"
11%include "ext/x86/x86inc.asm"
12
13SECTION_RODATA 32
14addsub: times 16 db 1, -1
15rounding: times 4 dq 0x800
16
17SECTION .text
18
19%define m(x) mangle(private_prefix %+ _ %+ x %+ SUFFIX)
20
21; Consolidate scaling and rounding to one place so that it is easier to change.
22
23%macro SSE_SCALE_4X4 0
24    ; Multiply and shift using scalar code
25    mov             scaled, [scaleq]
26    imul               rax, scaleq
27    add                rax, 0x800
28    shr                rax, 12
29%endmacro
30
31; 1 is the input and output register.
32; 2-3 are tmp registers.
33%macro SSE_SCALE 2-3
34    ; Reduce 32-bit sums to 64-bits sums.
35    pshufd             m%2, m%1, q3311
36    paddd              m%1, m%2
37
38    LOAD_SCALES %2, %3
39
40    ; Multiply and shift with rounding.
41    pmuludq            m%1, m%2
42    ; TODO: Alter rust source so that rounding can always done at the end (i.e.
43    ; only do it once)
44    mova               m%2, [rounding]
45    paddq              m%1, m%2
46    psrlq              m%1, 12
47%endmacro
48
49%macro LOAD_SCALES_4X8 2
50    ; Load 1 scale from each of the 2 rows.
51    movd               m%1, [scaleq]
52    movd               m%2, [scaleq+scale_strideq]
53    ; 64-bit unpack since our loads have only one value each.
54    punpcklqdq         m%1, m%2
55%endmacro
56
57; 2 is unused
58%macro LOAD_SCALES_8X4 2
59    ; Convert to 64-bits.
60    ; It doesn't matter that the upper halves are full of garbage.
61    movq               m%1, [scaleq]
62    pshufd             m%1, m%1, q1100
63%endmacro
64
65; 2 is unused
66%macro LOAD_SCALES_16X4 2
67    pmovzxdq           m%1, [scaleq]
68%endmacro
69
70; Separate from other scale macros, since it uses 2 inputs.
71; 1-2 are inputs regs and 1 is the output reg.
72; 3-4 are tmp registers
73%macro SSE_SCALE_32X4 4
74    pshufd             m%3, m%1, q3311
75    paddd              m%1, m%3
76    pshufd             m%3, m%2, q3311
77    paddd              m%2, m%3
78
79    ; Load scale for 4x4 blocks and convert to 64-bits.
80    ; It doesn't matter if the upper halves are full of garbage.
81    ; raw load:    0, 1, 2, 3 | 4, 5, 6, 7
82    ; unpack low:  0,    1    | 4,    5
83    ; unpack high: 2,    3,   | 6,    7
84    mova               m%4, [scaleq]
85    punpckldq          m%3, m%4, m%4
86    punpckhdq          m%4, m%4
87
88    pmuludq            m%1, m%3
89    pmuludq            m%2, m%4
90    mova               m%3, [rounding]
91    paddq              m%1, m%3
92    paddq              m%2, m%3
93    psrlq              m%1, 12
94    psrlq              m%2, 12
95    paddq              m%1, m%2
96%endmacro
97
98INIT_XMM ssse3
99; Use scale_stride's register to store src_stride3
100cglobal weighted_sse_4x4, 6, 7, 5, \
101        src, src_stride, dst, dst_stride, scale, \
102        src_stride3, dst_stride3
103    lea       src_stride3q, [src_strideq*3]
104    lea       dst_stride3q, [dst_strideq*3]
105    movq                m0, [addsub]
106    movd                m1, [srcq]
107    movd                m2, [dstq]
108    punpcklbw           m1, m2
109    movd                m2, [srcq+src_strideq]
110    movd                m3, [dstq+dst_strideq]
111    punpcklbw           m2, m3
112    pmaddubsw           m1, m0
113    pmaddubsw           m2, m0
114    pmaddwd             m1, m1
115    pmaddwd             m2, m2
116    paddd               m1, m2
117    movd                m2, [srcq+src_strideq*2]
118    movd                m3, [dstq+dst_strideq*2]
119    punpcklbw           m2, m3
120    movd                m3, [srcq+src_stride3q]
121    movd                m4, [dstq+dst_stride3q]
122    punpcklbw           m3, m4
123    pmaddubsw           m2, m0
124    pmaddubsw           m3, m0
125    pmaddwd             m2, m2
126    pmaddwd             m3, m3
127    paddd               m2, m3
128    paddd               m1, m2
129
130    pshuflw             m0, m1, q3232
131    paddd               m0, m1
132    movd               eax, m0
133
134    ; Multiply and shift using scalar code.
135    SSE_SCALE_4X4
136    RET
137
138%macro WEIGHTED_SSE_4X8_KERNEL 0
139    movd                m1, [srcq]
140    movd                m2, [srcq+src_strideq*4]
141    punpckldq           m1, m2
142    movd                m2, [dstq]
143    movd                m3, [dstq+dst_strideq*4]
144    add               srcq, src_strideq
145    add               dstq, dst_strideq
146    punpckldq           m2, m3
147    punpcklbw           m1, m2
148    movd                m2, [srcq]
149    movd                m3, [srcq+src_strideq*4]
150    punpckldq           m2, m3
151    movd                m3, [dstq]
152    movd                m4, [dstq+dst_strideq*4]
153    add               srcq, src_strideq
154    add               dstq, dst_strideq
155    punpckldq           m3, m4
156    punpcklbw           m2, m3
157    pmaddubsw           m1, m0
158    pmaddubsw           m2, m0
159    pmaddwd             m1, m1
160    pmaddwd             m2, m2
161    paddd               m1, m2
162    movd                m2, [srcq]
163    movd                m3, [srcq+src_strideq*4]
164    punpckldq           m2, m3
165    movd                m3, [dstq]
166    movd                m4, [dstq+dst_strideq*4]
167    add               srcq, src_strideq
168    add               dstq, dst_strideq
169    punpckldq           m3, m4
170    punpcklbw           m2, m3
171    movd                m3, [srcq]
172    movd                m4, [srcq+src_strideq*4]
173    punpckldq           m3, m4
174    movd                m4, [dstq]
175    movd                m5, [dstq+dst_strideq*4]
176    punpckldq           m4, m5
177    punpcklbw           m3, m4
178    pmaddubsw           m2, m0
179    pmaddubsw           m3, m0
180    pmaddwd             m2, m2
181    pmaddwd             m3, m3
182    paddd               m2, m3
183    paddd               m1, m2
184
185    %define LOAD_SCALES LOAD_SCALES_4X8
186    SSE_SCALE 1, 2, 3
187%endmacro
188
189INIT_XMM ssse3
190cglobal weighted_sse_4x8, 6, 6, 6, \
191        src, src_stride, dst, dst_stride, scale, scale_stride
192    mova                m0, [addsub]
193    WEIGHTED_SSE_4X8_KERNEL
194
195    pshufd              m0, m1, q3232
196    paddq               m1, m0
197    movq               rax, m1
198    RET
199
200INIT_XMM ssse3
201cglobal weighted_sse_4x16, 6, 6, 7, \
202        src, src_stride, dst, dst_stride, scale, scale_stride
203    mova                m0, [addsub]
204
205    WEIGHTED_SSE_4X8_KERNEL
206    ; Swap so the use of this macro will use m6 as the result
207    SWAP 1, 6
208
209    lea             scaleq, [scaleq+scale_strideq*2]
210    ; Already incremented by stride 3 times, but must go up 5 more to get to 8
211    add               srcq, src_strideq
212    add               dstq, dst_strideq
213    lea               srcq, [srcq+src_strideq*4]
214    lea               dstq, [dstq+dst_strideq*4]
215    WEIGHTED_SSE_4X8_KERNEL
216    paddq               m6, m1
217
218    pshufd              m0, m6, q3232
219    paddq               m6, m0
220    movq               rax, m6
221    RET
222
223%macro WEIGHTED_SSE_8X4_KERNEL 0
224    movq                m1, [srcq]
225    movq                m2, [dstq]
226    punpcklbw           m1, m2
227    movq                m2, [srcq+src_strideq]
228    movq                m3, [dstq+dst_strideq]
229    punpcklbw           m2, m3
230    pmaddubsw           m1, m0
231    pmaddubsw           m2, m0
232    pmaddwd             m1, m1
233    pmaddwd             m2, m2
234    paddd               m1, m2
235    movq                m2, [srcq+src_strideq*2]
236    movq                m3, [dstq+dst_strideq*2]
237    punpcklbw           m2, m3
238    movq                m3, [srcq+src_stride3q]
239    movq                m4, [dstq+dst_stride3q]
240    punpcklbw           m3, m4
241    pmaddubsw           m2, m0
242    pmaddubsw           m3, m0
243    pmaddwd             m2, m2
244    pmaddwd             m3, m3
245    paddd               m2, m3
246    paddd               m1, m2
247
248    %define LOAD_SCALES LOAD_SCALES_8X4
249    SSE_SCALE 1, 2
250%endmacro
251
252%macro WEIGHTED_SSE_16X4_KERNEL 0
253    pmovzxbw            m0, [srcq]
254    pmovzxbw            m1, [dstq]
255    psubw               m0, m1
256    pmaddwd             m0, m0
257    pmovzxbw            m1, [srcq+src_strideq]
258    pmovzxbw            m2, [dstq+dst_strideq]
259    psubw               m1, m2
260    pmaddwd             m1, m1
261    paddd               m0, m1
262    pmovzxbw            m1, [srcq+src_strideq*2]
263    pmovzxbw            m2, [dstq+dst_strideq*2]
264    psubw               m1, m2
265    pmaddwd             m1, m1
266    pmovzxbw            m2, [srcq+src_stride3q]
267    pmovzxbw            m3, [dstq+dst_stride3q]
268    psubw               m2, m3
269    pmaddwd             m2, m2
270    paddd               m1, m2
271    paddd               m1, m0
272
273    %define LOAD_SCALES LOAD_SCALES_16X4
274    SSE_SCALE 1, 2
275%endmacro
276
277%macro WEIGHTED_SSE_32X4_KERNEL 0
278    ; Unpacking high and low results in sums that are 8 samples apart. To
279    ; correctly apply weights, two separate registers are needed to accumulate.
280    mova                m2, [srcq]
281    mova                m3, [dstq]
282    punpcklbw           m1, m2, m3
283    punpckhbw           m2, m3
284    mova                m4, [srcq+src_strideq]
285    mova                m5, [dstq+dst_strideq]
286    punpcklbw           m3, m4, m5
287    punpckhbw           m4, m5
288    pmaddubsw           m1, m0
289    pmaddubsw           m2, m0
290    pmaddubsw           m3, m0
291    pmaddubsw           m4, m0
292    pmaddwd             m1, m1
293    pmaddwd             m2, m2
294    pmaddwd             m3, m3
295    pmaddwd             m4, m4
296    ; Accumulate
297    paddd               m1, m3
298    paddd               m2, m4
299    mova                m4, [srcq+src_strideq*2]
300    mova                m5, [dstq+dst_strideq*2]
301    punpcklbw           m3, m4, m5
302    punpckhbw           m4, m5
303    mova                m6, [srcq+src_stride3q]
304    mova                m7, [dstq+dst_stride3q]
305    punpcklbw           m5, m6, m7
306    punpckhbw           m6, m7
307    pmaddubsw           m3, m0
308    pmaddubsw           m4, m0
309    pmaddubsw           m5, m0
310    pmaddubsw           m6, m0
311    pmaddwd             m3, m3
312    pmaddwd             m4, m4
313    pmaddwd             m5, m5
314    pmaddwd             m6, m6
315    paddd               m3, m5
316    paddd               m4, m6
317    paddd               m1, m3
318    paddd               m2, m4
319
320    SSE_SCALE_32X4 1, 2, 3, 4
321%endmacro
322
323%macro WEIGHTED_SSE 2 ; w, h
324%if %1 == 8
325%if %2 == 4
326; Use scale_stride's register to store src_stride3
327cglobal weighted_sse_%1x%2, 6, 7, 5, \
328        src, src_stride, dst, dst_stride, scale, \
329        src_stride3, dst_stride3
330%else
331cglobal weighted_sse_%1x%2, 6, 9, 6, \
332        src, src_stride, dst, dst_stride, scale, scale_stride, \
333        src_stride3, dst_stride3, h
334%endif
335%elif %1 == 16
336%if %2 == 4
337; Use scale_stride's register to store src_stride3
338cglobal weighted_sse_%1x%2, 6, 7, 4, \
339        src, src_stride, dst, dst_stride, scale, \
340        src_stride3, dst_stride3
341%else
342cglobal weighted_sse_%1x%2, 6, 9, 5, \
343        src, src_stride, dst, dst_stride, scale, scale_stride, \
344        src_stride3, dst_stride3, h
345%endif
346%elif %1 == 32
347cglobal weighted_sse_%1x%2, 6, 9, 9, \
348        src, src_stride, dst, dst_stride, scale, scale_stride, \
349        src_stride3, dst_stride3, h
350%else ; > 32
351cglobal weighted_sse_%1x%2, 6, 10, 9, \
352        src, src_stride, dst, dst_stride, scale, scale_stride, \
353        src_stride3, dst_stride3, h, w
354%endif
355; === Setup ===
356; kernel_width/kernel_height: number of elements that the kernel processes.
357; m0: except for when w == 16, m0 is used to hold a constant 1, -1... vector
358;     register for diffing the two sources.
359; sum: The kernel stores it's results on m1. The last vector register is used
360;      unless only one iteration is done.
361
362; Default the kernel width to the width of this function.
363%define kernel_width %1
364%define kernel_height 4
365%if %1 == 8
366    mova                m0, [addsub]
367%endif
368
369%if %1 >= 32
370    mova                m0, [addsub]
371    ; Iterate multiple times when w > 32.
372    %define kernel_width 32
373%endif
374
375%if %1 > kernel_width || %2 > kernel_height
376    ; Add onto the last used vector register.
377    %assign sum xmm_regs_used-1
378%else
379    ; Use the result from the kernel
380    %define sum 1
381%endif
382
383    lea       src_stride3q, [src_strideq*3]
384    lea       dst_stride3q, [dst_strideq*3]
385%if %1 > kernel_width || %2 > kernel_height
386    pxor           m%[sum], m%[sum]
387%endif
388%if %2 > kernel_height
389    mov                 hd, %2/kernel_height-1
390.loop:
391%endif
392
393%if %1 > kernel_width
394    mov                 wd, %1/kernel_width-1
395.loop_horiz:
396%endif
397
398    WEIGHTED_SSE_%[kernel_width]X%[kernel_height]_KERNEL
399%if %2 > kernel_height || %1 > kernel_width
400    paddq          m%[sum], m1
401%endif
402
403%if %1 > kernel_width
404    add             scaleq, kernel_width*4/4
405    add               srcq, kernel_width
406    add               dstq, kernel_width
407    dec                 wq
408    jge .loop_horiz
409%endif
410
411%if %2 > kernel_height
412    ; Move down 4 rows.
413%if %1 > kernel_width
414    ; src/dst is incremented by width when processing multi iteration rows.
415    ; Reduce the offset by the width of the row.
416    lea               srcq, [srcq+src_strideq*4 - %1]
417    lea               dstq, [dstq+dst_strideq*4 - %1]
418    ; The behaviour for scale is similar
419    lea             scaleq, [scaleq+scale_strideq - %1*4/4]
420%else
421    lea               srcq, [srcq+src_strideq*4]
422    lea               dstq, [dstq+dst_strideq*4]
423    add             scaleq, scale_strideq
424%endif
425    dec                 hq
426    jge .loop
427%endif
428
429%if mmsize == 16
430    pshufd              m2, m%[sum], q3232
431    paddq          m%[sum], m2
432    movq               rax, m%[sum]
433%elif mmsize == 32
434    vextracti128       xm2, m%[sum], 1
435    paddq         xm%[sum], xm2
436    pshufd             xm2, xm%[sum], q3232
437    paddq         xm%[sum], xm2
438    movq               rax, xm%[sum]
439%endif
440    RET
441
442    %undef sum, kernel_width, res
443%endmacro
444
445INIT_XMM ssse3
446WEIGHTED_SSE 8, 4
447%if ARCH_X86_64
448WEIGHTED_SSE 8, 8
449WEIGHTED_SSE 8, 16
450WEIGHTED_SSE 8, 32
451%endif ; ARCH_X86_64
452
453INIT_YMM avx2
454WEIGHTED_SSE 16, 4
455%if ARCH_X86_64
456WEIGHTED_SSE 16, 8
457WEIGHTED_SSE 16, 16
458WEIGHTED_SSE 16, 32
459WEIGHTED_SSE 16, 64
460
461WEIGHTED_SSE 32, 8
462WEIGHTED_SSE 32, 16
463WEIGHTED_SSE 32, 32
464WEIGHTED_SSE 32, 64
465
466WEIGHTED_SSE 64, 16
467WEIGHTED_SSE 64, 32
468WEIGHTED_SSE 64, 64
469WEIGHTED_SSE 64, 128
470
471WEIGHTED_SSE 128, 64
472WEIGHTED_SSE 128, 128
473%endif ; ARCH_X86_64
474
475INIT_XMM sse2
476
477cglobal weighted_sse_4x4_hbd, 6, 8, 4, \
478        src, src_stride, dst, dst_stride, scale, scale_stride, \
479        src_stride3, dst_stride3
480    lea       src_stride3q, [src_strideq*3]
481    lea       dst_stride3q, [dst_strideq*3]
482    movq                m0, [srcq]
483    movq                m1, [dstq]
484    psubw               m0, m1
485    pmaddwd             m0, m0
486    movq                m1, [srcq+src_strideq]
487    movq                m2, [dstq+dst_strideq]
488    psubw               m1, m2
489    pmaddwd             m1, m1
490    paddd               m0, m1
491    movq                m1, [srcq+src_strideq*2]
492    movq                m2, [dstq+dst_strideq*2]
493    psubw               m1, m2
494    pmaddwd             m1, m1
495    movq                m2, [srcq+src_stride3q]
496    movq                m3, [dstq+dst_stride3q]
497    psubw               m2, m3
498    pmaddwd             m2, m2
499    paddd               m1, m2
500    paddd               m0, m1
501
502    pshuflw             m1, m0, q3232
503    paddd               m0, m1
504    movd               eax, m0
505
506    ; Multiply and shift using scalar code.
507    SSE_SCALE_4X4
508    RET
509