1/*
2 * Copyright © 2019, VideoLAN and dav1d authors
3 * Copyright © 2019, Martin Storsjo
4 * All rights reserved.
5 *
6 * Redistribution and use in source and binary forms, with or without
7 * modification, are permitted provided that the following conditions are met:
8 *
9 * 1. Redistributions of source code must retain the above copyright notice, this
10 *    list of conditions and the following disclaimer.
11 *
12 * 2. Redistributions in binary form must reproduce the above copyright notice,
13 *    this list of conditions and the following disclaimer in the documentation
14 *    and/or other materials provided with the distribution.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
17 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
20 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 */
27
28#include "src/arm/asm.S"
29#include "util.S"
30
31#define BUF_POS 0
32#define BUF_END 8
33#define DIF 16
34#define RNG 24
35#define CNT 28
36#define ALLOW_UPDATE_CDF 32
37
38const coeffs
39        .short 60, 56, 52, 48, 44, 40, 36, 32, 28, 24, 20, 16, 12, 8, 4, 0
40        .short 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0, 0, 0
41endconst
42
43const bits
44        .short   0x1,   0x2,   0x4,   0x8,   0x10,   0x20,   0x40,   0x80
45        .short 0x100, 0x200, 0x400, 0x800, 0x1000, 0x2000, 0x4000, 0x8000
46endconst
47
48.macro ld1_n d0, d1, src, sz, n
49.if \n <= 8
50        ld1             {\d0\sz},  [\src]
51.else
52        ld1             {\d0\sz, \d1\sz},  [\src]
53.endif
54.endm
55
56.macro st1_n s0, s1, dst, sz, n
57.if \n <= 8
58        st1             {\s0\sz},  [\dst]
59.else
60        st1             {\s0\sz, \s1\sz},  [\dst]
61.endif
62.endm
63
64.macro ushr_n d0, d1, s0, s1, shift, sz, n
65        ushr            \d0\sz,  \s0\sz,  \shift
66.if \n == 16
67        ushr            \d1\sz,  \s1\sz,  \shift
68.endif
69.endm
70
71.macro add_n d0, d1, s0, s1, s2, s3, sz, n
72        add             \d0\sz,  \s0\sz,  \s2\sz
73.if \n == 16
74        add             \d1\sz,  \s1\sz,  \s3\sz
75.endif
76.endm
77
78.macro sub_n d0, d1, s0, s1, s2, s3, sz, n
79        sub             \d0\sz,  \s0\sz,  \s2\sz
80.if \n == 16
81        sub             \d1\sz,  \s1\sz,  \s3\sz
82.endif
83.endm
84
85.macro and_n d0, d1, s0, s1, s2, s3, sz, n
86        and             \d0\sz,  \s0\sz,  \s2\sz
87.if \n == 16
88        and             \d1\sz,  \s1\sz,  \s3\sz
89.endif
90.endm
91
92.macro cmhs_n d0, d1, s0, s1, s2, s3, sz, n
93        cmhs            \d0\sz,  \s0\sz,  \s2\sz
94.if \n == 16
95        cmhs            \d1\sz,  \s1\sz,  \s3\sz
96.endif
97.endm
98
99.macro urhadd_n d0, d1, s0, s1, s2, s3, sz, n
100        urhadd          \d0\sz,  \s0\sz,  \s2\sz
101.if \n == 16
102        urhadd          \d1\sz,  \s1\sz,  \s3\sz
103.endif
104.endm
105
106.macro sshl_n d0, d1, s0, s1, s2, s3, sz, n
107        sshl            \d0\sz,  \s0\sz,  \s2\sz
108.if \n == 16
109        sshl            \d1\sz,  \s1\sz,  \s3\sz
110.endif
111.endm
112
113.macro sqdmulh_n d0, d1, s0, s1, s2, s3, sz, n
114        sqdmulh         \d0\sz,  \s0\sz,  \s2\sz
115.if \n == 16
116        sqdmulh         \d1\sz,  \s1\sz,  \s3\sz
117.endif
118.endm
119
120.macro str_n            idx0, idx1, dstreg, dstoff, n
121        str             \idx0,  [\dstreg, \dstoff]
122.if \n == 16
123        str             \idx1,  [\dstreg, \dstoff + 16]
124.endif
125.endm
126
127// unsigned dav1d_msac_decode_symbol_adapt4_neon(MsacContext *s, uint16_t *cdf,
128//                                               size_t n_symbols);
129
130function msac_decode_symbol_adapt4_neon, export=1
131.macro decode_update sz, szb, n
132        sub             sp,  sp,  #48
133        add             x8,  x0,  #RNG
134        ld1_n           v0,  v1,  x1,  \sz, \n                    // cdf
135        ld1r            {v4\sz},  [x8]                            // rng
136        movrel          x9,  coeffs, 30
137        movi            v31\sz, #0x7f, lsl #8                     // 0x7f00
138        sub             x9,  x9,  x2, lsl #1
139        mvni            v30\sz, #0x3f                             // 0xffc0
140        and             v7\szb, v4\szb, v31\szb                   // rng & 0x7f00
141        str             h4,  [sp, #14]                            // store original u = s->rng
142        and_n           v2,  v3,  v0,  v1,  v30, v30, \szb, \n    // cdf & 0xffc0
143
144        ld1_n           v4,  v5,  x9,  \sz, \n                    // EC_MIN_PROB * (n_symbols - ret)
145        sqdmulh_n       v6,  v7,  v2,  v3,  v7,  v7,  \sz, \n     // ((cdf >> EC_PROB_SHIFT) * (r - 128)) >> 1
146        add             x8,  x0,  #DIF + 6
147
148        add_n           v4,  v5,  v2,  v3,  v4,  v5,  \sz, \n     // v = cdf + EC_MIN_PROB * (n_symbols - ret)
149        add_n           v4,  v5,  v6,  v7,  v4,  v5,  \sz, \n     // v = ((cdf >> EC_PROB_SHIFT) * r) >> 1 + EC_MIN_PROB * (n_symbols - ret)
150
151        ld1r            {v6.8h},  [x8]                            // dif >> (EC_WIN_SIZE - 16)
152        movrel          x8,  bits
153        str_n           q4,  q5,  sp, #16, \n                     // store v values to allow indexed access
154
155        ld1_n           v16, v17, x8,  .8h, \n
156
157        cmhs_n          v2,  v3,  v6,  v6,  v4,  v5,  .8h,  \n    // c >= v
158
159        and_n           v6,  v7,  v2,  v3,  v16, v17, .16b, \n    // One bit per halfword set in the mask
160.if \n == 16
161        add             v6.8h,  v6.8h,  v7.8h
162.endif
163        addv            h6,  v6.8h                                // Aggregate mask bits
164        ldr             w4,  [x0, #ALLOW_UPDATE_CDF]
165        umov            w3,  v6.h[0]
166        rbit            w3,  w3
167        clz             w15, w3                                   // ret
168
169        cbz             w4,  L(renorm)
170        // update_cdf
171        ldrh            w3,  [x1, x2, lsl #1]                     // count = cdf[n_symbols]
172        movi            v5\szb, #0xff
173.if \n == 16
174        mov             w4,  #-5
175.else
176        mvn             w14, w2
177        mov             w4,  #-4
178        cmn             w14, #3                                   // set C if n_symbols <= 2
179.endif
180        urhadd_n        v4,  v5,  v5,  v5,  v2,  v3,  \sz, \n     // i >= val ? -1 : 32768
181.if \n == 16
182        sub             w4,  w4,  w3, lsr #4                      // -((count >> 4) + 5)
183.else
184        lsr             w14, w3,  #4                              // count >> 4
185        sbc             w4,  w4,  w14                             // -((count >> 4) + (n_symbols > 2) + 4)
186.endif
187        sub_n           v4,  v5,  v4,  v5,  v0,  v1,  \sz, \n     // (32768 - cdf[i]) or (-1 - cdf[i])
188        dup             v6\sz,    w4                              // -rate
189
190        sub             w3,  w3,  w3, lsr #5                      // count - (count == 32)
191        sub_n           v0,  v1,  v0,  v1,  v2,  v3,  \sz, \n     // cdf + (i >= val ? 1 : 0)
192        sshl_n          v4,  v5,  v4,  v5,  v6,  v6,  \sz, \n     // ({32768,-1} - cdf[i]) >> rate
193        add             w3,  w3,  #1                              // count + (count < 32)
194        add_n           v0,  v1,  v0,  v1,  v4,  v5,  \sz, \n     // cdf + (32768 - cdf[i]) >> rate
195        st1_n           v0,  v1,  x1,  \sz, \n
196        strh            w3,  [x1, x2, lsl #1]
197.endm
198
199        decode_update   .4h, .8b, 4
200
201L(renorm):
202        add             x8,  sp,  #16
203        add             x8,  x8,  w15, uxtw #1
204        ldrh            w3,  [x8]              // v
205        ldurh           w4,  [x8, #-2]         // u
206        ldr             w6,  [x0, #CNT]
207        ldr             x7,  [x0, #DIF]
208        sub             w4,  w4,  w3           // rng = u - v
209        clz             w5,  w4                // clz(rng)
210        eor             w5,  w5,  #16          // d = clz(rng) ^ 16
211        mvn             x7,  x7                // ~dif
212        add             x7,  x7,  x3, lsl #48  // ~dif + (v << 48)
213L(renorm2):
214        lsl             w4,  w4,  w5           // rng << d
215        subs            w6,  w6,  w5           // cnt -= d
216        lsl             x7,  x7,  x5           // (~dif + (v << 48)) << d
217        str             w4,  [x0, #RNG]
218        mvn             x7,  x7                // ~dif
219        b.hs            9f
220
221        // refill
222        ldp             x3,  x4,  [x0]         // BUF_POS, BUF_END
223        add             x5,  x3,  #8
224        cmp             x5,  x4
225        b.gt            2f
226
227        ldr             x3,  [x3]              // next_bits
228        add             w8,  w6,  #23          // shift_bits = cnt + 23
229        add             w6,  w6,  #16          // cnt += 16
230        rev             x3,  x3                // next_bits = bswap(next_bits)
231        sub             x5,  x5,  x8, lsr #3   // buf_pos -= shift_bits >> 3
232        and             w8,  w8,  #24          // shift_bits &= 24
233        lsr             x3,  x3,  x8           // next_bits >>= shift_bits
234        sub             w8,  w8,  w6           // shift_bits -= 16 + cnt
235        str             x5,  [x0, #BUF_POS]
236        lsl             x3,  x3,  x8           // next_bits <<= shift_bits
237        mov             w4,  #48
238        sub             w6,  w4,  w8           // cnt = cnt + 64 - shift_bits
239        eor             x7,  x7,  x3           // dif ^= next_bits
240        b               9f
241
2422:      // refill_eob
243        mov             w14, #40
244        sub             w5,  w14, w6           // c = 40 - cnt
2453:
246        cmp             x3,  x4
247        b.ge            4f
248        ldrb            w8,  [x3], #1
249        lsl             x8,  x8,  x5
250        eor             x7,  x7,  x8
251        subs            w5,  w5,  #8
252        b.ge            3b
253
2544:      // refill_eob_end
255        str             x3,  [x0, #BUF_POS]
256        sub             w6,  w14, w5           // cnt = 40 - c
257
2589:
259        str             w6,  [x0, #CNT]
260        str             x7,  [x0, #DIF]
261
262        mov             w0,  w15
263        add             sp,  sp,  #48
264        ret
265endfunc
266
267function msac_decode_symbol_adapt8_neon, export=1
268        decode_update   .8h, .16b, 8
269        b               L(renorm)
270endfunc
271
272function msac_decode_symbol_adapt16_neon, export=1
273        decode_update   .8h, .16b, 16
274        b               L(renorm)
275endfunc
276
277function msac_decode_hi_tok_neon, export=1
278        ld1             {v0.4h},  [x1]            // cdf
279        add             x16, x0,  #RNG
280        movi            v31.4h, #0x7f, lsl #8     // 0x7f00
281        movrel          x17, coeffs, 30-2*3
282        mvni            v30.4h, #0x3f             // 0xffc0
283        ldrh            w9,  [x1, #6]             // count = cdf[n_symbols]
284        ld1r            {v3.4h},  [x16]           // rng
285        movrel          x16, bits
286        ld1             {v29.4h}, [x17]           // EC_MIN_PROB * (n_symbols - ret)
287        add             x17, x0,  #DIF + 6
288        ld1             {v16.8h}, [x16]
289        mov             w13, #-24
290        and             v17.8b,  v0.8b,   v30.8b  // cdf & 0xffc0
291        ldr             w10, [x0, #ALLOW_UPDATE_CDF]
292        ld1r            {v1.8h},  [x17]           // dif >> (EC_WIN_SIZE - 16)
293        sub             sp,  sp,  #48
294        ldr             w6,  [x0, #CNT]
295        ldr             x7,  [x0, #DIF]
2961:
297        and             v7.8b,   v3.8b,   v31.8b  // rng & 0x7f00
298        sqdmulh         v6.4h,   v17.4h,  v7.4h   // ((cdf >> EC_PROB_SHIFT) * (r - 128)) >> 1
299        add             v4.4h,   v17.4h,  v29.4h  // v = cdf + EC_MIN_PROB * (n_symbols - ret)
300        add             v4.4h,   v6.4h,   v4.4h   // v = ((cdf >> EC_PROB_SHIFT) * r) >> 1 + EC_MIN_PROB * (n_symbols - ret)
301        str             h3,  [sp, #14]            // store original u = s->rng
302        cmhs            v2.8h,   v1.8h,   v4.8h   // c >= v
303        str             q4,  [sp, #16]            // store v values to allow indexed access
304        and             v6.16b,  v2.16b,  v16.16b // One bit per halfword set in the mask
305        addv            h6,  v6.8h                // Aggregate mask bits
306        umov            w3,  v6.h[0]
307        add             w13, w13, #5
308        rbit            w3,  w3
309        add             x8,  sp,  #16
310        clz             w15, w3                   // ret
311
312        cbz             w10, 2f
313        // update_cdf
314        movi            v5.8b, #0xff
315        mov             w4,  #-5
316        urhadd          v4.4h,   v5.4h,   v2.4h   // i >= val ? -1 : 32768
317        sub             w4,  w4,  w9, lsr #4      // -((count >> 4) + 5)
318        sub             v4.4h,   v4.4h,   v0.4h   // (32768 - cdf[i]) or (-1 - cdf[i])
319        dup             v6.4h,    w4              // -rate
320
321        sub             w9,  w9,  w9, lsr #5      // count - (count == 32)
322        sub             v0.4h,   v0.4h,   v2.4h   // cdf + (i >= val ? 1 : 0)
323        sshl            v4.4h,   v4.4h,   v6.4h   // ({32768,-1} - cdf[i]) >> rate
324        add             w9,  w9,  #1              // count + (count < 32)
325        add             v0.4h,   v0.4h,   v4.4h   // cdf + (32768 - cdf[i]) >> rate
326        st1             {v0.4h},  [x1]
327        and             v17.8b,  v0.8b,   v30.8b  // cdf & 0xffc0
328        strh            w9,  [x1, #6]
329
3302:
331        add             x8,  x8,  w15, uxtw #1
332        ldrh            w3,  [x8]              // v
333        ldurh           w4,  [x8, #-2]         // u
334        sub             w4,  w4,  w3           // rng = u - v
335        clz             w5,  w4                // clz(rng)
336        eor             w5,  w5,  #16          // d = clz(rng) ^ 16
337        mvn             x7,  x7                // ~dif
338        add             x7,  x7,  x3, lsl #48  // ~dif + (v << 48)
339        lsl             w4,  w4,  w5           // rng << d
340        subs            w6,  w6,  w5           // cnt -= d
341        lsl             x7,  x7,  x5           // (~dif + (v << 48)) << d
342        str             w4,  [x0, #RNG]
343        dup             v3.4h,   w4
344        mvn             x7,  x7                // ~dif
345        b.hs            9f
346
347        // refill
348        ldp             x3,  x4,  [x0]         // BUF_POS, BUF_END
349        add             x5,  x3,  #8
350        cmp             x5,  x4
351        b.gt            2f
352
353        ldr             x3,  [x3]              // next_bits
354        add             w8,  w6,  #23          // shift_bits = cnt + 23
355        add             w6,  w6,  #16          // cnt += 16
356        rev             x3,  x3                // next_bits = bswap(next_bits)
357        sub             x5,  x5,  x8, lsr #3   // buf_pos -= shift_bits >> 3
358        and             w8,  w8,  #24          // shift_bits &= 24
359        lsr             x3,  x3,  x8           // next_bits >>= shift_bits
360        sub             w8,  w8,  w6           // shift_bits -= 16 + cnt
361        str             x5,  [x0, #BUF_POS]
362        lsl             x3,  x3,  x8           // next_bits <<= shift_bits
363        mov             w4,  #48
364        sub             w6,  w4,  w8           // cnt = cnt + 64 - shift_bits
365        eor             x7,  x7,  x3           // dif ^= next_bits
366        b               9f
367
3682:      // refill_eob
369        mov             w14, #40
370        sub             w5,  w14, w6           // c = 40 - cnt
3713:
372        cmp             x3,  x4
373        b.ge            4f
374        ldrb            w8,  [x3], #1
375        lsl             x8,  x8,  x5
376        eor             x7,  x7,  x8
377        subs            w5,  w5,  #8
378        b.ge            3b
379
3804:      // refill_eob_end
381        str             x3,  [x0, #BUF_POS]
382        sub             w6,  w14, w5           // cnt = 40 - c
383
3849:
385        lsl             w15, w15, #1
386        sub             w15, w15, #5
387        lsr             x12, x7,  #48
388        adds            w13, w13, w15          // carry = tok_br < 3 || tok == 15
389        dup             v1.8h,   w12
390        b.cc            1b                     // loop if !carry
391        add             w13, w13, #30
392        str             w6,  [x0, #CNT]
393        add             sp,  sp,  #48
394        str             x7,  [x0, #DIF]
395        lsr             w0,  w13, #1
396        ret
397endfunc
398
399function msac_decode_bool_equi_neon, export=1
400        ldp             w5,  w6,  [x0, #RNG]   // + CNT
401        sub             sp,  sp,  #48
402        ldr             x7,  [x0, #DIF]
403        bic             w4,  w5,  #0xff        // r &= 0xff00
404        add             w4,  w4,  #8
405        subs            x8,  x7,  x4, lsl #47  // dif - vw
406        lsr             w4,  w4,  #1           // v
407        sub             w5,  w5,  w4           // r - v
408        cset            w15, lo
409        csel            w4,  w5,  w4,  hs      // if (ret) v = r - v;
410        csel            x7,  x8,  x7,  hs      // if (ret) dif = dif - vw;
411
412        clz             w5,  w4                // clz(rng)
413        mvn             x7,  x7                // ~dif
414        eor             w5,  w5,  #16          // d = clz(rng) ^ 16
415        b               L(renorm2)
416endfunc
417
418function msac_decode_bool_neon, export=1
419        ldp             w5,  w6,  [x0, #RNG]   // + CNT
420        sub             sp,  sp,  #48
421        ldr             x7,  [x0, #DIF]
422        lsr             w4,  w5,  #8           // r >> 8
423        bic             w1,  w1,  #0x3f        // f &= ~63
424        mul             w4,  w4,  w1
425        lsr             w4,  w4,  #7
426        add             w4,  w4,  #4           // v
427        subs            x8,  x7,  x4, lsl #48  // dif - vw
428        sub             w5,  w5,  w4           // r - v
429        cset            w15, lo
430        csel            w4,  w5,  w4,  hs      // if (ret) v = r - v;
431        csel            x7,  x8,  x7,  hs      // if (ret) dif = dif - vw;
432
433        clz             w5,  w4                // clz(rng)
434        mvn             x7,  x7                // ~dif
435        eor             w5,  w5,  #16          // d = clz(rng) ^ 16
436        b               L(renorm2)
437endfunc
438
439function msac_decode_bool_adapt_neon, export=1
440        ldr             w9,  [x1]              // cdf[0-1]
441        ldp             w5,  w6,  [x0, #RNG]   // + CNT
442        sub             sp,  sp,  #48
443        ldr             x7,  [x0, #DIF]
444        lsr             w4,  w5,  #8           // r >> 8
445        and             w2,  w9,  #0xffc0      // f &= ~63
446        mul             w4,  w4,  w2
447        lsr             w4,  w4,  #7
448        add             w4,  w4,  #4           // v
449        subs            x8,  x7,  x4, lsl #48  // dif - vw
450        sub             w5,  w5,  w4           // r - v
451        cset            w15, lo
452        csel            w4,  w5,  w4,  hs      // if (ret) v = r - v;
453        csel            x7,  x8,  x7,  hs      // if (ret) dif = dif - vw;
454
455        ldr             w10, [x0, #ALLOW_UPDATE_CDF]
456
457        clz             w5,  w4                // clz(rng)
458        mvn             x7,  x7                // ~dif
459        eor             w5,  w5,  #16          // d = clz(rng) ^ 16
460
461        cbz             w10, L(renorm2)
462
463        lsr             w2,  w9,  #16          // count = cdf[1]
464        and             w9,  w9,  #0xffff      // cdf[0]
465
466        sub             w3,  w2,  w2, lsr #5   // count - (count >= 32)
467        lsr             w2,  w2,  #4           // count >> 4
468        add             w10, w3,  #1           // count + (count < 32)
469        add             w2,  w2,  #4           // rate = (count >> 4) | 4
470
471        sub             w9,  w9,  w15          // cdf[0] -= bit
472        sub             w11, w9,  w15, lsl #15 // {cdf[0], cdf[0] - 32769}
473        asr             w11, w11, w2           // {cdf[0], cdf[0] - 32769} >> rate
474        sub             w9,  w9,  w11          // cdf[0]
475
476        strh            w9,  [x1]
477        strh            w10, [x1, #2]
478
479        b               L(renorm2)
480endfunc
481