1 /*
2 * Copyright © 2019, VideoLAN and dav1d authors
3 * Copyright © 2019, Two Orioles, LLC
4 * All rights reserved.
5 *
6 * Redistribution and use in source and binary forms, with or without
7 * modification, are permitted provided that the following conditions are met:
8 *
9 * 1. Redistributions of source code must retain the above copyright notice, this
10 * list of conditions and the following disclaimer.
11 *
12 * 2. Redistributions in binary form must reproduce the above copyright notice,
13 * this list of conditions and the following disclaimer in the documentation
14 * and/or other materials provided with the distribution.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
17 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
20 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 */
27
28 #include "tests/checkasm/checkasm.h"
29
30 #include "src/cpu.h"
31 #include "src/msac.h"
32
33 #include <stdio.h>
34 #include <string.h>
35
36 #define BUF_SIZE 8192
37
38 /* The normal code doesn't use function pointers */
39 typedef unsigned (*decode_symbol_adapt_fn)(MsacContext *s, uint16_t *cdf,
40 size_t n_symbols);
41 typedef unsigned (*decode_adapt_fn)(MsacContext *s, uint16_t *cdf);
42 typedef unsigned (*decode_bool_equi_fn)(MsacContext *s);
43 typedef unsigned (*decode_bool_fn)(MsacContext *s, unsigned f);
44
45 typedef struct {
46 decode_symbol_adapt_fn decode_symbol_adapt4;
47 decode_symbol_adapt_fn decode_symbol_adapt8;
48 decode_symbol_adapt_fn decode_symbol_adapt16;
49 decode_adapt_fn decode_bool_adapt;
50 decode_bool_equi_fn decode_bool_equi;
51 decode_bool_fn decode_bool;
52 decode_adapt_fn decode_hi_tok;
53 } MsacDSPContext;
54
randomize_cdf(uint16_t * const cdf,const int n)55 static void randomize_cdf(uint16_t *const cdf, const int n) {
56 int i;
57 for (i = 15; i > n; i--)
58 cdf[i] = rnd(); // padding
59 cdf[i] = 0; // count
60 do {
61 cdf[i - 1] = cdf[i] + rnd() % (32768 - cdf[i] - i) + 1;
62 } while (--i > 0);
63 }
64
65 /* memcmp() on structs can have weird behavior due to padding etc. */
msac_cmp(const MsacContext * const a,const MsacContext * const b)66 static int msac_cmp(const MsacContext *const a, const MsacContext *const b) {
67 return a->buf_pos != b->buf_pos || a->buf_end != b->buf_end ||
68 a->dif != b->dif || a->rng != b->rng || a->cnt != b->cnt ||
69 a->allow_update_cdf != b->allow_update_cdf;
70 }
71
msac_dump(unsigned c_res,unsigned a_res,const MsacContext * const a,const MsacContext * const b,const uint16_t * const cdf_a,const uint16_t * const cdf_b,const int num_cdf)72 static void msac_dump(unsigned c_res, unsigned a_res,
73 const MsacContext *const a, const MsacContext *const b,
74 const uint16_t *const cdf_a, const uint16_t *const cdf_b,
75 const int num_cdf)
76 {
77 if (c_res != a_res)
78 fprintf(stderr, "c_res %u a_res %u\n", c_res, a_res);
79 if (a->buf_pos != b->buf_pos)
80 fprintf(stderr, "buf_pos %p vs %p\n", a->buf_pos, b->buf_pos);
81 if (a->buf_end != b->buf_end)
82 fprintf(stderr, "buf_end %p vs %p\n", a->buf_end, b->buf_end);
83 if (a->dif != b->dif)
84 fprintf(stderr, "dif %zx vs %zx\n", a->dif, b->dif);
85 if (a->rng != b->rng)
86 fprintf(stderr, "rng %u vs %u\n", a->rng, b->rng);
87 if (a->cnt != b->cnt)
88 fprintf(stderr, "cnt %d vs %d\n", a->cnt, b->cnt);
89 if (a->allow_update_cdf)
90 fprintf(stderr, "allow_update_cdf %d vs %d\n",
91 a->allow_update_cdf, b->allow_update_cdf);
92 if (num_cdf && memcmp(cdf_a, cdf_b, sizeof(*cdf_a) * (num_cdf + 1))) {
93 fprintf(stderr, "cdf:\n");
94 for (int i = 0; i <= num_cdf; i++)
95 fprintf(stderr, " %5u", cdf_a[i]);
96 fprintf(stderr, "\n");
97 for (int i = 0; i <= num_cdf; i++)
98 fprintf(stderr, " %5u", cdf_b[i]);
99 fprintf(stderr, "\n");
100 for (int i = 0; i <= num_cdf; i++)
101 fprintf(stderr, " %c", cdf_a[i] != cdf_b[i] ? 'x' : '.');
102 fprintf(stderr, "\n");
103 }
104 }
105
106 #define CHECK_SYMBOL_ADAPT(n, n_min, n_max) do { \
107 if (check_func(c->decode_symbol_adapt##n, \
108 "msac_decode_symbol_adapt%d", n)) \
109 { \
110 for (int cdf_update = 0; cdf_update <= 1; cdf_update++) { \
111 for (int ns = n_min; ns <= n_max; ns++) { \
112 dav1d_msac_init(&s_c, buf, BUF_SIZE, !cdf_update); \
113 s_a = s_c; \
114 randomize_cdf(cdf[0], ns); \
115 memcpy(cdf[1], cdf[0], sizeof(*cdf)); \
116 for (int i = 0; i < 64; i++) { \
117 unsigned c_res = call_ref(&s_c, cdf[0], ns); \
118 unsigned a_res = call_new(&s_a, cdf[1], ns); \
119 if (c_res != a_res || msac_cmp(&s_c, &s_a) || \
120 memcmp(cdf[0], cdf[1], sizeof(**cdf) * (ns + 1))) \
121 { \
122 if (fail()) \
123 msac_dump(c_res, a_res, &s_c, &s_a, \
124 cdf[0], cdf[1], ns); \
125 } \
126 } \
127 if (cdf_update && ns == n - 1) \
128 bench_new(&s_a, cdf[1], ns); \
129 } \
130 } \
131 } \
132 } while (0)
133
check_decode_symbol(MsacDSPContext * const c,uint8_t * const buf)134 static void check_decode_symbol(MsacDSPContext *const c, uint8_t *const buf) {
135 ALIGN_STK_32(uint16_t, cdf, 2, [16]);
136 MsacContext s_c, s_a;
137
138 declare_func(unsigned, MsacContext *s, uint16_t *cdf, size_t n_symbols);
139 CHECK_SYMBOL_ADAPT( 4, 1, 4);
140 CHECK_SYMBOL_ADAPT( 8, 1, 7);
141 CHECK_SYMBOL_ADAPT(16, 3, 15);
142 report("decode_symbol");
143 }
144
check_decode_bool_adapt(MsacDSPContext * const c,uint8_t * const buf)145 static void check_decode_bool_adapt(MsacDSPContext *const c, uint8_t *const buf) {
146 MsacContext s_c, s_a;
147
148 declare_func(unsigned, MsacContext *s, uint16_t *cdf);
149 if (check_func(c->decode_bool_adapt, "msac_decode_bool_adapt")) {
150 uint16_t cdf[2][2];
151 for (int cdf_update = 0; cdf_update <= 1; cdf_update++) {
152 dav1d_msac_init(&s_c, buf, BUF_SIZE, !cdf_update);
153 s_a = s_c;
154 cdf[0][0] = cdf[1][0] = rnd() % 32767 + 1;
155 cdf[0][1] = cdf[1][1] = 0;
156 for (int i = 0; i < 64; i++) {
157 unsigned c_res = call_ref(&s_c, cdf[0]);
158 unsigned a_res = call_new(&s_a, cdf[1]);
159 if (c_res != a_res || msac_cmp(&s_c, &s_a) ||
160 memcmp(cdf[0], cdf[1], sizeof(*cdf)))
161 {
162 if (fail())
163 msac_dump(c_res, a_res, &s_c, &s_a, cdf[0], cdf[1], 1);
164 }
165 }
166 if (cdf_update)
167 bench_new(&s_a, cdf[1]);
168 }
169 }
170 }
171
check_decode_bool_equi(MsacDSPContext * const c,uint8_t * const buf)172 static void check_decode_bool_equi(MsacDSPContext *const c, uint8_t *const buf) {
173 MsacContext s_c, s_a;
174
175 declare_func(unsigned, MsacContext *s);
176 if (check_func(c->decode_bool_equi, "msac_decode_bool_equi")) {
177 dav1d_msac_init(&s_c, buf, BUF_SIZE, 1);
178 s_a = s_c;
179 for (int i = 0; i < 64; i++) {
180 unsigned c_res = call_ref(&s_c);
181 unsigned a_res = call_new(&s_a);
182 if (c_res != a_res || msac_cmp(&s_c, &s_a)) {
183 if (fail())
184 msac_dump(c_res, a_res, &s_c, &s_a, NULL, NULL, 0);
185 }
186 }
187 bench_new(&s_a);
188 }
189 }
190
check_decode_bool(MsacDSPContext * const c,uint8_t * const buf)191 static void check_decode_bool(MsacDSPContext *const c, uint8_t *const buf) {
192 MsacContext s_c, s_a;
193
194 declare_func(unsigned, MsacContext *s, unsigned f);
195 if (check_func(c->decode_bool, "msac_decode_bool")) {
196 dav1d_msac_init(&s_c, buf, BUF_SIZE, 1);
197 s_a = s_c;
198 for (int i = 0; i < 64; i++) {
199 const unsigned f = rnd() & 0x7fff;
200 unsigned c_res = call_ref(&s_c, f);
201 unsigned a_res = call_new(&s_a, f);
202 if (c_res != a_res || msac_cmp(&s_c, &s_a)) {
203 if (fail())
204 msac_dump(c_res, a_res, &s_c, &s_a, NULL, NULL, 0);
205 }
206 }
207 bench_new(&s_a, 16384);
208 }
209
210 }
211
check_decode_bool_funcs(MsacDSPContext * const c,uint8_t * const buf)212 static void check_decode_bool_funcs(MsacDSPContext *const c, uint8_t *const buf) {
213 check_decode_bool_adapt(c, buf);
214 check_decode_bool_equi(c, buf);
215 check_decode_bool(c, buf);
216 report("decode_bool");
217 }
218
check_decode_hi_tok(MsacDSPContext * const c,uint8_t * const buf)219 static void check_decode_hi_tok(MsacDSPContext *const c, uint8_t *const buf) {
220 ALIGN_STK_16(uint16_t, cdf, 2, [16]);
221 MsacContext s_c, s_a;
222
223 declare_func(unsigned, MsacContext *s, uint16_t *cdf);
224 if (check_func(c->decode_hi_tok, "msac_decode_hi_tok")) {
225 for (int cdf_update = 0; cdf_update <= 1; cdf_update++) {
226 dav1d_msac_init(&s_c, buf, BUF_SIZE, !cdf_update);
227 s_a = s_c;
228 randomize_cdf(cdf[0], 3);
229 memcpy(cdf[1], cdf[0], sizeof(*cdf));
230 for (int i = 0; i < 64; i++) {
231 unsigned c_res = call_ref(&s_c, cdf[0]);
232 unsigned a_res = call_new(&s_a, cdf[1]);
233 if (c_res != a_res || msac_cmp(&s_c, &s_a) ||
234 memcmp(cdf[0], cdf[1], sizeof(*cdf)))
235 {
236 if (fail())
237 msac_dump(c_res, a_res, &s_c, &s_a, cdf[0], cdf[1], 3);
238 break;
239 }
240 }
241 if (cdf_update)
242 bench_new(&s_a, cdf[1]);
243 }
244 }
245 report("decode_hi_tok");
246 }
247
checkasm_check_msac(void)248 void checkasm_check_msac(void) {
249 MsacDSPContext c;
250 c.decode_symbol_adapt4 = dav1d_msac_decode_symbol_adapt_c;
251 c.decode_symbol_adapt8 = dav1d_msac_decode_symbol_adapt_c;
252 c.decode_symbol_adapt16 = dav1d_msac_decode_symbol_adapt_c;
253 c.decode_bool_adapt = dav1d_msac_decode_bool_adapt_c;
254 c.decode_bool_equi = dav1d_msac_decode_bool_equi_c;
255 c.decode_bool = dav1d_msac_decode_bool_c;
256 c.decode_hi_tok = dav1d_msac_decode_hi_tok_c;
257
258 #if (ARCH_AARCH64 || ARCH_ARM) && HAVE_ASM
259 if (dav1d_get_cpu_flags() & DAV1D_ARM_CPU_FLAG_NEON) {
260 c.decode_symbol_adapt4 = dav1d_msac_decode_symbol_adapt4_neon;
261 c.decode_symbol_adapt8 = dav1d_msac_decode_symbol_adapt8_neon;
262 c.decode_symbol_adapt16 = dav1d_msac_decode_symbol_adapt16_neon;
263 c.decode_bool_adapt = dav1d_msac_decode_bool_adapt_neon;
264 c.decode_bool_equi = dav1d_msac_decode_bool_equi_neon;
265 c.decode_bool = dav1d_msac_decode_bool_neon;
266 c.decode_hi_tok = dav1d_msac_decode_hi_tok_neon;
267 }
268 #elif ARCH_X86 && HAVE_ASM
269 if (dav1d_get_cpu_flags() & DAV1D_X86_CPU_FLAG_SSE2) {
270 c.decode_symbol_adapt4 = dav1d_msac_decode_symbol_adapt4_sse2;
271 c.decode_symbol_adapt8 = dav1d_msac_decode_symbol_adapt8_sse2;
272 c.decode_symbol_adapt16 = dav1d_msac_decode_symbol_adapt16_sse2;
273 c.decode_bool_adapt = dav1d_msac_decode_bool_adapt_sse2;
274 c.decode_bool_equi = dav1d_msac_decode_bool_equi_sse2;
275 c.decode_bool = dav1d_msac_decode_bool_sse2;
276 c.decode_hi_tok = dav1d_msac_decode_hi_tok_sse2;
277 }
278
279 #if ARCH_X86_64
280 if (dav1d_get_cpu_flags() & DAV1D_X86_CPU_FLAG_AVX2) {
281 c.decode_symbol_adapt16 = dav1d_msac_decode_symbol_adapt16_avx2;
282 }
283 #endif
284 #endif
285
286 uint8_t buf[BUF_SIZE];
287 for (int i = 0; i < BUF_SIZE; i++)
288 buf[i] = rnd();
289
290 check_decode_symbol(&c, buf);
291 check_decode_bool_funcs(&c, buf);
292 check_decode_hi_tok(&c, buf);
293 }
294