xref: /qemu/tests/tcg/i386/test-avx.c (revision 45b5933f)
1 #include <stdio.h>
2 #include <stdint.h>
3 #include <stdlib.h>
4 #include <string.h>
5 
6 typedef void (*testfn)(void);
7 
8 typedef struct {
9     uint64_t q0, q1, q2, q3;
10 } __attribute__((aligned(32))) v4di;
11 
12 typedef struct {
13     uint64_t mm[8];
14     v4di ymm[16];
15     uint64_t r[16];
16     uint64_t flags;
17     uint32_t ff;
18     uint64_t pad;
19     v4di mem[4];
20     v4di mem0[4];
21 } reg_state;
22 
23 typedef struct {
24     int n;
25     testfn fn;
26     const char *s;
27     reg_state *init;
28 } TestDef;
29 
30 reg_state initI;
31 reg_state initF16;
32 reg_state initF32;
33 reg_state initF64;
34 
35 static void dump_ymm(const char *name, int n, const v4di *r, int ff)
36 {
37     printf("%s%d = %016lx %016lx %016lx %016lx\n",
38            name, n, r->q3, r->q2, r->q1, r->q0);
39     if (ff == 64) {
40         double v[4];
41         memcpy(v, r, sizeof(v));
42         printf("        %16g %16g %16g %16g\n",
43                 v[3], v[2], v[1], v[0]);
44     } else if (ff == 32) {
45         float v[8];
46         memcpy(v, r, sizeof(v));
47         printf(" %8g %8g %8g %8g %8g %8g %8g %8g\n",
48                 v[7], v[6], v[5], v[4], v[3], v[2], v[1], v[0]);
49     }
50 }
51 
52 static void dump_regs(reg_state *s)
53 {
54     int i;
55 
56     for (i = 0; i < 16; i++) {
57         dump_ymm("ymm", i, &s->ymm[i], 0);
58     }
59     for (i = 0; i < 4; i++) {
60         dump_ymm("mem", i, &s->mem0[i], 0);
61     }
62 }
63 
64 static void compare_state(const reg_state *a, const reg_state *b)
65 {
66     int i;
67     for (i = 0; i < 8; i++) {
68         if (a->mm[i] != b->mm[i]) {
69             printf("MM%d = %016lx\n", i, b->mm[i]);
70         }
71     }
72     for (i = 0; i < 16; i++) {
73         if (a->r[i] != b->r[i]) {
74             printf("r%d = %016lx\n", i, b->r[i]);
75         }
76     }
77     for (i = 0; i < 16; i++) {
78         if (memcmp(&a->ymm[i], &b->ymm[i], 32)) {
79             dump_ymm("ymm", i, &b->ymm[i], a->ff);
80         }
81     }
82     for (i = 0; i < 4; i++) {
83         if (memcmp(&a->mem0[i], &a->mem[i], 32)) {
84             dump_ymm("mem", i, &a->mem[i], a->ff);
85         }
86     }
87     if (a->flags != b->flags) {
88         printf("FLAGS = %016lx\n", b->flags);
89     }
90 }
91 
92 #define LOADMM(r, o) "movq " #r ", " #o "[%0]\n\t"
93 #define LOADYMM(r, o) "vmovdqa " #r ", " #o "[%0]\n\t"
94 #define STOREMM(r, o) "movq " #o "[%1], " #r "\n\t"
95 #define STOREYMM(r, o) "vmovdqa " #o "[%1], " #r "\n\t"
96 #define MMREG(F) \
97     F(mm0, 0x00) \
98     F(mm1, 0x08) \
99     F(mm2, 0x10) \
100     F(mm3, 0x18) \
101     F(mm4, 0x20) \
102     F(mm5, 0x28) \
103     F(mm6, 0x30) \
104     F(mm7, 0x38)
105 #define YMMREG(F) \
106     F(ymm0, 0x040) \
107     F(ymm1, 0x060) \
108     F(ymm2, 0x080) \
109     F(ymm3, 0x0a0) \
110     F(ymm4, 0x0c0) \
111     F(ymm5, 0x0e0) \
112     F(ymm6, 0x100) \
113     F(ymm7, 0x120) \
114     F(ymm8, 0x140) \
115     F(ymm9, 0x160) \
116     F(ymm10, 0x180) \
117     F(ymm11, 0x1a0) \
118     F(ymm12, 0x1c0) \
119     F(ymm13, 0x1e0) \
120     F(ymm14, 0x200) \
121     F(ymm15, 0x220)
122 #define LOADREG(r, o) "mov " #r ", " #o "[rax]\n\t"
123 #define STOREREG(r, o) "mov " #o "[rax], " #r "\n\t"
124 #define REG(F) \
125     F(rbx, 0x248) \
126     F(rcx, 0x250) \
127     F(rdx, 0x258) \
128     F(rsi, 0x260) \
129     F(rdi, 0x268) \
130     F(r8, 0x280) \
131     F(r9, 0x288) \
132     F(r10, 0x290) \
133     F(r11, 0x298) \
134     F(r12, 0x2a0) \
135     F(r13, 0x2a8) \
136     F(r14, 0x2b0) \
137     F(r15, 0x2b8) \
138 
139 static void run_test(const TestDef *t)
140 {
141     reg_state result;
142     reg_state *init = t->init;
143     memcpy(init->mem, init->mem0, sizeof(init->mem));
144     printf("%5d %s\n", t->n, t->s);
145     asm volatile(
146             MMREG(LOADMM)
147             YMMREG(LOADYMM)
148             "sub rsp, 128\n\t"
149             "push rax\n\t"
150             "push rbx\n\t"
151             "push rcx\n\t"
152             "push rdx\n\t"
153             "push %1\n\t"
154             "push %2\n\t"
155             "mov rax, %0\n\t"
156             "pushf\n\t"
157             "pop rbx\n\t"
158             "shr rbx, 8\n\t"
159             "shl rbx, 8\n\t"
160             "mov rcx, 0x2c0[rax]\n\t"
161             "and rcx, 0xff\n\t"
162             "or rbx, rcx\n\t"
163             "push rbx\n\t"
164             "popf\n\t"
165             REG(LOADREG)
166             "mov rax, 0x240[rax]\n\t"
167             "call [rsp]\n\t"
168             "mov [rsp], rax\n\t"
169             "mov rax, 8[rsp]\n\t"
170             REG(STOREREG)
171             "mov rbx, [rsp]\n\t"
172             "mov 0x240[rax], rbx\n\t"
173             "mov rbx, 0\n\t"
174             "mov 0x270[rax], rbx\n\t"
175             "mov 0x278[rax], rbx\n\t"
176             "pushf\n\t"
177             "pop rbx\n\t"
178             "and rbx, 0xff\n\t"
179             "mov 0x2c0[rax], rbx\n\t"
180             "add rsp, 16\n\t"
181             "pop rdx\n\t"
182             "pop rcx\n\t"
183             "pop rbx\n\t"
184             "pop rax\n\t"
185             "add rsp, 128\n\t"
186             MMREG(STOREMM)
187             YMMREG(STOREYMM)
188             : : "r"(init), "r"(&result), "r"(t->fn)
189             : "memory", "cc",
190             "rsi", "rdi",
191             "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
192             "mm0", "mm1", "mm2", "mm3", "mm4", "mm5", "mm6", "mm7",
193             "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5",
194             "ymm6", "ymm7", "ymm8", "ymm9", "ymm10", "ymm11",
195             "ymm12", "ymm13", "ymm14", "ymm15"
196             );
197     compare_state(init, &result);
198 }
199 
200 #define TEST(n, cmd, type) \
201 static void __attribute__((naked)) test_##n(void) \
202 { \
203     asm volatile(cmd); \
204     asm volatile("ret"); \
205 }
206 #include "test-avx.h"
207 
208 
209 static const TestDef test_table[] = {
210 #define TEST(n, cmd, type) {n, test_##n, cmd, &init##type},
211 #include "test-avx.h"
212     {-1, NULL, "", NULL}
213 };
214 
215 static void run_all(void)
216 {
217     const TestDef *t;
218     for (t = test_table; t->fn; t++) {
219         run_test(t);
220     }
221 }
222 
223 #define ARRAY_LEN(x) (sizeof(x) / sizeof(x[0]))
224 
225 uint16_t val_f16[] = { 0x4000, 0xbc00, 0x44cd, 0x3a66, 0x4200, 0x7a1a, 0x4780, 0x4826 };
226 float val_f32[] = {2.0, -1.0, 4.8, 0.8, 3, -42.0, 5e6, 7.5, 8.3};
227 double val_f64[] = {2.0, -1.0, 4.8, 0.8, 3, -42.0, 5e6, 7.5};
228 v4di val_i64[] = {
229     {0x3d6b3b6a9e4118f2lu, 0x355ae76d2774d78clu,
230      0xac3ff76c4daa4b28lu, 0xe7fabd204cb54083lu},
231     {0xd851c54a56bf1f29lu, 0x4a84d1d50bf4c4fflu,
232      0x56621e553d52b56clu, 0xd0069553da8f584alu},
233     {0x5826475e2c5fd799lu, 0xfd32edc01243f5e9lu,
234      0x738ba2c66d3fe126lu, 0x5707219c6e6c26b4lu},
235 };
236 
237 v4di deadbeef = {0xa5a5a5a5deadbeefull, 0xa5a5a5a5deadbeefull,
238                  0xa5a5a5a5deadbeefull, 0xa5a5a5a5deadbeefull};
239 /* &gather_mem[0x10] is 512 bytes from the base; indices must be >=-64, <64
240  * to account for scaling by 8 */
241 v4di indexq = {0x000000000000001full, 0x000000000000003dull,
242                0xffffffffffffffffull, 0xffffffffffffffdfull};
243 v4di indexd = {0x00000002ffffffcdull, 0xfffffff500000010ull,
244                0x0000003afffffff0ull, 0x000000000000000eull};
245 
246 v4di gather_mem[0x20];
247 _Static_assert(sizeof(gather_mem) == 1024);
248 
249 void init_f16reg(v4di *r)
250 {
251     memset(r, 0, sizeof(*r));
252     memcpy(r, val_f16, sizeof(val_f16));
253 }
254 
255 void init_f32reg(v4di *r)
256 {
257     static int n;
258     float v[8];
259     int i;
260     for (i = 0; i < 8; i++) {
261         v[i] = val_f32[n++];
262         if (n == ARRAY_LEN(val_f32)) {
263             n = 0;
264         }
265     }
266     memcpy(r, v, sizeof(*r));
267 }
268 
269 void init_f64reg(v4di *r)
270 {
271     static int n;
272     double v[4];
273     int i;
274     for (i = 0; i < 4; i++) {
275         v[i] = val_f64[n++];
276         if (n == ARRAY_LEN(val_f64)) {
277             n = 0;
278         }
279     }
280     memcpy(r, v, sizeof(*r));
281 }
282 
283 void init_intreg(v4di *r)
284 {
285     static uint64_t mask;
286     static int n;
287 
288     r->q0 = val_i64[n].q0 ^ mask;
289     r->q1 = val_i64[n].q1 ^ mask;
290     r->q2 = val_i64[n].q2 ^ mask;
291     r->q3 = val_i64[n].q3 ^ mask;
292     n++;
293     if (n == ARRAY_LEN(val_i64)) {
294         n = 0;
295         mask *= 0x104C11DB7;
296     }
297 }
298 
299 static void init_all(reg_state *s)
300 {
301     int i;
302 
303     s->r[3] = (uint64_t)&s->mem[0]; /* rdx */
304     s->r[4] = (uint64_t)&gather_mem[ARRAY_LEN(gather_mem) / 2]; /* rsi */
305     s->r[5] = (uint64_t)&s->mem[2]; /* rdi */
306     s->flags = 2;
307     for (i = 0; i < 16; i++) {
308         s->ymm[i] = deadbeef;
309     }
310     s->ymm[13] = indexd;
311     s->ymm[14] = indexq;
312     for (i = 0; i < 4; i++) {
313         s->mem0[i] = deadbeef;
314     }
315 }
316 
317 int main(int argc, char *argv[])
318 {
319     int i;
320 
321     init_all(&initI);
322     init_intreg(&initI.ymm[10]);
323     init_intreg(&initI.ymm[11]);
324     init_intreg(&initI.ymm[12]);
325     init_intreg(&initI.mem0[1]);
326     printf("Int:\n");
327     dump_regs(&initI);
328 
329     init_all(&initF16);
330     init_f16reg(&initF16.ymm[10]);
331     init_f16reg(&initF16.ymm[11]);
332     init_f16reg(&initF16.ymm[12]);
333     init_f16reg(&initF16.mem0[1]);
334     initF16.ff = 16;
335     printf("F16:\n");
336     dump_regs(&initF16);
337 
338     init_all(&initF32);
339     init_f32reg(&initF32.ymm[10]);
340     init_f32reg(&initF32.ymm[11]);
341     init_f32reg(&initF32.ymm[12]);
342     init_f32reg(&initF32.mem0[1]);
343     initF32.ff = 32;
344     printf("F32:\n");
345     dump_regs(&initF32);
346 
347     init_all(&initF64);
348     init_f64reg(&initF64.ymm[10]);
349     init_f64reg(&initF64.ymm[11]);
350     init_f64reg(&initF64.ymm[12]);
351     init_f64reg(&initF64.mem0[1]);
352     initF64.ff = 64;
353     printf("F64:\n");
354     dump_regs(&initF64);
355 
356     for (i = 0; i < ARRAY_LEN(gather_mem); i++) {
357         init_intreg(&gather_mem[i]);
358     }
359 
360     if (argc > 1) {
361         int n = atoi(argv[1]);
362         run_test(&test_table[n]);
363     } else {
364         run_all();
365     }
366     return 0;
367 }
368