1; Copyright (c) 2020, The rav1e contributors. All rights reserved 2; 3; This source code is subject to the terms of the BSD 2 Clause License and 4; the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 5; was not distributed with this source code in the LICENSE file, you can 6; obtain it at www.aomedia.org/license/software. If the Alliance for Open 7; Media Patent License 1.0 was not distributed with this source code in the 8; PATENTS file, you can obtain it at www.aomedia.org/license/patent. 9 10%include "config.asm" 11%include "ext/x86/x86inc.asm" 12 13SECTION_RODATA 32 14addsub: times 16 db 1, -1 15rounding: times 4 dq 0x800 16 17SECTION .text 18 19%define m(x) mangle(private_prefix %+ _ %+ x %+ SUFFIX) 20 21; Consolidate scaling and rounding to one place so that it is easier to change. 22 23%macro SSE_SCALE_4X4 0 24 ; Multiply and shift using scalar code 25 mov scaled, [scaleq] 26 imul rax, scaleq 27 add rax, 0x800 28 shr rax, 12 29%endmacro 30 31; 1 is the input and output register. 32; 2-3 are tmp registers. 33%macro SSE_SCALE 2-3 34 ; Reduce 32-bit sums to 64-bits sums. 35 pshufd m%2, m%1, q3311 36 paddd m%1, m%2 37 38 LOAD_SCALES %2, %3 39 40 ; Multiply and shift with rounding. 41 pmuludq m%1, m%2 42 ; TODO: Alter rust source so that rounding can always done at the end (i.e. 43 ; only do it once) 44 mova m%2, [rounding] 45 paddq m%1, m%2 46 psrlq m%1, 12 47%endmacro 48 49%macro LOAD_SCALES_4X8 2 50 ; Load 1 scale from each of the 2 rows. 51 movd m%1, [scaleq] 52 movd m%2, [scaleq+scale_strideq] 53 ; 64-bit unpack since our loads have only one value each. 54 punpcklqdq m%1, m%2 55%endmacro 56 57; 2 is unused 58%macro LOAD_SCALES_8X4 2 59 ; Convert to 64-bits. 60 ; It doesn't matter that the upper halves are full of garbage. 61 movq m%1, [scaleq] 62 pshufd m%1, m%1, q1100 63%endmacro 64 65; 2 is unused 66%macro LOAD_SCALES_16X4 2 67 pmovzxdq m%1, [scaleq] 68%endmacro 69 70; Separate from other scale macros, since it uses 2 inputs. 71; 1-2 are inputs regs and 1 is the output reg. 72; 3-4 are tmp registers 73%macro SSE_SCALE_32X4 4 74 pshufd m%3, m%1, q3311 75 paddd m%1, m%3 76 pshufd m%3, m%2, q3311 77 paddd m%2, m%3 78 79 ; Load scale for 4x4 blocks and convert to 64-bits. 80 ; It doesn't matter if the upper halves are full of garbage. 81 ; raw load: 0, 1, 2, 3 | 4, 5, 6, 7 82 ; unpack low: 0, 1 | 4, 5 83 ; unpack high: 2, 3, | 6, 7 84 mova m%4, [scaleq] 85 punpckldq m%3, m%4, m%4 86 punpckhdq m%4, m%4 87 88 pmuludq m%1, m%3 89 pmuludq m%2, m%4 90 mova m%3, [rounding] 91 paddq m%1, m%3 92 paddq m%2, m%3 93 psrlq m%1, 12 94 psrlq m%2, 12 95 paddq m%1, m%2 96%endmacro 97 98INIT_XMM ssse3 99; Use scale_stride's register to store src_stride3 100cglobal weighted_sse_4x4, 6, 7, 5, \ 101 src, src_stride, dst, dst_stride, scale, \ 102 src_stride3, dst_stride3 103 lea src_stride3q, [src_strideq*3] 104 lea dst_stride3q, [dst_strideq*3] 105 movq m0, [addsub] 106 movd m1, [srcq] 107 movd m2, [dstq] 108 punpcklbw m1, m2 109 movd m2, [srcq+src_strideq] 110 movd m3, [dstq+dst_strideq] 111 punpcklbw m2, m3 112 pmaddubsw m1, m0 113 pmaddubsw m2, m0 114 pmaddwd m1, m1 115 pmaddwd m2, m2 116 paddd m1, m2 117 movd m2, [srcq+src_strideq*2] 118 movd m3, [dstq+dst_strideq*2] 119 punpcklbw m2, m3 120 movd m3, [srcq+src_stride3q] 121 movd m4, [dstq+dst_stride3q] 122 punpcklbw m3, m4 123 pmaddubsw m2, m0 124 pmaddubsw m3, m0 125 pmaddwd m2, m2 126 pmaddwd m3, m3 127 paddd m2, m3 128 paddd m1, m2 129 130 pshuflw m0, m1, q3232 131 paddd m0, m1 132 movd eax, m0 133 134 ; Multiply and shift using scalar code. 135 SSE_SCALE_4X4 136 RET 137 138%macro WEIGHTED_SSE_4X8_KERNEL 0 139 movd m1, [srcq] 140 movd m2, [srcq+src_strideq*4] 141 punpckldq m1, m2 142 movd m2, [dstq] 143 movd m3, [dstq+dst_strideq*4] 144 add srcq, src_strideq 145 add dstq, dst_strideq 146 punpckldq m2, m3 147 punpcklbw m1, m2 148 movd m2, [srcq] 149 movd m3, [srcq+src_strideq*4] 150 punpckldq m2, m3 151 movd m3, [dstq] 152 movd m4, [dstq+dst_strideq*4] 153 add srcq, src_strideq 154 add dstq, dst_strideq 155 punpckldq m3, m4 156 punpcklbw m2, m3 157 pmaddubsw m1, m0 158 pmaddubsw m2, m0 159 pmaddwd m1, m1 160 pmaddwd m2, m2 161 paddd m1, m2 162 movd m2, [srcq] 163 movd m3, [srcq+src_strideq*4] 164 punpckldq m2, m3 165 movd m3, [dstq] 166 movd m4, [dstq+dst_strideq*4] 167 add srcq, src_strideq 168 add dstq, dst_strideq 169 punpckldq m3, m4 170 punpcklbw m2, m3 171 movd m3, [srcq] 172 movd m4, [srcq+src_strideq*4] 173 punpckldq m3, m4 174 movd m4, [dstq] 175 movd m5, [dstq+dst_strideq*4] 176 punpckldq m4, m5 177 punpcklbw m3, m4 178 pmaddubsw m2, m0 179 pmaddubsw m3, m0 180 pmaddwd m2, m2 181 pmaddwd m3, m3 182 paddd m2, m3 183 paddd m1, m2 184 185 %define LOAD_SCALES LOAD_SCALES_4X8 186 SSE_SCALE 1, 2, 3 187%endmacro 188 189INIT_XMM ssse3 190cglobal weighted_sse_4x8, 6, 6, 6, \ 191 src, src_stride, dst, dst_stride, scale, scale_stride 192 mova m0, [addsub] 193 WEIGHTED_SSE_4X8_KERNEL 194 195 pshufd m0, m1, q3232 196 paddq m1, m0 197 movq rax, m1 198 RET 199 200INIT_XMM ssse3 201cglobal weighted_sse_4x16, 6, 6, 7, \ 202 src, src_stride, dst, dst_stride, scale, scale_stride 203 mova m0, [addsub] 204 205 WEIGHTED_SSE_4X8_KERNEL 206 ; Swap so the use of this macro will use m6 as the result 207 SWAP 1, 6 208 209 lea scaleq, [scaleq+scale_strideq*2] 210 ; Already incremented by stride 3 times, but must go up 5 more to get to 8 211 add srcq, src_strideq 212 add dstq, dst_strideq 213 lea srcq, [srcq+src_strideq*4] 214 lea dstq, [dstq+dst_strideq*4] 215 WEIGHTED_SSE_4X8_KERNEL 216 paddq m6, m1 217 218 pshufd m0, m6, q3232 219 paddq m6, m0 220 movq rax, m6 221 RET 222 223%macro WEIGHTED_SSE_8X4_KERNEL 0 224 movq m1, [srcq] 225 movq m2, [dstq] 226 punpcklbw m1, m2 227 movq m2, [srcq+src_strideq] 228 movq m3, [dstq+dst_strideq] 229 punpcklbw m2, m3 230 pmaddubsw m1, m0 231 pmaddubsw m2, m0 232 pmaddwd m1, m1 233 pmaddwd m2, m2 234 paddd m1, m2 235 movq m2, [srcq+src_strideq*2] 236 movq m3, [dstq+dst_strideq*2] 237 punpcklbw m2, m3 238 movq m3, [srcq+src_stride3q] 239 movq m4, [dstq+dst_stride3q] 240 punpcklbw m3, m4 241 pmaddubsw m2, m0 242 pmaddubsw m3, m0 243 pmaddwd m2, m2 244 pmaddwd m3, m3 245 paddd m2, m3 246 paddd m1, m2 247 248 %define LOAD_SCALES LOAD_SCALES_8X4 249 SSE_SCALE 1, 2 250%endmacro 251 252%macro WEIGHTED_SSE_16X4_KERNEL 0 253 pmovzxbw m0, [srcq] 254 pmovzxbw m1, [dstq] 255 psubw m0, m1 256 pmaddwd m0, m0 257 pmovzxbw m1, [srcq+src_strideq] 258 pmovzxbw m2, [dstq+dst_strideq] 259 psubw m1, m2 260 pmaddwd m1, m1 261 paddd m0, m1 262 pmovzxbw m1, [srcq+src_strideq*2] 263 pmovzxbw m2, [dstq+dst_strideq*2] 264 psubw m1, m2 265 pmaddwd m1, m1 266 pmovzxbw m2, [srcq+src_stride3q] 267 pmovzxbw m3, [dstq+dst_stride3q] 268 psubw m2, m3 269 pmaddwd m2, m2 270 paddd m1, m2 271 paddd m1, m0 272 273 %define LOAD_SCALES LOAD_SCALES_16X4 274 SSE_SCALE 1, 2 275%endmacro 276 277%macro WEIGHTED_SSE_32X4_KERNEL 0 278 ; Unpacking high and low results in sums that are 8 samples apart. To 279 ; correctly apply weights, two separate registers are needed to accumulate. 280 mova m2, [srcq] 281 mova m3, [dstq] 282 punpcklbw m1, m2, m3 283 punpckhbw m2, m3 284 mova m4, [srcq+src_strideq] 285 mova m5, [dstq+dst_strideq] 286 punpcklbw m3, m4, m5 287 punpckhbw m4, m5 288 pmaddubsw m1, m0 289 pmaddubsw m2, m0 290 pmaddubsw m3, m0 291 pmaddubsw m4, m0 292 pmaddwd m1, m1 293 pmaddwd m2, m2 294 pmaddwd m3, m3 295 pmaddwd m4, m4 296 ; Accumulate 297 paddd m1, m3 298 paddd m2, m4 299 mova m4, [srcq+src_strideq*2] 300 mova m5, [dstq+dst_strideq*2] 301 punpcklbw m3, m4, m5 302 punpckhbw m4, m5 303 mova m6, [srcq+src_stride3q] 304 mova m7, [dstq+dst_stride3q] 305 punpcklbw m5, m6, m7 306 punpckhbw m6, m7 307 pmaddubsw m3, m0 308 pmaddubsw m4, m0 309 pmaddubsw m5, m0 310 pmaddubsw m6, m0 311 pmaddwd m3, m3 312 pmaddwd m4, m4 313 pmaddwd m5, m5 314 pmaddwd m6, m6 315 paddd m3, m5 316 paddd m4, m6 317 paddd m1, m3 318 paddd m2, m4 319 320 SSE_SCALE_32X4 1, 2, 3, 4 321%endmacro 322 323%macro WEIGHTED_SSE 2 ; w, h 324%if %1 == 8 325%if %2 == 4 326; Use scale_stride's register to store src_stride3 327cglobal weighted_sse_%1x%2, 6, 7, 5, \ 328 src, src_stride, dst, dst_stride, scale, \ 329 src_stride3, dst_stride3 330%else 331cglobal weighted_sse_%1x%2, 6, 9, 6, \ 332 src, src_stride, dst, dst_stride, scale, scale_stride, \ 333 src_stride3, dst_stride3, h 334%endif 335%elif %1 == 16 336%if %2 == 4 337; Use scale_stride's register to store src_stride3 338cglobal weighted_sse_%1x%2, 6, 7, 4, \ 339 src, src_stride, dst, dst_stride, scale, \ 340 src_stride3, dst_stride3 341%else 342cglobal weighted_sse_%1x%2, 6, 9, 5, \ 343 src, src_stride, dst, dst_stride, scale, scale_stride, \ 344 src_stride3, dst_stride3, h 345%endif 346%elif %1 == 32 347cglobal weighted_sse_%1x%2, 6, 9, 9, \ 348 src, src_stride, dst, dst_stride, scale, scale_stride, \ 349 src_stride3, dst_stride3, h 350%else ; > 32 351cglobal weighted_sse_%1x%2, 6, 10, 9, \ 352 src, src_stride, dst, dst_stride, scale, scale_stride, \ 353 src_stride3, dst_stride3, h, w 354%endif 355; === Setup === 356; kernel_width/kernel_height: number of elements that the kernel processes. 357; m0: except for when w == 16, m0 is used to hold a constant 1, -1... vector 358; register for diffing the two sources. 359; sum: The kernel stores it's results on m1. The last vector register is used 360; unless only one iteration is done. 361 362; Default the kernel width to the width of this function. 363%define kernel_width %1 364%define kernel_height 4 365%if %1 == 8 366 mova m0, [addsub] 367%endif 368 369%if %1 >= 32 370 mova m0, [addsub] 371 ; Iterate multiple times when w > 32. 372 %define kernel_width 32 373%endif 374 375%if %1 > kernel_width || %2 > kernel_height 376 ; Add onto the last used vector register. 377 %assign sum xmm_regs_used-1 378%else 379 ; Use the result from the kernel 380 %define sum 1 381%endif 382 383 lea src_stride3q, [src_strideq*3] 384 lea dst_stride3q, [dst_strideq*3] 385%if %1 > kernel_width || %2 > kernel_height 386 pxor m%[sum], m%[sum] 387%endif 388%if %2 > kernel_height 389 mov hd, %2/kernel_height-1 390.loop: 391%endif 392 393%if %1 > kernel_width 394 mov wd, %1/kernel_width-1 395.loop_horiz: 396%endif 397 398 WEIGHTED_SSE_%[kernel_width]X%[kernel_height]_KERNEL 399%if %2 > kernel_height || %1 > kernel_width 400 paddq m%[sum], m1 401%endif 402 403%if %1 > kernel_width 404 add scaleq, kernel_width*4/4 405 add srcq, kernel_width 406 add dstq, kernel_width 407 dec wq 408 jge .loop_horiz 409%endif 410 411%if %2 > kernel_height 412 ; Move down 4 rows. 413%if %1 > kernel_width 414 ; src/dst is incremented by width when processing multi iteration rows. 415 ; Reduce the offset by the width of the row. 416 lea srcq, [srcq+src_strideq*4 - %1] 417 lea dstq, [dstq+dst_strideq*4 - %1] 418 ; The behaviour for scale is similar 419 lea scaleq, [scaleq+scale_strideq - %1*4/4] 420%else 421 lea srcq, [srcq+src_strideq*4] 422 lea dstq, [dstq+dst_strideq*4] 423 add scaleq, scale_strideq 424%endif 425 dec hq 426 jge .loop 427%endif 428 429%if mmsize == 16 430 pshufd m2, m%[sum], q3232 431 paddq m%[sum], m2 432 movq rax, m%[sum] 433%elif mmsize == 32 434 vextracti128 xm2, m%[sum], 1 435 paddq xm%[sum], xm2 436 pshufd xm2, xm%[sum], q3232 437 paddq xm%[sum], xm2 438 movq rax, xm%[sum] 439%endif 440 RET 441 442 %undef sum, kernel_width, res 443%endmacro 444 445INIT_XMM ssse3 446WEIGHTED_SSE 8, 4 447%if ARCH_X86_64 448WEIGHTED_SSE 8, 8 449WEIGHTED_SSE 8, 16 450WEIGHTED_SSE 8, 32 451%endif ; ARCH_X86_64 452 453INIT_YMM avx2 454WEIGHTED_SSE 16, 4 455%if ARCH_X86_64 456WEIGHTED_SSE 16, 8 457WEIGHTED_SSE 16, 16 458WEIGHTED_SSE 16, 32 459WEIGHTED_SSE 16, 64 460 461WEIGHTED_SSE 32, 8 462WEIGHTED_SSE 32, 16 463WEIGHTED_SSE 32, 32 464WEIGHTED_SSE 32, 64 465 466WEIGHTED_SSE 64, 16 467WEIGHTED_SSE 64, 32 468WEIGHTED_SSE 64, 64 469WEIGHTED_SSE 64, 128 470 471WEIGHTED_SSE 128, 64 472WEIGHTED_SSE 128, 128 473%endif ; ARCH_X86_64 474 475INIT_XMM sse2 476 477cglobal weighted_sse_4x4_hbd, 6, 8, 4, \ 478 src, src_stride, dst, dst_stride, scale, scale_stride, \ 479 src_stride3, dst_stride3 480 lea src_stride3q, [src_strideq*3] 481 lea dst_stride3q, [dst_strideq*3] 482 movq m0, [srcq] 483 movq m1, [dstq] 484 psubw m0, m1 485 pmaddwd m0, m0 486 movq m1, [srcq+src_strideq] 487 movq m2, [dstq+dst_strideq] 488 psubw m1, m2 489 pmaddwd m1, m1 490 paddd m0, m1 491 movq m1, [srcq+src_strideq*2] 492 movq m2, [dstq+dst_strideq*2] 493 psubw m1, m2 494 pmaddwd m1, m1 495 movq m2, [srcq+src_stride3q] 496 movq m3, [dstq+dst_stride3q] 497 psubw m2, m3 498 pmaddwd m2, m2 499 paddd m1, m2 500 paddd m0, m1 501 502 pshuflw m1, m0, q3232 503 paddd m0, m1 504 movd eax, m0 505 506 ; Multiply and shift using scalar code. 507 SSE_SCALE_4X4 508 RET 509