1 /*******************************************************************************
2 * Copyright 2018-2021 Intel Corporation
3 * Copyright 2020-2021 FUJITSU LIMITED
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *******************************************************************************/
17 
18 #include <assert.h>
19 #include <numeric>
20 
21 #include "dnnl_debug.h"
22 
23 #include "common/c_types_map.hpp"
24 #include "common/memory_desc_wrapper.hpp"
25 #include "common/nstl.hpp"
26 #include "common/primitive.hpp"
27 #include "common/type_helpers.hpp"
28 #include "common/utils.hpp"
29 
30 #include "cpu/aarch64/jit_uni_reorder.hpp"
31 #include "cpu/cpu_primitive.hpp"
32 #include "cpu/reorder/cpu_reorder_pd.hpp"
33 
34 #include "cpu/aarch64/jit_generator.hpp"
35 
36 // #define TR_DEBUG
37 #if defined(TR_DEBUG)
38 #define DEBUg(...) \
39     do { \
40         __VA_ARGS__ \
41     } while (0)
42 #else
43 #define DEBUg(...)
44 #endif
45 #define DEBUG(...) DEBUg(__VA_ARGS__)
46 
47 using namespace Xbyak_aarch64;
48 using namespace dnnl::impl::types;
49 
50 namespace dnnl {
51 namespace impl {
52 namespace cpu {
53 namespace aarch64 {
54 
55 namespace tr {
56 
57 /** Minimal reasonable/desirable kernel size.
58  * The constant might be used to determine how a problem should be split
59  * between kernel and threading driver. */
60 const size_t ker_prb_size_min = 64;
61 
62 /* kernel */
63 struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
DECLARE_CPU_JIT_AUX_FUNCTIONSdnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t64     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_reorder_kernel_f32)
65 
66     void operator()(const call_param_t *c) const override {
67         jit_generator::operator()(c);
68     }
69 
create_kerneldnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t70     status_t create_kernel() override { return jit_generator::create_kernel(); }
71 
72     enum {
73         len_unroll_max = 256,
74         ndims_jit_loop_max = 3,
75     };
76 
77     struct simple_impl_desc_t {
78         int ndims_full_unroll;
79         int len_last_dim_unroll;
80         int len_unroll;
81     };
82 
simple_impl_desc_initdnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t83     static bool simple_impl_desc_init(
84             const prb_t &prb, simple_impl_desc_t *desc) {
85         const int ndims = prb.ndims;
86 
87         int ndims_full_unroll = 0;
88         int len_last_dim_unroll = 1;
89         int len_unroll = 1;
90 
91         for (int d = 0; d < ndims; ++d) {
92             auto &node = prb.nodes[d];
93             if (len_unroll * node.n <= len_unroll_max) {
94                 ndims_full_unroll++;
95                 len_unroll *= node.n;
96             } else {
97                 len_last_dim_unroll = len_unroll_max / len_unroll;
98                 while (node.n % len_last_dim_unroll)
99                     --len_last_dim_unroll;
100                 len_unroll *= len_last_dim_unroll;
101                 break;
102             }
103         }
104 
105         if (prb.ndims - ndims_full_unroll > ndims_jit_loop_max) return false;
106 
107         if (desc) {
108             desc->ndims_full_unroll = ndims_full_unroll;
109             desc->len_last_dim_unroll = len_last_dim_unroll;
110             desc->len_unroll = len_unroll;
111         }
112 
113         return true;
114     }
115 
applicablednnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t116     static bool applicable(const prb_t &p) {
117         using namespace data_type;
118 
119         bool ok = true && p.ndims > 0
120                 && utils::one_of(p.itype, f32, s32, data_type::s8, u8)
121                 && utils::one_of(p.otype, f32, s32, data_type::s8, u8)
122                 && utils::everyone_is(0, p.ioff, p.ooff) /* do we need this? */
123                 && utils::one_of(p.beta, 0.f, 1.f) /* anything else? */
124                 && simple_impl_desc_init(p, nullptr);
125         if (!ok) return false;
126 
127         const ptrdiff_t max_stride = (1LL << 31) - 1;
128         for (int d = 0; d < p.ndims; ++d) {
129             const ptrdiff_t cms = max_stride / p.nodes[d].n;
130             bool strides_ok = true
131                     && p.nodes[d].is < cms / (int)data_type_size(p.itype)
132                     && p.nodes[d].os < cms / (int)data_type_size(p.otype);
133             if (!strides_ok) return false;
134         }
135 
136         return true;
137     }
138 
ndnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t139     int n(int d) {
140         assert(d < prb_.ndims);
141         return (int)prb_.nodes[d].n;
142     }
isdnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t143     int is(int d) {
144         assert(d < prb_.ndims);
145         return (int)prb_.nodes[d].is;
146     }
osdnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t147     int os(int d) {
148         assert(d < prb_.ndims);
149         return (int)prb_.nodes[d].os;
150     }
ssdnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t151     int ss(int d) {
152         assert(d < prb_.ndims);
153         return (int)prb_.nodes[d].ss;
154     }
155 
stepdnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t156     void step(int off, int prev_i_off, int prev_o_off, int prev_s_off,
157             int &i_off, int &o_off, int &s_off, int step_size = 1) {
158         i_off = prev_i_off;
159         o_off = prev_o_off;
160         s_off = prev_s_off;
161 
162         if (off == 0) return;
163 
164         int start_dim = 0, dims_prod = 1;
165         for (; start_dim < prb_.ndims && dims_prod != step_size; ++start_dim)
166             dims_prod *= n(start_dim);
167         assert(start_dim < prb_.ndims);
168         off /= step_size;
169 
170         for (int d = start_dim; d < prb_.ndims; ++d) {
171             i_off += is(d);
172             o_off += os(d);
173             s_off += ss(d);
174 
175             if (off % n(d)) break;
176 
177             i_off += -n(d) * is(d);
178             o_off += -n(d) * os(d);
179             s_off += -n(d) * ss(d);
180             off /= n(d);
181 
182             if (off == 0) break; /* FIXME: is it really required? */
183         }
184     }
185 
stepdnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t186     void step(int off, int prev_i_off, int prev_o_off, int &i_off, int &o_off,
187             int step_size = 1) {
188         int dummy = 0;
189         step(off, prev_i_off, prev_o_off, dummy, i_off, o_off, dummy,
190                 step_size);
191     }
192 
tr8x8_sve256dnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t193     void tr8x8_sve256(int i_off, int o_off) {
194         using namespace data_type;
195 
196         const auto cvt2ps
197                 = [=](const int startIdx, const int regNum, data_type_t idt) {
198                       switch (idt) {
199                           case f32:
200                               /* do nothing */
201                               break;
202                           case s32: cvt_z_s32_f32(startIdx, regNum); break;
203                           case data_type::s8:
204                               cvt_z_s8_s32(startIdx, regNum);
205                               cvt_z_s32_f32(startIdx, regNum);
206                               break;
207                           case u8:
208                               cvt_z_u8_s32(startIdx, regNum);
209                               cvt_z_s32_f32(startIdx, regNum);
210                               break;
211                           default: assert(!"unreachable");
212                       }
213                   };
214 
215         const auto cvt2odt = [=](const int startIdx, const int regNum,
216                                      data_type_t odt, data_type_t idt) {
217             switch (odt) {
218                 case s32:
219                     if (idt == f32)
220                         cvt_z_f32_s32(startIdx, regNum);
221                     else if (idt == data_type::s8)
222                         cvt_z_s8_s32(startIdx, regNum);
223                     else if (idt == u8)
224                         cvt_z_u8_s32(startIdx, regNum);
225                     break;
226                 case data_type::s8:
227                     if (idt == f32) cvt_z_f32_s32(startIdx, regNum);
228                     if (utils::one_of(idt, f32, s32))
229                         cvt_z_s32_s8(startIdx, regNum);
230                     if (idt == u8) cvt_z_u8_s8(startIdx, regNum);
231                     break;
232                 case u8:
233                     if (idt == f32) cvt_z_f32_s32(startIdx, regNum);
234                     if (utils::one_of(idt, f32, s32))
235                         cvt_z_s32_u8(startIdx, regNum);
236                     if (idt == data_type::s8) cvt_z_s8_u8(startIdx, regNum);
237                     break;
238                 default: assert(!"unreachable");
239             }
240         };
241 
242         const int unroll = 8;
243 
244         const bool interim_f32 = (prb_.itype != f32)
245                 || utils::one_of(f32, prb_.itype, prb_.otype);
246 
247         const bool need_saturation
248                 = (utils::one_of(prb_.otype, u8, data_type::s8, s32)
249                         && interim_f32);
250         const uint64_t sveLen = get_sve_length();
251 
252         add_imm(X_TMP_0, XReg(x_ptr_in_off), i_off * itype_sz, X_DEFAULT_ADDR);
253         add_imm(X_TMP_1, X_TMP_0, is(0) * itype_sz, X_DEFAULT_ADDR);
254         add_imm(X_TMP_2, X_TMP_1, is(0) * itype_sz, X_DEFAULT_ADDR);
255         add_imm(X_TMP_3, X_TMP_2, is(0) * itype_sz, X_DEFAULT_ADDR);
256 
257         if (unroll * itype_sz == 32)
258             for (uint32_t i = 0; i < 4; i++)
259                 ld1w(ZRegS {i}, p_lsb_256 / T_z, ptr(x_tmp_vec[i]));
260         else if (unroll * itype_sz == 16)
261             for (uint32_t i = 0; i < 4; i++)
262                 ldr(QReg {i}, ptr(x_tmp_vec[i]));
263         else if (unroll * itype_sz == 8)
264             for (uint32_t i = 0; i < 4; i++)
265                 ldr(DReg {i}, ptr(x_tmp_vec[i]));
266 
267         add_imm(X_TMP_0, X_TMP_3, is(0) * itype_sz, X_DEFAULT_ADDR);
268         add_imm(X_TMP_1, X_TMP_0, is(0) * itype_sz, X_DEFAULT_ADDR);
269         add_imm(X_TMP_2, X_TMP_1, is(0) * itype_sz, X_DEFAULT_ADDR);
270         add_imm(X_TMP_3, X_TMP_2, is(0) * itype_sz, X_DEFAULT_ADDR);
271 
272         if (unroll * itype_sz == 32)
273             for (uint32_t i = 0; i < 4; i++)
274                 ld1w(ZRegS {4 + i}, p_lsb_256 / T_z, ptr(x_tmp_vec[i]));
275         else if (unroll * itype_sz == 16)
276             for (uint32_t i = 0; i < 4; i++)
277                 ldr(QReg {4 + i}, ptr(x_tmp_vec[i]));
278         else if (unroll * itype_sz == 8)
279             for (uint32_t i = 0; i < 4; i++)
280                 ldr(DReg {4 + i}, ptr(x_tmp_vec[i]));
281 
282         if (interim_f32) cvt2ps(0, unroll, prb_.itype);
283 
284 #if 0
285         /* Deubg code */
286         index(z0.s, 0, 1);
287         mov(z0.s, P_NOT_256/T_m, 0);
288         mov(z_tmp_vec[0].s, 16);
289         for(uint32_t i=1; i<8; i++) {
290           add(ZRegS{i}, ZRegS{i-1}, z_tmp_vec[0].s);
291           mov(ZRegS{i}, P_NOT_256/T_m, 0);
292         }
293 #endif
294 
295         ptrue(p_tmp0.s, VL4);
296         /* 1st turn */
297         for (uint32_t i = 0; i < unroll / 2; i++) {
298             trn1(z_tmp_vec[i].s, ZRegS {2 * i}, ZRegS {2 * i + 1});
299             trn2(z_tmp_vec[unroll / 2 + i].s, ZRegS {2 * i}, ZRegS {2 * i + 1});
300         }
301 
302         /* 2nd turn */
303         trn1(z4.d, z_tmp_vec[0].d, z_tmp_vec[1].d);
304         trn1(z5.d, z_tmp_vec[4].d, z_tmp_vec[5].d);
305         trn2(z6.d, z_tmp_vec[0].d, z_tmp_vec[1].d);
306         trn2(z7.d, z_tmp_vec[4].d, z_tmp_vec[5].d);
307         trn1(z_tmp_vec[0].d, z_tmp_vec[2].d, z_tmp_vec[3].d);
308         trn1(z_tmp_vec[1].d, z_tmp_vec[6].d, z_tmp_vec[7].d);
309         trn2(z_tmp_vec[2].d, z_tmp_vec[2].d, z_tmp_vec[3].d);
310         trn2(z_tmp_vec[3].d, z_tmp_vec[6].d, z_tmp_vec[7].d);
311 
312         /* 3rd turn */
313         for (uint32_t i = 0; i < unroll / 2; i++) {
314             mov(ZRegD {i}, ZRegD {unroll / 2 + i});
315             mov(z_tmp_vec[unroll / 2 + i].d, z_tmp_vec[i].d);
316         }
317 
318         /* 4th turn */
319         for (uint32_t i = 0; i < unroll / 2; i++) {
320             ZRegB z {unroll / 2 + i};
321             ZRegB z_tmp = z_tmp_vec[unroll / 2 + i].b;
322             /* Move bit 128-255 to 0-127. */
323             ext(z, z, 16);
324             /* Move bit 0-127 to 128-255. */
325             ext(z_tmp, z_tmp, sveLen - 16);
326         }
327 
328         /* 5th turn */
329         for (uint32_t i = 0; i < unroll / 2; i++) {
330             ZRegS z0 {i};
331             ZRegS z1 {unroll / 2 + i};
332             sel(z0, p_tmp0.s, z0, z_tmp_vec[unroll / 2 + i].s);
333             sel(z1, p_tmp0, z1, z_tmp_vec[i].s);
334         }
335 
336         if (need_saturation) {
337             init_saturate_f32(ymm_zero, ymm_saturation_ubound, reg_tmp,
338                     interim_f32 ? f32 : prb_.itype, prb_.otype);
339             for (int i = 0; i < unroll; i++)
340                 saturate_f32(ZRegS(i), ymm_zero, ymm_saturation_ubound,
341                         prb_.otype, p_all);
342         }
343 
344         if (prb_.otype != f32)
345             cvt2odt(0, unroll, prb_.otype, interim_f32 ? f32 : prb_.itype);
346 
347         add_imm(X_TMP_0, XReg(x_ptr_out_off), o_off * otype_sz, X_DEFAULT_ADDR);
348         add_imm(X_TMP_1, X_TMP_0, os(1) * otype_sz, X_DEFAULT_ADDR);
349         add_imm(X_TMP_2, X_TMP_1, os(1) * otype_sz, X_DEFAULT_ADDR);
350         add_imm(X_TMP_3, X_TMP_2, os(1) * otype_sz, X_DEFAULT_ADDR);
351 
352         if (unroll * otype_sz == 32)
353             for (uint32_t i = 0; i < 4; i++)
354                 st1w(ZRegS {i}, p_lsb_256 / T_z, ptr(x_tmp_vec[i]));
355         else if (unroll * otype_sz == 16)
356             for (uint32_t i = 0; i < 4; i++)
357                 str(QReg {i}, ptr(x_tmp_vec[i]));
358         else if (unroll * otype_sz == 8)
359             for (uint32_t i = 0; i < 4; i++)
360                 str(DReg {i}, ptr(x_tmp_vec[i]));
361 
362         add_imm(X_TMP_0, X_TMP_3, os(1) * otype_sz, X_DEFAULT_ADDR);
363         add_imm(X_TMP_1, X_TMP_0, os(1) * otype_sz, X_DEFAULT_ADDR);
364         add_imm(X_TMP_2, X_TMP_1, os(1) * otype_sz, X_DEFAULT_ADDR);
365         add_imm(X_TMP_3, X_TMP_2, os(1) * otype_sz, X_DEFAULT_ADDR);
366 
367         if (unroll * otype_sz == 32)
368             for (uint32_t i = 0; i < 4; i++)
369                 st1w(ZRegS {4 + i}, p_lsb_256 / T_z, ptr(x_tmp_vec[i]));
370         else if (unroll * otype_sz == 16)
371             for (uint32_t i = 0; i < 4; i++)
372                 str(QReg {4 + i}, ptr(x_tmp_vec[i]));
373         else if (unroll * otype_sz == 8)
374             for (uint32_t i = 0; i < 4; i++)
375                 str(DReg {4 + i}, ptr(x_tmp_vec[i]));
376     }
377 
can_do_tr8x8dnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t378     bool can_do_tr8x8() {
379         using namespace data_type;
380 
381         return get_sve_length() >= Xbyak_aarch64::util::SVE_256
382                 && prb_.ndims >= 2
383                 && ((utils::one_of(prb_.itype, u8, data_type::s8, s32, f32)
384                         && utils::one_of(
385                                 prb_.otype, u8, data_type::s8, s32, f32)))
386                 && utils::everyone_is(8, n(0), n(1))
387                 && utils::everyone_is(1, os(0), is(1))
388                 && prb_.scale_type == scale_type_t::NONE && prb_.beta == 0.f;
389     }
390 
process_unroll_tr8x8dnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t391     bool process_unroll_tr8x8(int len) {
392         if (!can_do_tr8x8()) return false;
393 
394         const int step_size = n(0) * n(1);
395         int i_off = 0, o_off = 0;
396         for (int off = 0; off < len; off += step_size) {
397             step(off, i_off, o_off, i_off, o_off, step_size);
398             tr8x8_sve256(i_off, o_off);
399         }
400 
401         return true;
402     }
403 
404     template <cpu_isa_t isa>
process_direct_copydnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t405     bool process_direct_copy(int len) {
406         using namespace data_type;
407 
408         const int simd_w = cpu_isa_traits<isa>::vlen == 16
409                 ? cpu_isa_traits<isa>::vlen / itype_sz /* use 128-bit VReg */
410                 : cpu_isa_traits<isa>::vlen / itype_sz
411                         / 2; /* use lower half of 512-bit ZReg */
412 
413         bool can_do = true && mayiuse(isa)
414                 && utils::everyone_is(1, os(0), is(0))
415                 && (false || prb_.itype == prb_.otype
416                         || (prb_.itype == s32 && prb_.otype == f32)
417                         || (prb_.itype == f32 && prb_.otype == s32))
418                 && len % simd_w == 0 && n(0) % len == 0
419                 && prb_.scale_type == scale_type_t::NONE && prb_.beta == 0.f;
420         if (!can_do) return false;
421 
422         for (int off = 0; off < len;) {
423             const int unroll
424                     = nstl::min(16 - (prb_.otype == s32), (len - off) / simd_w);
425 
426             int ur = 0;
427             int tmp_ur = 0;
428             while (ur < unroll) {
429                 int count = 0;
430                 const int vlen = cpu_isa_traits<isa>::vlen;
431 
432                 do {
433                     add_imm(x_tmp_vec[count++], x_ptr_in_off,
434                             (off + ur * simd_w) * itype_sz, X_DEFAULT_ADDR);
435                     ur++;
436                 } while (ur < unroll && count < x_tmp_vec_size);
437 
438                 for (int i = 0; i < count; i++) {
439                     /*                    if (vlen == 64)
440                         ldr(ZReg(tmp_ur + i), ptr(x_tmp_vec[i]));
441                         else */
442                     if (vlen == 64 || vlen == 32)
443                         ld1w(ZRegS(tmp_ur + i), p_lsb_256 / T_z,
444                                 ptr(x_tmp_vec[i]));
445                     else if (vlen == 16)
446                         ldr(QReg(tmp_ur + i), ptr(x_tmp_vec[i]));
447                     else
448                         assert(!"unreachable");
449                 }
450                 tmp_ur += count;
451             }
452 
453             if (prb_.itype != prb_.otype) {
454                 const int vlen = cpu_isa_traits<isa>::vlen;
455                 for (int ur = 0; ur < unroll; ++ur) {
456                     if (prb_.itype == s32 && prb_.otype == f32) {
457                         if (vlen == 64 || vlen == 32) {
458                             ZRegS r(ur);
459                             /* MSB side 256 bits are ignored. */
460                             scvtf(r, p_all / T_m, r);
461                         } else if (vlen == 16) {
462                             VReg4S r(ur);
463                             scvtf(r, r);
464                         } else
465                             assert(!"unreachable");
466                     } else if (prb_.itype == f32 && prb_.otype == s32) {
467                         /* Out of order can be expected. */
468                         if (vlen == 64 || vlen == 32) {
469                             ZRegS r(ur);
470                             frinti(r, p_all / T_m, r);
471                             fcvtzs(r, p_all / T_m, r);
472                         } else if (vlen == 16) {
473                             VReg4S r(ur);
474                             frinti(r, r);
475                             fcvtzs(r, r);
476                         } else
477                             assert(!"unreachable");
478                     } else
479                         assert(!"unreachable");
480                 }
481             }
482 
483             ur = 0;
484             tmp_ur = 0;
485             while (ur < unroll) {
486                 int count = 0;
487                 const int vlen = cpu_isa_traits<isa>::vlen;
488 
489                 do {
490                     add_imm(x_tmp_vec[count++], x_ptr_out_off,
491                             (off + ur * simd_w) * otype_sz, X_DEFAULT_ADDR);
492                     ur++;
493                 } while (ur < unroll && count < x_tmp_vec_size);
494 
495                 for (int i = 0; i < count; i++) {
496                     if (vlen == 64 || vlen == 32)
497                         st1w(ZRegS(tmp_ur + i), p_lsb_256 / T_z,
498                                 ptr(x_tmp_vec[i]));
499                     else if (vlen == 16)
500                         str(QReg(tmp_ur + i), ptr(x_tmp_vec[i]));
501                     else
502                         assert(!"unreachable");
503                 }
504                 tmp_ur += count;
505             }
506 
507             off += unroll * simd_w;
508         }
509 
510         return true;
511     }
512 
process_unroll_generic_stepdnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t513     void process_unroll_generic_step(int reg_unroll, const int *i_off,
514             const int *o_off, const int *s_off) {
515         using namespace data_type;
516 
517         auto cvt2ps
518                 = [=](const int startIdx, const int regNum, data_type_t idt) {
519                       switch (idt) {
520                           case f32:
521                               /* do nothing */
522                               break;
523                           case s32: cvt_v_s32_f32(startIdx, regNum); break;
524                           case data_type::s8:
525                               cvt_v_s8_s32(startIdx, regNum);
526                               cvt_v_s32_f32(startIdx, regNum);
527                               break;
528                           case u8:
529                               cvt_v_u8_s32(startIdx, regNum);
530                               cvt_v_s32_f32(startIdx, regNum);
531                               break;
532                           default: assert(!"unreachable");
533                       }
534                   };
535 
536         auto cvt2odt = [=](const int startIdx, const int regNum,
537                                data_type_t odt, data_type_t idt) {
538             switch (odt) {
539                 case s32:
540                     if (idt == f32)
541                         cvt_v_f32_s32(startIdx, regNum);
542                     else if (idt == data_type::s8)
543                         cvt_v_s8_s32(startIdx, regNum);
544                     else if (idt == u8)
545                         cvt_v_u8_s32(startIdx, regNum);
546                     break;
547                 case data_type::s8:
548                     if (idt == f32) cvt_v_f32_s32(startIdx, regNum);
549                     if (idt == f32 || idt == s32)
550                         cvt_v_s32_s8(startIdx, regNum);
551                     if (idt == u8) { cvt_v_u8_s8(startIdx, regNum); }
552                     break;
553                 case u8:
554                     if (idt == f32) cvt_v_f32_s32(startIdx, regNum);
555                     if (idt == f32 || idt == s32)
556                         cvt_v_s32_u8(startIdx, regNum);
557                     if (idt == data_type::s8) cvt_v_s8_u8(startIdx, regNum);
558                     break;
559                 default: assert(!"unreachable");
560             }
561         };
562 
563         /* check whether loading 4 values at once is possible */
564         bool can_load_xmm = reg_unroll % 4 == 0;
565         for (int ur = 1; ur < reg_unroll; ++ur)
566             if (i_off[ur] != i_off[ur - 1] + 1) can_load_xmm = false;
567         const int load_step = can_load_xmm ? 4 : 1;
568 
569         /* check whether storing 4 values at once is possible */
570         bool can_store_xmm = reg_unroll % 4 == 0;
571         for (int ur = 1; ur < reg_unroll; ++ur)
572             if (o_off[ur] != o_off[ur - 1] + 1) can_store_xmm = false;
573         const int ur_step = can_store_xmm ? 4 : 1;
574 
575         const bool interim_f32 = false
576                 || utils::one_of(f32, prb_.itype, prb_.otype)
577                 || prb_.scale_type != scale_type_t::NONE || prb_.beta != 0.f;
578 
579         const bool need_saturation
580                 = (utils::one_of(prb_.otype, u8, data_type::s8, s32)
581                         && interim_f32);
582 
583         if (!can_load_xmm && can_store_xmm) {
584             assert(ur_step == 4);
585             /* load with stride */
586             for (int ur = 0; ur < reg_unroll; ur += ur_step) {
587 
588                 /* x_tmp_vec = X_TMP_0 - X_TMP_4
589                  Do not use X_TMP_? as the last arg. */
590                 for (int r = 0; r < ur_step; ++r) {
591                     add_imm(x_tmp_vec[r], x_ptr_in_off,
592                             i_off[ur + r] * itype_sz, X_DEFAULT_ADDR);
593                 }
594 
595                 for (int r = 0; r < ur_step; ++r) {
596                     if (itype_sz == 4)
597                         ld1(VReg4S(ur)[r], ptr(x_tmp_vec[r]));
598                     else if (itype_sz == 2)
599                         ld1(VReg8H(ur)[r], ptr(x_tmp_vec[r]));
600                     else
601                         ld1(VReg16B(ur)[r], ptr(x_tmp_vec[r]));
602                 }
603             }
604         } else {
605             int ur = 0;
606             int tmp_ur = 0;
607             while (ur < reg_unroll) {
608                 int count = 0;
609 
610                 do {
611                     add_imm(x_tmp_vec[count++], x_ptr_in_off,
612                             i_off[ur] * itype_sz, X_DEFAULT_ADDR);
613                     ur += load_step;
614                 } while (ur < reg_unroll && count < x_tmp_vec_size);
615 
616                 for (int i = 0; i < count; i++) {
617 
618                     switch (load_step * itype_sz) {
619                         case 16: ldr(QReg(tmp_ur), ptr(x_tmp_vec[i])); break;
620                         case 8: ldr(DReg(tmp_ur), ptr(x_tmp_vec[i])); break;
621                         case 4: ldr(SReg(tmp_ur), ptr(x_tmp_vec[i])); break;
622                         case 2: ldr(HReg(tmp_ur), ptr(x_tmp_vec[i])); break;
623                         case 1: ldr(BReg(tmp_ur), ptr(x_tmp_vec[i])); break;
624                         default: assert(!"unreachable");
625                     }
626                     tmp_ur += load_step;
627                 }
628             }
629         }
630 
631         /* xmm[:] <-- (f32)xmm[:] */
632         if (interim_f32) {
633             const int cvt_step = nstl::max(load_step, ur_step);
634             for (int ur = 0; ur < reg_unroll; ur += cvt_step)
635                 cvt2ps(ur, 1, prb_.itype);
636         }
637 
638         if (can_load_xmm && !can_store_xmm) {
639             const bool fast_return = true // transposition on the fly
640                     && prb_.scale_type != scale_type_t::MANY
641                     && prb_.beta == 0.f;
642             if (fast_return) {
643                 if (prb_.scale_type == scale_type_t::COMMON)
644                     for (int ur = 0; ur < reg_unroll; ur += load_step)
645                         fmul(VReg4S(ur), VReg4S(ur), xmm_scale);
646                 if (prb_.otype != f32) {
647                     init_saturate_f32(xmm_zero, xmm_saturation_ubound, reg_tmp,
648                             interim_f32 ? f32 : prb_.itype, prb_.otype);
649                     for (int ur = 0; ur < reg_unroll; ur += load_step)
650                         if (need_saturation)
651                             saturate_f32(VReg4S(ur), xmm_zero,
652                                     xmm_saturation_ubound, prb_.otype, p_all);
653 
654                     for (int ur = 0; ur < reg_unroll; ur += load_step)
655                         cvt2odt(ur, 1, prb_.otype,
656                                 interim_f32 ? f32 : prb_.itype);
657                 }
658                 /* load_step is 1 or 4. */
659                 for (int ur = 0; ur < reg_unroll; ur += load_step) {
660                     for (int r = 0; r < load_step; ++r) {
661                         add_imm(x_tmp_vec[r], x_ptr_out_off,
662                                 o_off[ur + r] * otype_sz, X_DEFAULT_ADDR);
663                     }
664 
665                     for (int r = 0; r < load_step; ++r) {
666                         if (otype_sz == 4)
667                             st1(VReg4S(ur)[r], ptr(x_tmp_vec[r]));
668                         else if (otype_sz == 2)
669                             st1(VReg8H(ur)[r], ptr(x_tmp_vec[r]));
670                         else
671                             st1(VReg16B(ur)[r], ptr(x_tmp_vec[r]));
672                     }
673                 }
674                 return;
675             }
676 
677             /* scatter elements of xmm into 4 xmms */
678             if (itype_sz == 4 || interim_f32) {
679                 for (int ur = 0; ur < reg_unroll; ur += load_step)
680                     for (int r = 1; r < load_step; ++r) {
681                         VReg4S v(ur);
682                         VReg4S v_r(ur + r);
683                         dup(VReg16B(ur + r), VReg16B(ur)[0]);
684                         ins(VReg4S(ur + r)[0], VReg4S(ur)[r]);
685                     }
686             } else {
687                 for (int ur = 0; ur < reg_unroll; ur += load_step)
688                     for (int r = 1; r < load_step; ++r)
689                         ext(VReg16B(ur + r), VReg16B(ur), VReg16B(ur),
690                                 itype_sz * r);
691             }
692         }
693 
694         /* scale and beta processing */
695         if (can_store_xmm) {
696             /* xmm <-- scale * xmm[:] */
697             if (prb_.scale_type == scale_type_t::COMMON) {
698                 for (int ur = 0; ur < reg_unroll; ur += ur_step)
699                     fmul(VReg4S(ur), VReg4S(ur), xmm_scale);
700             } else if (prb_.scale_type == scale_type_t::MANY) {
701                 enum class scale_load_type_t { bcast, load, gather };
702 
703                 for (int ur = 0; ur < reg_unroll; ur += ur_step) {
704                     scale_load_type_t scale_load_type
705                             = scale_load_type_t::bcast; // the best case
706 
707                     for (int r = ur + 1; r < ur + ur_step; ++r)
708                         if (s_off[r] != s_off[r - 1] + 0)
709                             scale_load_type = scale_load_type_t::load;
710 
711                     if (scale_load_type == scale_load_type_t::bcast) {
712                         VReg4S v(xmm_scale.getIdx());
713                         VReg4S v_dst(ur);
714                         add_imm(X_TMP_0, x_ptr_scale_off, s_off[ur] * stype_sz,
715                                 X_DEFAULT_ADDR);
716                         ldr(W_TMP_0, ptr(X_TMP_0));
717                         dup(v, W_TMP_0);
718                         fmul(v_dst, v_dst, v);
719                         continue;
720                     }
721 
722                     // bcast doesn't work, the next try -- load
723                     for (int r = ur + 1; r < ur + ur_step; ++r)
724                         if (s_off[r] != s_off[r - 1] + 1)
725                             scale_load_type = scale_load_type_t::gather;
726 
727                     if (scale_load_type == scale_load_type_t::load) {
728                         uint32_t idx = xmm_scale.getIdx();
729                         VReg4S v_dst(ur);
730                         add_imm(X_TMP_0, x_ptr_scale_off, s_off[ur] * stype_sz,
731                                 X_DEFAULT_ADDR);
732 
733                         ldr(QReg {idx}, ptr(X_TMP_0));
734                         fmul(v_dst, v_dst, VReg4S {idx});
735                         continue;
736                     }
737 
738                     // load doesn't work as well
739                     // so gather the scale factors one by one
740                     /*ur_step is 1 or 4. */
741                     for (int r = ur; r < ur + ur_step; ++r) {
742                         /* x_tmp_vec = X_TMP_0 - X_TMP_4
743                          Do not use X_TMP_? as the last arg. */
744                         add_imm(x_tmp_vec[r - ur], x_ptr_scale_off,
745                                 s_off[r] * stype_sz, X_DEFAULT_ADDR);
746                     }
747                     for (int r = ur; r < ur + ur_step; ++r) {
748                         VReg4S v(xmm_scale.getIdx());
749                         ld1(v[r - ur], ptr(x_tmp_vec[r - ur]));
750                     }
751                     fmul(VReg4S(ur), VReg4S(ur), xmm_scale);
752                 }
753             }
754 
755             /* dst <-- beta * dst + xmm[:] */
756             assert(prb_.beta == 0.f || prb_.beta == 1.f);
757             if (prb_.beta == 1.f) {
758                 int ur = 0;
759                 int tmp_ur = 0;
760 
761                 while (ur < reg_unroll) {
762                     int count = 0;
763 
764                     do {
765                         add_imm(x_tmp_vec[count++], x_ptr_out_off,
766                                 o_off[ur] * otype_sz, X_DEFAULT_ADDR);
767                         ur += ur_step;
768                     } while (ur < reg_unroll && count < x_tmp_vec_size);
769 
770                     assert(count <= z_tmp_vec_size);
771                     /* Firstly, data is loaded. */
772                     for (int i = 0; i < count; i++) {
773 
774                         if (prb_.otype == f32 || prb_.otype == s32) {
775                             ldr(QReg(tmp_vec_idx[i]), ptr(x_tmp_vec[i])); // bug
776                         } else if (prb_.otype == data_type::s8
777                                 || prb_.otype == u8) {
778                             ldr(SReg(tmp_vec_idx[i]), ptr(x_tmp_vec[i])); // bug
779                         } else
780                             assert(!"unreachable");
781                     }
782 
783                     /* Secondly, it is added. */
784                     if (prb_.otype == f32) {
785                         for (int i = 0; i < count; i++) {
786                             VReg4S v(tmp_ur);
787                             fadd(v, v, VReg4S(tmp_vec_idx[i]));
788                             tmp_ur += ur_step;
789                         }
790                     } else {
791                         for (int i = 0; i < count; i++) {
792                             /* cvt2ps() generate successive instructions
793                                which have save destination operand,
794                                but out of order can be expected. */
795                             cvt2ps(tmp_vec_idx[i], 1, prb_.otype);
796                         }
797                         for (int i = 0; i < count; i++) {
798                             VReg4S v(tmp_ur);
799                             fadd(v, v, VReg4S(tmp_vec_idx[i]));
800                             tmp_ur += ur_step;
801                         }
802                     }
803                 }
804             }
805         } else {
806             /* xmm[0] <-- scale * xmm[0] */
807             if (prb_.scale_type == scale_type_t::COMMON) {
808                 for (int ur = 0; ur < reg_unroll; ur += ur_step) {
809                     VReg4S tmp(ur);
810                     fmul(tmp, tmp, VReg4S(xmm_scale.getIdx()));
811                 }
812             } else if (prb_.scale_type == scale_type_t::MANY) {
813                 int ur = 0;
814                 int tmp_ur = 0;
815                 while (ur < reg_unroll) {
816                     int count = 0;
817 
818                     do {
819                         add_imm(x_tmp_vec[count++], x_ptr_scale_off,
820                                 s_off[ur] * stype_sz, X_DEFAULT_ADDR);
821                         ur += ur_step;
822                     } while (ur < reg_unroll && count < x_tmp_vec_size);
823 
824                     for (int i = 0; i < count; i++)
825                         ldr(SReg(tmp_vec_idx[i]), ptr(x_tmp_vec[i]));
826                     for (int i = 0; i < count; i++) {
827                         VReg4S tmp(tmp_ur + ur_step * i);
828                         fmul(tmp, tmp, VReg4S(tmp_vec_idx[i]));
829                     }
830 
831                     tmp_ur += ur_step * count;
832                 }
833             }
834 
835             /* dst <-- beta * dst + xmm[0] */
836             assert(prb_.beta == 0.f || prb_.beta == 1.f);
837             if (prb_.beta == 1.f) {
838                 int ur = 0;
839                 int tmp_ur = 0;
840                 while (ur < reg_unroll) {
841                     int count = 0;
842 
843                     do {
844                         add_imm(x_tmp_vec[count++], x_ptr_out_off,
845                                 o_off[ur] * otype_sz, X_DEFAULT_ADDR);
846                         ur += ur_step;
847                     } while (ur < reg_unroll && count < (x_tmp_vec_size / 2));
848 
849                     assert(static_cast<size_t>(count) <= z_tmp_vec.size());
850 
851                     if (prb_.otype == f32) {
852                         /* addss: dest[31:0] <- src1[31:0] + src2[31:0]
853                          dset[MAXVL-1:32] (Unmodified) */
854                         for (int i = 0; i < count; i++) {
855                             ld1(VReg4S(z_tmp_vec[i].getIdx())[0],
856                                     ptr(x_tmp_vec[i]));
857                         }
858                         for (int i = 0; i < count; i++) {
859                             SReg s {tmp_vec_idx[i]};
860                             fadd(s, s, SReg(tmp_ur + ur_step * i));
861                         }
862                         for (int i = 0; i < count; i++) {
863                             mov(VReg4S(tmp_ur)[0], VReg4S(tmp_vec_idx[i])[0]);
864                             tmp_ur += ur_step;
865                         }
866                     } else {
867                         for (int i = 0; i < count; i++) {
868                             if (prb_.otype == s32) {
869                                 ldr(SReg(tmp_vec_idx[i]), ptr(x_tmp_vec[i]));
870                             } else if (utils::one_of(
871                                                prb_.otype, data_type::s8, u8)) {
872                                 ldr(BReg(tmp_vec_idx[i]), ptr(x_tmp_vec[i]));
873                             } else {
874                                 assert(!"unsupported o_type");
875                             }
876                             cvt2ps(tmp_vec_idx[i], 1, prb_.otype);
877                         }
878                         for (int i = 0; i < count; i++) {
879                             VReg4S v(tmp_ur);
880                             fadd(v, v, VReg4S(tmp_vec_idx[i]));
881                             tmp_ur += ur_step;
882                         }
883                     }
884                 }
885             }
886         }
887 
888         if (need_saturation) {
889             init_saturate_f32(
890                     xmm_zero, xmm_saturation_ubound, reg_tmp, f32, prb_.otype);
891             for (int ur = 0; ur < reg_unroll; ur += ur_step) {
892                 saturate_f32(VReg4S(ur), xmm_zero, xmm_saturation_ubound,
893                         prb_.otype, p_all);
894             }
895         }
896 
897         for (int ur = 0; ur < reg_unroll; ur += ur_step) {
898             if (prb_.otype != f32)
899                 cvt2odt(ur, 1, prb_.otype, interim_f32 ? f32 : prb_.itype);
900         }
901 
902         int ur = 0;
903         int tmp_ur = 0;
904         while (ur < reg_unroll) {
905             int count = 0;
906 
907             do {
908                 add_imm(x_tmp_vec[count++], x_ptr_out_off, o_off[ur] * otype_sz,
909                         X_DEFAULT_ADDR);
910                 ur += ur_step;
911             } while (ur < reg_unroll && count < x_tmp_vec_size);
912 
913             for (int i = 0; i < count; i++) {
914 
915                 switch (ur_step * otype_sz) {
916                     case 16: str(QReg(tmp_ur), ptr(x_tmp_vec[i])); break;
917                     case 8: str(DReg(tmp_ur), ptr(x_tmp_vec[i])); break;
918                     case 4: str(SReg(tmp_ur), ptr(x_tmp_vec[i])); break;
919                     case 2: str(HReg(tmp_ur), ptr(x_tmp_vec[i])); break;
920                     case 1: str(BReg(tmp_ur), ptr(x_tmp_vec[i])); break;
921                     default: assert(!"unreachable");
922                 }
923                 tmp_ur += ur_step;
924             }
925         }
926     }
927 
process_unroll_genericdnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t928     void process_unroll_generic(int len) {
929         const int blk = 8;
930 
931         int i_off[2 * blk] = {0};
932         int o_off[2 * blk] = {0};
933         int s_off[2 * blk] = {0};
934 
935         int curr = 0; // will switch between 0 and 1
936 
937         for (int off = 0; off < len; off += blk) {
938             const int reg_unroll = nstl::min(off + blk, len) - off;
939 
940             /* compute offsets */
941             for (int ur = off != 0 ? 0 : 1; ur < reg_unroll; ++ur) {
942                 const int ur_c = curr * blk + ur;
943                 const int ur_p = (ur_c - 1 + 2 * blk) % (2 * blk); // prev ur
944                 step(off + ur, i_off[ur_p], o_off[ur_p], s_off[ur_p],
945                         i_off[ur_c], o_off[ur_c], s_off[ur_c]);
946             }
947 
948             process_unroll_generic_step(reg_unroll, i_off + curr * blk,
949                     o_off + curr * blk, s_off + curr * blk);
950 
951             curr = 1 - curr;
952         }
953     }
954 
loop_begindnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t955     void loop_begin(Label &l, XReg reg_cnt, int len) {
956         mov(reg_cnt, len);
957         L(l);
958     }
959 
loop_enddnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t960     void loop_end(Label &l, XReg reg_cnt, int len, int i_step, int o_step,
961             int s_step) {
962         add_imm(reg_off_in, reg_off_in, i_step * itype_sz, X_TMP_0);
963         add_imm(reg_off_out, reg_off_out, o_step * otype_sz, X_TMP_0);
964         add_imm(x_ptr_in_off, x_ptr_in_off, i_step * itype_sz, X_TMP_0);
965         add_imm(x_ptr_out_off, x_ptr_out_off, o_step * otype_sz, X_TMP_0);
966 
967         if (prb_.scale_type == scale_type_t::MANY) {
968             add_imm(reg_off_scale, reg_off_scale, s_step * stype_sz, X_TMP_0);
969             add_imm(x_ptr_scale_off, x_ptr_scale_off, s_step * stype_sz,
970                     X_TMP_0);
971         }
972         subs(reg_cnt, reg_cnt, 1);
973         b(NE, l);
974 
975         sub_imm(reg_off_in, reg_off_in, len * i_step * itype_sz, X_TMP_0);
976         sub_imm(reg_off_out, reg_off_out, len * o_step * otype_sz, X_TMP_0);
977         sub_imm(x_ptr_in_off, x_ptr_in_off, len * i_step * itype_sz, X_TMP_0);
978         sub_imm(x_ptr_out_off, x_ptr_out_off, len * o_step * otype_sz, X_TMP_0);
979 
980         if (prb_.scale_type == scale_type_t::MANY) {
981             sub_imm(reg_off_scale, reg_off_scale, len * s_step * stype_sz,
982                     X_TMP_0);
983             sub_imm(x_ptr_scale_off, x_ptr_scale_off, len * s_step * stype_sz,
984                     X_TMP_0);
985         }
986     }
987 
simple_impldnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t988     bool simple_impl() {
989         simple_impl_desc_t d;
990         if (!simple_impl_desc_init(prb_, &d)) return false;
991 
992         const int nfu = d.ndims_full_unroll;
993         const int ldu = d.len_last_dim_unroll;
994         const int n_jit_loops = prb_.ndims - d.ndims_full_unroll;
995         assert(n_jit_loops <= ndims_jit_loop_max);
996 
997         eor(reg_off_in, reg_off_in, reg_off_in);
998         eor(reg_off_out, reg_off_out, reg_off_out);
999         mov(x_ptr_in_off, XReg(reg_ptr_in.getIdx()));
1000         mov(x_ptr_out_off, XReg(reg_ptr_out.getIdx()));
1001         if (prb_.scale_type == scale_type_t::MANY) {
1002             eor(reg_off_scale, reg_off_scale, reg_off_scale);
1003             mov(x_ptr_scale_off, XReg(reg_ptr_scale.getIdx()));
1004         }
1005 
1006         Label l_loop[3];
1007         XReg reg_cnt[3] = {x15, x14, x13};
1008 
1009         if (n_jit_loops > 2) loop_begin(l_loop[2], reg_cnt[2], n(nfu + 2));
1010 
1011         if (n_jit_loops > 1) loop_begin(l_loop[1], reg_cnt[1], n(nfu + 1));
1012 
1013         if (n_jit_loops > 0)
1014             loop_begin(l_loop[0], reg_cnt[0], n(nfu + 0) / ldu);
1015 
1016         bool optimized = false;
1017         optimized = optimized || process_direct_copy<sve_512>(d.len_unroll);
1018         optimized = optimized || process_direct_copy<asimd>(d.len_unroll);
1019         optimized = optimized || process_unroll_tr8x8(d.len_unroll);
1020         if (!optimized) process_unroll_generic(d.len_unroll);
1021 
1022         if (n_jit_loops > 0)
1023             loop_end(l_loop[0], reg_cnt[0], n(nfu + 0) / ldu, is(nfu + 0) * ldu,
1024                     os(nfu + 0) * ldu, ss(nfu + 0) * ldu);
1025 
1026         if (n_jit_loops > 1)
1027             loop_end(l_loop[1], reg_cnt[1], n(nfu + 1), is(nfu + 1),
1028                     os(nfu + 1), ss(nfu + 1));
1029 
1030         if (n_jit_loops > 2)
1031             loop_end(l_loop[2], reg_cnt[2], n(nfu + 2), is(nfu + 2),
1032                     os(nfu + 2), ss(nfu + 2));
1033 
1034         return true;
1035     }
1036 
impldnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t1037     void impl() {
1038         if (simple_impl()) return;
1039         assert(!"no implementation available");
1040     }
1041 
1042 #define UNROLL_INST(inst, reg, ...) \
1043     for (size_t i = startIdx; i < startIdx + regNum; i++) { \
1044         reg tmp(i); \
1045         inst(__VA_ARGS__); \
1046     }
1047 #define UNROLL_INST2(inst, ...) \
1048     for (size_t i = startIdx; i < startIdx + regNum; i++) \
1049         inst(__VA_ARGS__);
1050 
cvt_z_s32_f32dnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t1051     void cvt_z_s32_f32(const size_t startIdx, const size_t regNum) {
1052         UNROLL_INST(scvtf, ZRegS, tmp, p_all / T_m, tmp);
1053     }
1054 
cvt_v_s32_f32dnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t1055     void cvt_v_s32_f32(const size_t startIdx, const size_t regNum) {
1056         UNROLL_INST(scvtf, VReg4S, tmp, tmp);
1057     }
1058 
cvt_z_f32_s32dnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t1059     void cvt_z_f32_s32(const size_t startIdx, const size_t regNum) {
1060         UNROLL_INST(frinti, ZRegS, tmp, p_all / T_m, tmp);
1061         UNROLL_INST(fcvtzs, ZRegS, tmp, p_all / T_m, tmp);
1062     }
1063 
cvt_v_f32_s32dnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t1064     void cvt_v_f32_s32(const size_t startIdx, const size_t regNum) {
1065         UNROLL_INST(frinti, VReg4S, tmp, tmp);
1066         UNROLL_INST(fcvtzs, VReg4S, tmp, tmp);
1067     }
1068 
cvt_z_s8_s32dnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t1069     void cvt_z_s8_s32(const size_t startIdx, const size_t regNum) {
1070         cvt_z_b_s(startIdx, regNum);
1071         UNROLL_INST(sxtb, ZRegS, tmp, p_all / T_m, tmp);
1072     }
1073 
cvt_v_s8_s32dnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t1074     void cvt_v_s8_s32(const size_t startIdx, const size_t regNum) {
1075         UNROLL_INST(sxtl, VReg, tmp.h8, tmp.b8);
1076         UNROLL_INST(sxtl, VReg, tmp.s4, tmp.h4);
1077     }
1078 
cvt_z_s8_f32dnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t1079     void cvt_z_s8_f32(const size_t startIdx, const size_t regNum) {
1080         cvt_z_b_s(startIdx, regNum);
1081         cvt_z_s32_f32(startIdx, regNum);
1082     }
1083 
cvt_v_s8_f32dnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t1084     void cvt_v_s8_f32(const size_t startIdx, const size_t regNum) {
1085         cvt_v_b_s(startIdx, regNum);
1086         cvt_v_s32_f32(startIdx, regNum);
1087     }
1088 
cvt_z_b_sdnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t1089     void cvt_z_b_s(const size_t startIdx, const size_t regNum) {
1090         assert(z_tmp7.getIdx() < startIdx
1091                 || startIdx + regNum - 1 < z_tmp7.getIdx());
1092 
1093         dup(z_tmp7.b, 0);
1094         UNROLL_INST(zip1, ZRegB, tmp, tmp, z_tmp7.b);
1095         UNROLL_INST(zip1, ZRegH, tmp, tmp, z_tmp7.h);
1096     }
1097 
cvt_v_b_sdnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t1098     void cvt_v_b_s(const size_t startIdx, const size_t regNum) {
1099         assert(v_tmp7.getIdx() < startIdx
1100                 || startIdx + regNum - 1 < v_tmp7.getIdx());
1101 
1102         mov_imm(W_TMP_0, 0);
1103         dup(v_tmp7.b16, W_TMP_0);
1104         UNROLL_INST(zip1, VReg16B, tmp, tmp, v_tmp7.b16);
1105         UNROLL_INST(zip1, VReg8H, tmp, tmp, v_tmp7.h8);
1106     }
1107 
cvt_z_u8_s32dnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t1108     void cvt_z_u8_s32(const size_t startIdx, const size_t regNum) {
1109         cvt_z_b_s(startIdx, regNum);
1110         UNROLL_INST(uxtb, ZRegS, tmp, p_all / T_m, tmp);
1111     }
1112 
cvt_v_u8_s32dnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t1113     void cvt_v_u8_s32(const size_t startIdx, const size_t regNum) {
1114         UNROLL_INST(uxtl, VReg, tmp.h8, tmp.b8);
1115         UNROLL_INST(uxtl, VReg, tmp.s4, tmp.h4);
1116     }
1117 
cvt_z_s32_s8dnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t1118     void cvt_z_s32_s8(const size_t startIdx, const size_t regNum) {
1119         assert(z_tmp7.getIdx() < startIdx
1120                 || startIdx + regNum - 1 < z_tmp7.getIdx());
1121 
1122         dup(z_tmp7.s, 0);
1123         UNROLL_INST2(smin, ZRegS(i), 127);
1124         UNROLL_INST2(smax, ZRegS(i), -128);
1125         UNROLL_INST(uzp1, ZRegH, tmp, tmp, z_tmp7.h);
1126         UNROLL_INST(uzp1, ZRegB, tmp, tmp, z_tmp7.b);
1127     }
1128 
cvt_v_s32_s8dnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t1129     void cvt_v_s32_s8(const size_t startIdx, const size_t regNum) {
1130         assert(v_tmp7.getIdx() < startIdx
1131                 || startIdx + regNum - 1 < v_tmp7.getIdx());
1132 
1133         mov_imm(W_TMP_0, 127);
1134         dup(v_tmp7.s4, W_TMP_0);
1135         mov_imm(W_TMP_0, -128);
1136         UNROLL_INST2(smin, VReg4S(i), VReg4S(i), v_tmp7.s4);
1137         dup(v_tmp7.s4, W_TMP_0);
1138         UNROLL_INST2(smax, VReg4S(i), VReg4S(i), v_tmp7.s4);
1139         mov_imm(W_TMP_0, 0);
1140         dup(v_tmp7.s4, W_TMP_0);
1141         UNROLL_INST(uzp1, VReg8H, tmp, tmp, v_tmp7.h8);
1142         UNROLL_INST(uzp1, VReg16B, tmp, tmp, v_tmp7.b16);
1143     }
1144 
cvt_z_u8_s8dnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t1145     void cvt_z_u8_s8(const size_t startIdx, const size_t regNum) {
1146         UNROLL_INST2(umin, ZRegB(i), 127);
1147     }
1148 
cvt_v_u8_s8dnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t1149     void cvt_v_u8_s8(const size_t startIdx, const size_t regNum) {
1150         assert(v_tmp7.getIdx() < startIdx
1151                 || startIdx + regNum - 1 < v_tmp7.getIdx());
1152 
1153         mov_imm(W_TMP_0, 127);
1154         dup(v_tmp7.b16, W_TMP_0);
1155         UNROLL_INST(umin, VReg16B, tmp, tmp, v_tmp7.b16);
1156     }
1157 
cvt_z_u32_u8dnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t1158     void cvt_z_u32_u8(const size_t startIdx, const size_t regNum) {
1159         UNROLL_INST2(umin, ZRegS(i), 255);
1160         UNROLL_INST(uzp1, ZRegH, tmp, tmp, tmp);
1161         UNROLL_INST(uzp1, ZRegB, tmp, tmp, tmp);
1162     }
1163 
cvt_v_u32_u8dnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t1164     void cvt_v_u32_u8(const size_t startIdx, const size_t regNum) {
1165         assert(v_tmp7.getIdx() < startIdx
1166                 || startIdx + regNum - 1 < v_tmp7.getIdx());
1167 
1168         mov_imm(W_TMP_0, 255);
1169         dup(v_tmp7.s4, W_TMP_0);
1170         UNROLL_INST(umin, VReg4S, tmp, tmp, v_tmp7.s4);
1171         UNROLL_INST(uzp1, VReg8H, tmp, tmp, tmp);
1172         UNROLL_INST(uzp1, VReg16B, tmp, tmp, tmp);
1173     }
1174 
cvt_z_s32_u8dnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t1175     void cvt_z_s32_u8(const size_t startIdx, const size_t regNum) {
1176         assert(z_tmp7.getIdx() < startIdx
1177                 || startIdx + regNum - 1 < z_tmp7.getIdx());
1178 
1179         dupm(z_tmp7.s, 255);
1180         UNROLL_INST2(smax, ZRegS(i), 0);
1181         UNROLL_INST2(smin, ZRegS(i), p_all / T_m, z_tmp7.s);
1182         UNROLL_INST(uzp1, ZRegH, tmp, tmp, tmp);
1183         UNROLL_INST(uzp1, ZRegB, tmp, tmp, tmp);
1184         UNROLL_INST2(mov, ZRegB(i), P_NOT_128 / T_m, 0);
1185     }
1186 
cvt_v_s32_u8dnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t1187     void cvt_v_s32_u8(const size_t startIdx, const size_t regNum) {
1188         assert(v_tmp7.getIdx() < startIdx
1189                 || startIdx + regNum - 1 < v_tmp7.getIdx());
1190 
1191         mov_imm(W_TMP_0, 0);
1192         dup(v_tmp7.s4, W_TMP_0);
1193         mov_imm(W_TMP_0, 255);
1194         UNROLL_INST(smax, VReg4S, tmp, tmp, v_tmp7.s4);
1195         dup(v_tmp7.s4, W_TMP_0);
1196         UNROLL_INST(smin, VReg4S, tmp, tmp, v_tmp7.s4);
1197         UNROLL_INST(uzp1, VReg8H, tmp, tmp, tmp);
1198         UNROLL_INST(uzp1, VReg16B, tmp, tmp, tmp);
1199     }
1200 
cvt_z_s8_u8dnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t1201     void cvt_z_s8_u8(const size_t startIdx, const size_t regNum) {
1202         UNROLL_INST2(smax, ZRegB(i), 0);
1203     }
1204 
cvt_v_s8_u8dnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t1205     void cvt_v_s8_u8(const size_t startIdx, const size_t regNum) {
1206         assert(v_tmp7.getIdx() < startIdx
1207                 || startIdx + regNum - 1 < v_tmp7.getIdx());
1208 
1209         mov_imm(W_TMP_0, 0);
1210         dup(v_tmp7.b16, W_TMP_0);
1211         UNROLL_INST(smax, VReg16B, tmp, tmp, v_tmp7.b16);
1212     }
1213 #undef UNROLL_INST
1214 #undef UNROLL_INST
1215 
jit_uni_reorder_kernel_f32_tdnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t1216     jit_uni_reorder_kernel_f32_t(const desc_t &desc) : kernel_t(desc) {
1217         itype_sz = data_type_size(prb_.itype);
1218         otype_sz = data_type_size(prb_.otype);
1219         stype_sz = sizeof(float);
1220     }
1221 
generatednnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t1222     void generate() override {
1223         using namespace Xbyak_aarch64::util;
1224         uint64_t sveLen = get_sve_length();
1225 
1226         preamble();
1227 #define PARAM(x) offsetof(call_param_t, x)
1228         if (prb_.scale_type == scale_type_t::COMMON) {
1229             add_imm(X_DEFAULT_ADDR, abi_param1, PARAM(scale), X_TMP_1);
1230             ldr(X_TMP_0, ptr(X_DEFAULT_ADDR));
1231             ldr(W_TMP_1, ptr(X_TMP_0));
1232             dup(xmm_scale, W_TMP_1);
1233         } else if (prb_.scale_type == scale_type_t::MANY) {
1234             add_imm(X_DEFAULT_ADDR, abi_param1, PARAM(scale), X_TMP_0);
1235             ldr(reg_ptr_scale, ptr(X_DEFAULT_ADDR));
1236         }
1237         add_imm(X_TMP_0, abi_param1, PARAM(in), X_TMP_2);
1238         add_imm(X_TMP_1, abi_param1, PARAM(out), X_TMP_2);
1239         ldr(reg_ptr_in, ptr(X_TMP_0));
1240         ldr(reg_ptr_out, ptr(X_TMP_1));
1241 #undef PARAM
1242 
1243         mov(x_ptr_in_off, XReg(reg_ptr_in.getIdx()));
1244         mov(x_ptr_out_off, XReg(reg_ptr_out.getIdx()));
1245         mov(x_ptr_scale_off, XReg(reg_ptr_scale.getIdx()));
1246 
1247         if (sveLen) { /* SVE is available. */
1248             ptrue(p_lsb_256.b, VL32);
1249             ptrue(p_all.b);
1250         }
1251 
1252         if (can_do_tr8x8()) {
1253             dup(ymm_zero, 0);
1254 
1255             if (prb_.itype == data_type::u8 && prb_.otype == data_type::s8) {
1256                 mov_imm(reg_tmp, 0x7f7f7f7f7f7f7f7f);
1257                 mov(VReg4S(ymm_8x127b.getIdx())[0], WReg(reg_tmp.getIdx()));
1258             }
1259         } else if (mayiuse(sve_512)) {
1260             movi(xmm_zero, 0);
1261 
1262             if (prb_.itype == data_type::u8 && prb_.otype == data_type::s8) {
1263                 mov(WReg(reg_tmp.getIdx()), 0x7f7f7f7f);
1264                 mov(xmm_4x127b[0], WReg(reg_tmp.getIdx()));
1265             }
1266         }
1267 
1268         impl();
1269         postamble();
1270     }
1271 
1272 private:
1273     int itype_sz;
1274     int otype_sz;
1275     int stype_sz;
1276 
1277     XReg reg_ptr_in = x6;
1278     XReg reg_ptr_out = x2;
1279     XReg reg_ptr_scale = abi_not_param1;
1280 
1281     XReg reg_off_in = x8;
1282     XReg reg_off_out = x9;
1283     XReg reg_off_scale = x10;
1284 
1285     XReg reg_tmp = x0;
1286 
1287     VReg4S xmm_scale = v15.s;
1288     VReg4S xmm_zero = v14.s;
1289     VReg4S xmm_4x127b = v13.s; // TODO: unite with ymm_zero
1290     ZRegS ymm_zero = z14.s;
1291     ZRegS ymm_8x127b = z13.s;
1292     VReg4S xmm_tmp = v12.s;
1293     VReg4S xmm_saturation_ubound = v12.s;
1294     ZRegS ymm_saturation_ubound = z12.s;
1295 
1296     /* Note: x22 - x28 are already used as temporal registgers
1297        in jit_generator.hpp.
1298        x_ptr_(in|out|scale)_off keeps (base + offset) address. */
1299     XReg x_ptr_in_off = x16;
1300     XReg x_ptr_out_off = x18;
1301     XReg x_ptr_scale_off = x20;
1302 
1303     /* Caution: Chose predicate registers not used by x64's implementation. */
1304     PReg p_lsb_256 = p7;
1305     PReg p_all = p6;
1306     PReg p_tmp0 = p5;
1307 
1308     const std::vector<uint32_t> tmp_vec_idx = {20, 21, 22, 23, 24, 25, 26, 27};
1309     ZReg z_tmp0 = z20;
1310     ZReg z_tmp1 = z21;
1311     ZReg z_tmp2 = z22;
1312     ZReg z_tmp3 = z23;
1313     ZReg z_tmp4 = z24;
1314     ZReg z_tmp5 = z25;
1315     ZReg z_tmp6 = z26;
1316     ZReg z_tmp7 = z27;
1317     VReg v_tmp7 = v27;
1318 
1319     const std::vector<ZReg> z_tmp_vec
1320             = {z_tmp0, z_tmp1, z_tmp2, z_tmp3, z_tmp4, z_tmp5, z_tmp6, z_tmp7};
1321     constexpr static int z_tmp_vec_size = 8;
1322 };
1323 
desc_init(kernel_t::desc_t & desc,const prb_t & prb,int ndims_ker_max)1324 status_t kernel_t::desc_init(
1325         kernel_t::desc_t &desc, const prb_t &prb, int ndims_ker_max) {
1326     desc.prb = prb;
1327     desc.prb.ioff = desc.prb.ooff = 0;
1328 
1329     if (ndims_ker_max > prb.ndims) return status::invalid_arguments;
1330 
1331     auto ndims_ker_max_f = [&]() {
1332         size_t cur_size = 1;
1333         for (int d = 0; d < prb.ndims; cur_size *= prb.nodes[d++].n)
1334             if (cur_size >= ker_prb_size_min) return d;
1335         return prb.ndims;
1336     };
1337 
1338     if (ndims_ker_max <= 0) ndims_ker_max = ndims_ker_max_f();
1339 
1340     /* traverse through kernel implementations */
1341     /* TODO: find a better way to do that... */
1342     desc.id = 0;
1343     for (int ndims_ker = ndims_ker_max; ndims_ker > 0; --ndims_ker) {
1344         desc.prb.ndims = ndims_ker;
1345         if (jit_uni_reorder_kernel_f32_t::applicable(desc.prb))
1346             return status::success;
1347     }
1348 
1349     return status::unimplemented;
1350 }
1351 
create(const kernel_t::desc_t & desc)1352 kernel_t *kernel_t::create(const kernel_t::desc_t &desc) {
1353     switch (desc.id) {
1354         case 0: return new jit_uni_reorder_kernel_f32_t(desc);
1355         default: assert(!"unknown kernel id"); return nullptr;
1356     }
1357 
1358     return nullptr;
1359 }
1360 } // namespace tr
1361 
prb_block_for_cache(tr::prb_t & prb)1362 static void prb_block_for_cache(tr::prb_t &prb) {
1363     /* If strides for 0th and 1st nodes are cache friendly
1364      * then one can altogether do away with blocking ! */
1365     const bool cache_blocking_needed = false
1366             || (prb.nodes[0].is % 64 == 0 && prb.nodes[0].n > 16)
1367             || (prb.ndims > 1 && prb.nodes[1].is % 64 == 0
1368                     && prb.nodes[1].n > 16);
1369     if (!cache_blocking_needed) return;
1370 
1371     int unit_input_stride_idx = -1;
1372     for (auto idx = 0; idx < prb.ndims; ++idx) {
1373         if (prb.nodes[idx].is == 1) unit_input_stride_idx = idx;
1374     }
1375 
1376     /* Re-prioritize the sequential read over sequential write:
1377      *                             /-> [n0:is0:1][16n1:1:osk]...
1378      * [n0:is0:1]...[nk:1:osk] -->     or
1379      *                             \-> [16n1:1:osk][n0:is0:1]... */
1380     if (unit_input_stride_idx != -1) {
1381         const auto output_stride = prb.nodes[unit_input_stride_idx].os;
1382         const auto num_elems = prb.nodes[unit_input_stride_idx].n;
1383 
1384         const bool split_needed = (num_elems > 16) && (num_elems % 16 == 0);
1385         const int move_location = (output_stride % 4 != 0) ? 0 : 1;
1386         if (split_needed) prb_node_split(prb, unit_input_stride_idx, 16);
1387 
1388         /* Because of cache-unfriendly nature of unit-output stride node, let
1389          * us move unit-input stride node on or near front! */
1390         prb_node_move(prb, unit_input_stride_idx, move_location);
1391     }
1392 
1393     /* Potentially, split the node with os=1 in two and pull in the node with
1394      * is=1 between them for better cache reuse:
1395      * [n0:is0:1][n1:1:os1] --> [16n0:is0:1][n1:1:os1][n0/16:is0*16:16] */
1396     if (prb.ndims >= 2 && prb.nodes[0].os == 1 && prb.nodes[1].is == 1) {
1397         const auto input_stride = prb.nodes[0].is;
1398         const auto num_elems = prb.nodes[0].n;
1399 
1400         const bool split_needed = true && (num_elems > 16)
1401                 && (num_elems % 16 == 0) && (input_stride >= 256)
1402                 && (input_stride % 64 == 0);
1403         if (split_needed) {
1404             prb_node_split(prb, 0, 16);
1405             prb_node_move(prb, 1, 2);
1406         }
1407     }
1408 }
1409 
1410 /** finds the maximum number of dimension the kernel should process and
1411  * optionally splits one of the dimension to achieve better balance between
1412  * parallel driver and the kernel. */
prb_thread_kernel_balance(tr::prb_t & prb,int & ndims_ker_max,int nthr)1413 static void prb_thread_kernel_balance(
1414         tr::prb_t &prb, int &ndims_ker_max, int nthr) {
1415     size_t sz_total = 1;
1416     for (int d = 0; d < prb.ndims; ++d)
1417         sz_total *= prb.nodes[d].n;
1418 
1419     /* sz_drv_min is the minimal size for the parallel
1420      * driver required for good parallelization */
1421     const size_t sz_drv_min
1422             = nstl::min<size_t>(16 * nthr, utils::div_up(sz_total, 1024));
1423 
1424     /* kdims -- # of dimensions processed by a kernel
1425      * sz_ker_cur -- product of the dimension processed by a kernel
1426      * sz_drv_cur -- product of the dimension processed by a driver */
1427 
1428     int kdims = prb.ndims;
1429     size_t sz_drv_cur = 1;
1430     for (; kdims > 1 && sz_drv_cur < sz_drv_min; --kdims)
1431         sz_drv_cur *= prb.nodes[kdims - 1].n;
1432 
1433     size_t sz_ker_cur = 1;
1434     for (int d = 0; d < kdims; ++d)
1435         sz_ker_cur *= prb.nodes[d].n;
1436 
1437     /* Initially kdims is chosen so that sz_drv_cur >= sz_drv_min.
1438      *
1439      * It might happen that for chosen kdims the sz_ker_cur is too small
1440      * (less than tr::ker_prb_size_min). In that case try to split the
1441      * innermost driver dimension into two, to increase sz_ker_cur. */
1442     bool want_borrow_ker_from_drv = true && kdims < prb.ndims
1443             && sz_ker_cur < tr::ker_prb_size_min && sz_drv_cur > sz_drv_min;
1444     if (want_borrow_ker_from_drv) {
1445         /* sz_want_borrow is the minimal sz, so that:
1446          *  o) sz_ker_cur * sz_want_borrow >= tr::ker_prb_size_min
1447          *  o) current innermost driver dimension is divisible by
1448          *     sz_want_borrow (so that we can evenly split that
1449          *     dimension into two)
1450          *
1451          *  In the worst case the minimal sz_want_borrow is equal
1452          *  to the innermost driver dimension itself. In that case
1453          *  we will sacrifice it in favor of kernel (is it fine?). */
1454         size_t sz_want_borrow = utils::div_up(tr::ker_prb_size_min, sz_ker_cur);
1455         for (; prb.nodes[kdims].n % sz_want_borrow; ++sz_want_borrow)
1456             ;
1457         if (sz_want_borrow != prb.nodes[kdims].n)
1458             prb_node_split(prb, kdims, sz_want_borrow);
1459         kdims += 1;
1460     }
1461 
1462     /* On the other hand it might happen that for chosen kdims
1463      * the sz_drv_cur is too small (less than sz_drv_min). In that case
1464      * try to split the outermost kernel dimension into two, to increase
1465      * sz_drv_cur. */
1466     bool want_borrow_drv_from_ker = true && sz_ker_cur > tr::ker_prb_size_min
1467             && sz_drv_cur < sz_drv_min;
1468     if (want_borrow_drv_from_ker) {
1469         size_t sz_want_borrow = utils::div_up(sz_drv_min, sz_drv_cur);
1470         for (; prb.nodes[kdims - 1].n % sz_want_borrow; ++sz_want_borrow)
1471             ;
1472         if (sz_want_borrow != prb.nodes[kdims - 1].n)
1473             prb_node_split(
1474                     prb, kdims - 1, prb.nodes[kdims - 1].n / sz_want_borrow);
1475     }
1476 
1477     ndims_ker_max = kdims;
1478 
1479     if (want_borrow_ker_from_drv || want_borrow_drv_from_ker) {
1480         DEBUG({
1481             printf("split: ");
1482             prb_dump(prb);
1483             printf("ndims_ker_max = %d\n", ndims_ker_max);
1484         });
1485     }
1486 }
1487 
create(reorder_pd_t ** reorder_pd,engine_t * engine,const primitive_attr_t * attr,engine_t * src_engine,const memory_desc_t * src_md,engine_t * dst_engine,const memory_desc_t * dst_md)1488 status_t jit_uni_reorder_t::pd_t::create(reorder_pd_t **reorder_pd,
1489         engine_t *engine, const primitive_attr_t *attr, engine_t *src_engine,
1490         const memory_desc_t *src_md, engine_t *dst_engine,
1491         const memory_desc_t *dst_md) {
1492     auto prb = tr::prb_t();
1493 
1494     status_t prb_init_status = prb_init(prb, *src_md, *dst_md, attr);
1495     if (prb_init_status != status::success) return prb_init_status;
1496 
1497     DEBUG({
1498         printf("init : ");
1499         prb_dump(prb);
1500     });
1501     // Sort the prb array in increasing sizes of the output stride
1502     prb_normalize(prb);
1503     DEBUG({
1504         printf("norm : ");
1505         prb_dump(prb);
1506     });
1507     /* Combine the variables, which appear together on both
1508              * sides of the reorder */
1509     prb_simplify(prb);
1510     DEBUG({
1511         printf("smpl : ");
1512         prb_dump(prb);
1513     });
1514 
1515     prb_block_for_cache(prb);
1516     DEBUG({
1517         printf("cache: ");
1518         prb_dump(prb);
1519     });
1520 
1521     int ndims_ker_max;
1522     int nthr = dnnl_get_max_threads();
1523     prb_thread_kernel_balance(prb, ndims_ker_max, nthr);
1524 
1525     tr::kernel_t::desc_t ker_desc;
1526     status_t ker_init_status
1527             = tr::kernel_t::desc_init(ker_desc, prb, ndims_ker_max);
1528     if (ker_init_status != status::success) return ker_init_status;
1529 
1530     const int ndims_driver = prb.ndims - ker_desc.prb.ndims;
1531     if (ndims_driver > jit_uni_reorder_t::ndims_driver_max)
1532         return status::unimplemented;
1533 
1534     DEBUG({
1535         printf("ker  : ");
1536         prb_dump(ker_desc.prb);
1537     });
1538 
1539     auto _pd = new pd_t(
1540             attr, src_engine->kind(), src_md, dst_engine->kind(), dst_md);
1541     if (_pd == nullptr) return status::out_of_memory;
1542     if (_pd->init(engine, src_engine, dst_engine) != status::success) {
1543         delete _pd;
1544         return status::unimplemented;
1545     }
1546     _pd->prb_ = prb;
1547     _pd->ker_desc_ = ker_desc;
1548     _pd->init_scratchpad_md();
1549     _pd->nthr_ = nthr;
1550     return safe_ptr_assign(*reorder_pd, _pd);
1551 }
1552 
omp_driver_0d(int off,const char * in,char * out,const float * scale) const1553 void jit_uni_reorder_t::omp_driver_0d(
1554         int off, const char *in, char *out, const float *scale) const {
1555     tr::call_param_t c {in, out, scale};
1556     (*kernel_)(&c);
1557 }
1558 
omp_driver_1d(int ithr,int nthr,int off,const char * in,char * out,const float * scale) const1559 void jit_uni_reorder_t::omp_driver_1d(int ithr, int nthr, int off,
1560         const char *in, char *out, const float *scale) const {
1561     const tr::node_t *ns = pd()->prb_.nodes + off;
1562     for_nd(ithr, nthr, (ptrdiff_t)ns[0].n, [&](ptrdiff_t d0) {
1563         auto c = tr::call_param_t();
1564         c.in = in + d0 * ns[0].is * data_type_size(pd()->prb_.itype);
1565         c.out = out + d0 * ns[0].os * data_type_size(pd()->prb_.otype);
1566         c.scale = scale + d0 * ns[0].ss;
1567         (*kernel_)(&c);
1568     });
1569 }
1570 
omp_driver_2d(int ithr,int nthr,int off,const char * in,char * out,const float * scale) const1571 void jit_uni_reorder_t::omp_driver_2d(int ithr, int nthr, int off,
1572         const char *in, char *out, const float *scale) const {
1573     const tr::node_t *ns = pd()->prb_.nodes + off;
1574     for_nd(ithr, nthr, (ptrdiff_t)ns[1].n, (ptrdiff_t)ns[0].n,
1575             [&](ptrdiff_t d1, ptrdiff_t d0) {
1576                 auto c = tr::call_param_t();
1577                 c.in = in
1578                         + (d0 * ns[0].is + d1 * ns[1].is)
1579                                 * data_type_size(pd()->prb_.itype);
1580                 c.out = out
1581                         + (d0 * ns[0].os + d1 * ns[1].os)
1582                                 * data_type_size(pd()->prb_.otype);
1583                 c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss;
1584                 (*kernel_)(&c);
1585             });
1586 }
1587 
omp_driver_3d(int ithr,int nthr,int off,const char * in,char * out,const float * scale) const1588 void jit_uni_reorder_t::omp_driver_3d(int ithr, int nthr, int off,
1589         const char *in, char *out, const float *scale) const {
1590     const tr::node_t *ns = pd()->prb_.nodes + off;
1591     for_nd(ithr, nthr, (ptrdiff_t)ns[2].n, (ptrdiff_t)ns[1].n,
1592             (ptrdiff_t)ns[0].n, [&](ptrdiff_t d2, ptrdiff_t d1, ptrdiff_t d0) {
1593                 auto c = tr::call_param_t();
1594                 c.in = in
1595                         + (d0 * ns[0].is + d1 * ns[1].is + d2 * ns[2].is)
1596                                 * data_type_size(pd()->prb_.itype);
1597                 c.out = out
1598                         + (d0 * ns[0].os + d1 * ns[1].os + d2 * ns[2].os)
1599                                 * data_type_size(pd()->prb_.otype);
1600                 c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss + d2 * ns[2].ss;
1601                 (*kernel_)(&c);
1602             });
1603 }
1604 
omp_driver_4d(int ithr,int nthr,int off,const char * in,char * out,const float * scale) const1605 void jit_uni_reorder_t::omp_driver_4d(int ithr, int nthr, int off,
1606         const char *in, char *out, const float *scale) const {
1607     const tr::node_t *ns = pd()->prb_.nodes + off;
1608     for_nd(ithr, nthr, (ptrdiff_t)ns[3].n, (ptrdiff_t)ns[2].n,
1609             (ptrdiff_t)ns[1].n, (ptrdiff_t)ns[0].n,
1610             [&](ptrdiff_t d3, ptrdiff_t d2, ptrdiff_t d1, ptrdiff_t d0) {
1611                 auto c = tr::call_param_t();
1612                 c.in = in
1613                         + (d0 * ns[0].is + d1 * ns[1].is + d2 * ns[2].is
1614                                   + d3 * ns[3].is)
1615                                 * data_type_size(pd()->prb_.itype);
1616                 c.out = out
1617                         + (d0 * ns[0].os + d1 * ns[1].os + d2 * ns[2].os
1618                                   + d3 * ns[3].os)
1619                                 * data_type_size(pd()->prb_.otype);
1620                 c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss + d2 * ns[2].ss
1621                         + d3 * ns[3].ss;
1622                 (*kernel_)(&c);
1623             });
1624 }
1625 
omp_driver(const char * in,char * out,const float * scale) const1626 void jit_uni_reorder_t::omp_driver(
1627         const char *in, char *out, const float *scale) const {
1628     in += pd()->prb_.ioff * data_type_size(pd()->prb_.itype);
1629     out += pd()->prb_.ooff * data_type_size(pd()->prb_.otype);
1630 
1631     DEBUG({
1632         printf("prb : ");
1633         tr::prb_dump(pd()->prb_);
1634     });
1635     DEBUG({
1636         printf("ker : ");
1637         tr::prb_dump(pd()->ker_desc_.prb);
1638     });
1639 
1640     int ndims = pd()->prb_.ndims;
1641     int ndims_ker = pd()->ker_desc_.prb.ndims;
1642     assert(ndims - ndims_ker <= ndims_driver_max);
1643 
1644     if (ndims - ndims_ker == 0) {
1645         omp_driver_0d(ndims_ker, in, out, scale);
1646     } else {
1647         parallel(pd()->nthr_, [&](const int ithr, const int nthr) {
1648             switch (ndims - ndims_ker) {
1649                 case 1:
1650                     omp_driver_1d(ithr, nthr, ndims_ker, in, out, scale);
1651                     break;
1652                 case 2:
1653                     omp_driver_2d(ithr, nthr, ndims_ker, in, out, scale);
1654                     break;
1655                 case 3:
1656                     omp_driver_3d(ithr, nthr, ndims_ker, in, out, scale);
1657                     break;
1658                 case 4:
1659                     omp_driver_4d(ithr, nthr, ndims_ker, in, out, scale);
1660                     break;
1661                 default: assert(!"unimplemented");
1662             }
1663         });
1664     }
1665 }
1666 
init(engine_t * engine)1667 status_t jit_uni_reorder_t::init(engine_t *engine) {
1668     CHECK(safe_ptr_assign(kernel_, tr::kernel_t::create(pd()->ker_desc_)));
1669     return kernel_->create_kernel();
1670 }
1671 
execute(const exec_ctx_t & ctx) const1672 status_t jit_uni_reorder_t::execute(const exec_ctx_t &ctx) const {
1673     status_t status = status::success;
1674     auto in = CTX_IN_MEM(const char *, DNNL_ARG_FROM);
1675     auto out = CTX_OUT_CLEAN_MEM(char *, DNNL_ARG_TO, status);
1676     CHECK(status);
1677     DEFINE_SCALES_BUFFER(scales);
1678 
1679     omp_driver(in, out, scales);
1680 
1681     return status::success;
1682 }
1683 
1684 } // namespace aarch64
1685 } // namespace cpu
1686 } // namespace impl
1687 } // namespace dnnl
1688