xref: /qemu/target/arm/tcg/translate-sme.c (revision bb509d94)
1 /*
2  * AArch64 SME translation
3  *
4  * Copyright (c) 2022 Linaro, Ltd
5  *
6  * This library is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU Lesser General Public
8  * License as published by the Free Software Foundation; either
9  * version 2.1 of the License, or (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14  * Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public
17  * License along with this library; if not, see <http://www.gnu.org/licenses/>.
18  */
19 
20 #include "qemu/osdep.h"
21 #include "cpu.h"
22 #include "tcg/tcg-op.h"
23 #include "tcg/tcg-op-gvec.h"
24 #include "tcg/tcg-gvec-desc.h"
25 #include "translate.h"
26 #include "exec/helper-gen.h"
27 #include "translate-a64.h"
28 #include "fpu/softfloat.h"
29 
30 
31 /*
32  * Include the generated decoder.
33  */
34 
35 #include "decode-sme.c.inc"
36 
37 
38 /*
39  * Resolve tile.size[index] to a host pointer, where tile and index
40  * are always decoded together, dependent on the element size.
41  */
42 static TCGv_ptr get_tile_rowcol(DisasContext *s, int esz, int rs,
43                                 int tile_index, bool vertical)
44 {
45     int tile = tile_index >> (4 - esz);
46     int index = esz == MO_128 ? 0 : extract32(tile_index, 0, 4 - esz);
47     int pos, len, offset;
48     TCGv_i32 tmp;
49     TCGv_ptr addr;
50 
51     /* Compute the final index, which is Rs+imm. */
52     tmp = tcg_temp_new_i32();
53     tcg_gen_trunc_tl_i32(tmp, cpu_reg(s, rs));
54     tcg_gen_addi_i32(tmp, tmp, index);
55 
56     /* Prepare a power-of-two modulo via extraction of @len bits. */
57     len = ctz32(streaming_vec_reg_size(s)) - esz;
58 
59     if (vertical) {
60         /*
61          * Compute the byte offset of the index within the tile:
62          *     (index % (svl / size)) * size
63          *   = (index % (svl >> esz)) << esz
64          * Perform the power-of-two modulo via extraction of the low @len bits.
65          * Perform the multiply by shifting left by @pos bits.
66          * Perform these operations simultaneously via deposit into zero.
67          */
68         pos = esz;
69         tcg_gen_deposit_z_i32(tmp, tmp, pos, len);
70 
71         /*
72          * For big-endian, adjust the indexed column byte offset within
73          * the uint64_t host words that make up env->zarray[].
74          */
75         if (HOST_BIG_ENDIAN && esz < MO_64) {
76             tcg_gen_xori_i32(tmp, tmp, 8 - (1 << esz));
77         }
78     } else {
79         /*
80          * Compute the byte offset of the index within the tile:
81          *     (index % (svl / size)) * (size * sizeof(row))
82          *   = (index % (svl >> esz)) << (esz + log2(sizeof(row)))
83          */
84         pos = esz + ctz32(sizeof(ARMVectorReg));
85         tcg_gen_deposit_z_i32(tmp, tmp, pos, len);
86 
87         /* Row slices are always aligned and need no endian adjustment. */
88     }
89 
90     /* The tile byte offset within env->zarray is the row. */
91     offset = tile * sizeof(ARMVectorReg);
92 
93     /* Include the byte offset of zarray to make this relative to env. */
94     offset += offsetof(CPUARMState, zarray);
95     tcg_gen_addi_i32(tmp, tmp, offset);
96 
97     /* Add the byte offset to env to produce the final pointer. */
98     addr = tcg_temp_new_ptr();
99     tcg_gen_ext_i32_ptr(addr, tmp);
100     tcg_temp_free_i32(tmp);
101     tcg_gen_add_ptr(addr, addr, cpu_env);
102 
103     return addr;
104 }
105 
106 static bool trans_ZERO(DisasContext *s, arg_ZERO *a)
107 {
108     if (!dc_isar_feature(aa64_sme, s)) {
109         return false;
110     }
111     if (sme_za_enabled_check(s)) {
112         gen_helper_sme_zero(cpu_env, tcg_constant_i32(a->imm),
113                             tcg_constant_i32(streaming_vec_reg_size(s)));
114     }
115     return true;
116 }
117 
118 static bool trans_MOVA(DisasContext *s, arg_MOVA *a)
119 {
120     static gen_helper_gvec_4 * const h_fns[5] = {
121         gen_helper_sve_sel_zpzz_b, gen_helper_sve_sel_zpzz_h,
122         gen_helper_sve_sel_zpzz_s, gen_helper_sve_sel_zpzz_d,
123         gen_helper_sve_sel_zpzz_q
124     };
125     static gen_helper_gvec_3 * const cz_fns[5] = {
126         gen_helper_sme_mova_cz_b, gen_helper_sme_mova_cz_h,
127         gen_helper_sme_mova_cz_s, gen_helper_sme_mova_cz_d,
128         gen_helper_sme_mova_cz_q,
129     };
130     static gen_helper_gvec_3 * const zc_fns[5] = {
131         gen_helper_sme_mova_zc_b, gen_helper_sme_mova_zc_h,
132         gen_helper_sme_mova_zc_s, gen_helper_sme_mova_zc_d,
133         gen_helper_sme_mova_zc_q,
134     };
135 
136     TCGv_ptr t_za, t_zr, t_pg;
137     TCGv_i32 t_desc;
138     int svl;
139 
140     if (!dc_isar_feature(aa64_sme, s)) {
141         return false;
142     }
143     if (!sme_smza_enabled_check(s)) {
144         return true;
145     }
146 
147     t_za = get_tile_rowcol(s, a->esz, a->rs, a->za_imm, a->v);
148     t_zr = vec_full_reg_ptr(s, a->zr);
149     t_pg = pred_full_reg_ptr(s, a->pg);
150 
151     svl = streaming_vec_reg_size(s);
152     t_desc = tcg_constant_i32(simd_desc(svl, svl, 0));
153 
154     if (a->v) {
155         /* Vertical slice -- use sme mova helpers. */
156         if (a->to_vec) {
157             zc_fns[a->esz](t_zr, t_za, t_pg, t_desc);
158         } else {
159             cz_fns[a->esz](t_za, t_zr, t_pg, t_desc);
160         }
161     } else {
162         /* Horizontal slice -- reuse sve sel helpers. */
163         if (a->to_vec) {
164             h_fns[a->esz](t_zr, t_za, t_zr, t_pg, t_desc);
165         } else {
166             h_fns[a->esz](t_za, t_zr, t_za, t_pg, t_desc);
167         }
168     }
169 
170     tcg_temp_free_ptr(t_za);
171     tcg_temp_free_ptr(t_zr);
172     tcg_temp_free_ptr(t_pg);
173 
174     return true;
175 }
176 
177 static bool trans_LDST1(DisasContext *s, arg_LDST1 *a)
178 {
179     typedef void GenLdSt1(TCGv_env, TCGv_ptr, TCGv_ptr, TCGv, TCGv_i32);
180 
181     /*
182      * Indexed by [esz][be][v][mte][st], which is (except for load/store)
183      * also the order in which the elements appear in the function names,
184      * and so how we must concatenate the pieces.
185      */
186 
187 #define FN_LS(F)     { gen_helper_sme_ld1##F, gen_helper_sme_st1##F }
188 #define FN_MTE(F)    { FN_LS(F), FN_LS(F##_mte) }
189 #define FN_HV(F)     { FN_MTE(F##_h), FN_MTE(F##_v) }
190 #define FN_END(L, B) { FN_HV(L), FN_HV(B) }
191 
192     static GenLdSt1 * const fns[5][2][2][2][2] = {
193         FN_END(b, b),
194         FN_END(h_le, h_be),
195         FN_END(s_le, s_be),
196         FN_END(d_le, d_be),
197         FN_END(q_le, q_be),
198     };
199 
200 #undef FN_LS
201 #undef FN_MTE
202 #undef FN_HV
203 #undef FN_END
204 
205     TCGv_ptr t_za, t_pg;
206     TCGv_i64 addr;
207     int svl, desc = 0;
208     bool be = s->be_data == MO_BE;
209     bool mte = s->mte_active[0];
210 
211     if (!dc_isar_feature(aa64_sme, s)) {
212         return false;
213     }
214     if (!sme_smza_enabled_check(s)) {
215         return true;
216     }
217 
218     t_za = get_tile_rowcol(s, a->esz, a->rs, a->za_imm, a->v);
219     t_pg = pred_full_reg_ptr(s, a->pg);
220     addr = tcg_temp_new_i64();
221 
222     tcg_gen_shli_i64(addr, cpu_reg(s, a->rm), a->esz);
223     tcg_gen_add_i64(addr, addr, cpu_reg_sp(s, a->rn));
224 
225     if (mte) {
226         desc = FIELD_DP32(desc, MTEDESC, MIDX, get_mem_index(s));
227         desc = FIELD_DP32(desc, MTEDESC, TBI, s->tbid);
228         desc = FIELD_DP32(desc, MTEDESC, TCMA, s->tcma);
229         desc = FIELD_DP32(desc, MTEDESC, WRITE, a->st);
230         desc = FIELD_DP32(desc, MTEDESC, SIZEM1, (1 << a->esz) - 1);
231         desc <<= SVE_MTEDESC_SHIFT;
232     } else {
233         addr = clean_data_tbi(s, addr);
234     }
235     svl = streaming_vec_reg_size(s);
236     desc = simd_desc(svl, svl, desc);
237 
238     fns[a->esz][be][a->v][mte][a->st](cpu_env, t_za, t_pg, addr,
239                                       tcg_constant_i32(desc));
240 
241     tcg_temp_free_ptr(t_za);
242     tcg_temp_free_ptr(t_pg);
243     tcg_temp_free_i64(addr);
244     return true;
245 }
246 
247 typedef void GenLdStR(DisasContext *, TCGv_ptr, int, int, int, int);
248 
249 static bool do_ldst_r(DisasContext *s, arg_ldstr *a, GenLdStR *fn)
250 {
251     int svl = streaming_vec_reg_size(s);
252     int imm = a->imm;
253     TCGv_ptr base;
254 
255     if (!sme_za_enabled_check(s)) {
256         return true;
257     }
258 
259     /* ZA[n] equates to ZA0H.B[n]. */
260     base = get_tile_rowcol(s, MO_8, a->rv, imm, false);
261 
262     fn(s, base, 0, svl, a->rn, imm * svl);
263 
264     tcg_temp_free_ptr(base);
265     return true;
266 }
267 
268 TRANS_FEAT(LDR, aa64_sme, do_ldst_r, a, gen_sve_ldr)
269 TRANS_FEAT(STR, aa64_sme, do_ldst_r, a, gen_sve_str)
270 
271 static bool do_adda(DisasContext *s, arg_adda *a, MemOp esz,
272                     gen_helper_gvec_4 *fn)
273 {
274     int svl = streaming_vec_reg_size(s);
275     uint32_t desc = simd_desc(svl, svl, 0);
276     TCGv_ptr za, zn, pn, pm;
277 
278     if (!sme_smza_enabled_check(s)) {
279         return true;
280     }
281 
282     /* Sum XZR+zad to find ZAd. */
283     za = get_tile_rowcol(s, esz, 31, a->zad, false);
284     zn = vec_full_reg_ptr(s, a->zn);
285     pn = pred_full_reg_ptr(s, a->pn);
286     pm = pred_full_reg_ptr(s, a->pm);
287 
288     fn(za, zn, pn, pm, tcg_constant_i32(desc));
289 
290     tcg_temp_free_ptr(za);
291     tcg_temp_free_ptr(zn);
292     tcg_temp_free_ptr(pn);
293     tcg_temp_free_ptr(pm);
294     return true;
295 }
296 
297 TRANS_FEAT(ADDHA_s, aa64_sme, do_adda, a, MO_32, gen_helper_sme_addha_s)
298 TRANS_FEAT(ADDVA_s, aa64_sme, do_adda, a, MO_32, gen_helper_sme_addva_s)
299 TRANS_FEAT(ADDHA_d, aa64_sme_i16i64, do_adda, a, MO_64, gen_helper_sme_addha_d)
300 TRANS_FEAT(ADDVA_d, aa64_sme_i16i64, do_adda, a, MO_64, gen_helper_sme_addva_d)
301 
302 static bool do_outprod(DisasContext *s, arg_op *a, MemOp esz,
303                        gen_helper_gvec_5 *fn)
304 {
305     int svl = streaming_vec_reg_size(s);
306     uint32_t desc = simd_desc(svl, svl, a->sub);
307     TCGv_ptr za, zn, zm, pn, pm;
308 
309     if (!sme_smza_enabled_check(s)) {
310         return true;
311     }
312 
313     /* Sum XZR+zad to find ZAd. */
314     za = get_tile_rowcol(s, esz, 31, a->zad, false);
315     zn = vec_full_reg_ptr(s, a->zn);
316     zm = vec_full_reg_ptr(s, a->zm);
317     pn = pred_full_reg_ptr(s, a->pn);
318     pm = pred_full_reg_ptr(s, a->pm);
319 
320     fn(za, zn, zm, pn, pm, tcg_constant_i32(desc));
321 
322     tcg_temp_free_ptr(za);
323     tcg_temp_free_ptr(zn);
324     tcg_temp_free_ptr(pn);
325     tcg_temp_free_ptr(pm);
326     return true;
327 }
328 
329 static bool do_outprod_fpst(DisasContext *s, arg_op *a, MemOp esz,
330                             gen_helper_gvec_5_ptr *fn)
331 {
332     int svl = streaming_vec_reg_size(s);
333     uint32_t desc = simd_desc(svl, svl, a->sub);
334     TCGv_ptr za, zn, zm, pn, pm, fpst;
335 
336     if (!sme_smza_enabled_check(s)) {
337         return true;
338     }
339 
340     /* Sum XZR+zad to find ZAd. */
341     za = get_tile_rowcol(s, esz, 31, a->zad, false);
342     zn = vec_full_reg_ptr(s, a->zn);
343     zm = vec_full_reg_ptr(s, a->zm);
344     pn = pred_full_reg_ptr(s, a->pn);
345     pm = pred_full_reg_ptr(s, a->pm);
346     fpst = fpstatus_ptr(FPST_FPCR);
347 
348     fn(za, zn, zm, pn, pm, fpst, tcg_constant_i32(desc));
349 
350     tcg_temp_free_ptr(za);
351     tcg_temp_free_ptr(zn);
352     tcg_temp_free_ptr(pn);
353     tcg_temp_free_ptr(pm);
354     tcg_temp_free_ptr(fpst);
355     return true;
356 }
357 
358 TRANS_FEAT(FMOPA_h, aa64_sme, do_outprod_fpst, a, MO_32, gen_helper_sme_fmopa_h)
359 TRANS_FEAT(FMOPA_s, aa64_sme, do_outprod_fpst, a, MO_32, gen_helper_sme_fmopa_s)
360 TRANS_FEAT(FMOPA_d, aa64_sme_f64f64, do_outprod_fpst, a, MO_64, gen_helper_sme_fmopa_d)
361 
362 /* TODO: FEAT_EBF16 */
363 TRANS_FEAT(BFMOPA, aa64_sme, do_outprod, a, MO_32, gen_helper_sme_bfmopa)
364 
365 TRANS_FEAT(SMOPA_s, aa64_sme, do_outprod, a, MO_32, gen_helper_sme_smopa_s)
366 TRANS_FEAT(UMOPA_s, aa64_sme, do_outprod, a, MO_32, gen_helper_sme_umopa_s)
367 TRANS_FEAT(SUMOPA_s, aa64_sme, do_outprod, a, MO_32, gen_helper_sme_sumopa_s)
368 TRANS_FEAT(USMOPA_s, aa64_sme, do_outprod, a, MO_32, gen_helper_sme_usmopa_s)
369 
370 TRANS_FEAT(SMOPA_d, aa64_sme_i16i64, do_outprod, a, MO_64, gen_helper_sme_smopa_d)
371 TRANS_FEAT(UMOPA_d, aa64_sme_i16i64, do_outprod, a, MO_64, gen_helper_sme_umopa_d)
372 TRANS_FEAT(SUMOPA_d, aa64_sme_i16i64, do_outprod, a, MO_64, gen_helper_sme_sumopa_d)
373 TRANS_FEAT(USMOPA_d, aa64_sme_i16i64, do_outprod, a, MO_64, gen_helper_sme_usmopa_d)
374