1 /*******************************************************************************
2 * Copyright 2018-2021 Intel Corporation
3 * Copyright 2020-2021 FUJITSU LIMITED
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *******************************************************************************/
17
18 #include <assert.h>
19 #include <numeric>
20
21 #include "dnnl_debug.h"
22
23 #include "common/c_types_map.hpp"
24 #include "common/memory_desc_wrapper.hpp"
25 #include "common/nstl.hpp"
26 #include "common/primitive.hpp"
27 #include "common/type_helpers.hpp"
28 #include "common/utils.hpp"
29
30 #include "cpu/aarch64/jit_uni_reorder.hpp"
31 #include "cpu/cpu_primitive.hpp"
32 #include "cpu/reorder/cpu_reorder_pd.hpp"
33
34 #include "cpu/aarch64/jit_generator.hpp"
35
36 // #define TR_DEBUG
37 #if defined(TR_DEBUG)
38 #define DEBUg(...) \
39 do { \
40 __VA_ARGS__ \
41 } while (0)
42 #else
43 #define DEBUg(...)
44 #endif
45 #define DEBUG(...) DEBUg(__VA_ARGS__)
46
47 using namespace Xbyak_aarch64;
48 using namespace dnnl::impl::types;
49
50 namespace dnnl {
51 namespace impl {
52 namespace cpu {
53 namespace aarch64 {
54
55 namespace tr {
56
57 /** Minimal reasonable/desirable kernel size.
58 * The constant might be used to determine how a problem should be split
59 * between kernel and threading driver. */
60 const size_t ker_prb_size_min = 64;
61
62 /* kernel */
63 struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
DECLARE_CPU_JIT_AUX_FUNCTIONSdnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t64 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_reorder_kernel_f32)
65
66 void operator()(const call_param_t *c) const override {
67 jit_generator::operator()(c);
68 }
69
create_kerneldnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t70 status_t create_kernel() override { return jit_generator::create_kernel(); }
71
72 enum {
73 len_unroll_max = 256,
74 ndims_jit_loop_max = 3,
75 };
76
77 struct simple_impl_desc_t {
78 int ndims_full_unroll;
79 int len_last_dim_unroll;
80 int len_unroll;
81 };
82
simple_impl_desc_initdnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t83 static bool simple_impl_desc_init(
84 const prb_t &prb, simple_impl_desc_t *desc) {
85 const int ndims = prb.ndims;
86
87 int ndims_full_unroll = 0;
88 int len_last_dim_unroll = 1;
89 int len_unroll = 1;
90
91 for (int d = 0; d < ndims; ++d) {
92 auto &node = prb.nodes[d];
93 if (len_unroll * node.n <= len_unroll_max) {
94 ndims_full_unroll++;
95 len_unroll *= node.n;
96 } else {
97 len_last_dim_unroll = len_unroll_max / len_unroll;
98 while (node.n % len_last_dim_unroll)
99 --len_last_dim_unroll;
100 len_unroll *= len_last_dim_unroll;
101 break;
102 }
103 }
104
105 if (prb.ndims - ndims_full_unroll > ndims_jit_loop_max) return false;
106
107 if (desc) {
108 desc->ndims_full_unroll = ndims_full_unroll;
109 desc->len_last_dim_unroll = len_last_dim_unroll;
110 desc->len_unroll = len_unroll;
111 }
112
113 return true;
114 }
115
applicablednnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t116 static bool applicable(const prb_t &p) {
117 using namespace data_type;
118
119 bool ok = true && p.ndims > 0
120 && utils::one_of(p.itype, f32, s32, data_type::s8, u8)
121 && utils::one_of(p.otype, f32, s32, data_type::s8, u8)
122 && utils::everyone_is(0, p.ioff, p.ooff) /* do we need this? */
123 && utils::one_of(p.beta, 0.f, 1.f) /* anything else? */
124 && simple_impl_desc_init(p, nullptr);
125 if (!ok) return false;
126
127 const ptrdiff_t max_stride = (1LL << 31) - 1;
128 for (int d = 0; d < p.ndims; ++d) {
129 const ptrdiff_t cms = max_stride / p.nodes[d].n;
130 bool strides_ok = true
131 && p.nodes[d].is < cms / (int)data_type_size(p.itype)
132 && p.nodes[d].os < cms / (int)data_type_size(p.otype);
133 if (!strides_ok) return false;
134 }
135
136 return true;
137 }
138
ndnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t139 int n(int d) {
140 assert(d < prb_.ndims);
141 return (int)prb_.nodes[d].n;
142 }
isdnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t143 int is(int d) {
144 assert(d < prb_.ndims);
145 return (int)prb_.nodes[d].is;
146 }
osdnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t147 int os(int d) {
148 assert(d < prb_.ndims);
149 return (int)prb_.nodes[d].os;
150 }
ssdnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t151 int ss(int d) {
152 assert(d < prb_.ndims);
153 return (int)prb_.nodes[d].ss;
154 }
155
stepdnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t156 void step(int off, int prev_i_off, int prev_o_off, int prev_s_off,
157 int &i_off, int &o_off, int &s_off, int step_size = 1) {
158 i_off = prev_i_off;
159 o_off = prev_o_off;
160 s_off = prev_s_off;
161
162 if (off == 0) return;
163
164 int start_dim = 0, dims_prod = 1;
165 for (; start_dim < prb_.ndims && dims_prod != step_size; ++start_dim)
166 dims_prod *= n(start_dim);
167 assert(start_dim < prb_.ndims);
168 off /= step_size;
169
170 for (int d = start_dim; d < prb_.ndims; ++d) {
171 i_off += is(d);
172 o_off += os(d);
173 s_off += ss(d);
174
175 if (off % n(d)) break;
176
177 i_off += -n(d) * is(d);
178 o_off += -n(d) * os(d);
179 s_off += -n(d) * ss(d);
180 off /= n(d);
181
182 if (off == 0) break; /* FIXME: is it really required? */
183 }
184 }
185
stepdnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t186 void step(int off, int prev_i_off, int prev_o_off, int &i_off, int &o_off,
187 int step_size = 1) {
188 int dummy = 0;
189 step(off, prev_i_off, prev_o_off, dummy, i_off, o_off, dummy,
190 step_size);
191 }
192
tr8x8_sve256dnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t193 void tr8x8_sve256(int i_off, int o_off) {
194 using namespace data_type;
195
196 const auto cvt2ps
197 = [=](const int startIdx, const int regNum, data_type_t idt) {
198 switch (idt) {
199 case f32:
200 /* do nothing */
201 break;
202 case s32: cvt_z_s32_f32(startIdx, regNum); break;
203 case data_type::s8:
204 cvt_z_s8_s32(startIdx, regNum);
205 cvt_z_s32_f32(startIdx, regNum);
206 break;
207 case u8:
208 cvt_z_u8_s32(startIdx, regNum);
209 cvt_z_s32_f32(startIdx, regNum);
210 break;
211 default: assert(!"unreachable");
212 }
213 };
214
215 const auto cvt2odt = [=](const int startIdx, const int regNum,
216 data_type_t odt, data_type_t idt) {
217 switch (odt) {
218 case s32:
219 if (idt == f32)
220 cvt_z_f32_s32(startIdx, regNum);
221 else if (idt == data_type::s8)
222 cvt_z_s8_s32(startIdx, regNum);
223 else if (idt == u8)
224 cvt_z_u8_s32(startIdx, regNum);
225 break;
226 case data_type::s8:
227 if (idt == f32) cvt_z_f32_s32(startIdx, regNum);
228 if (utils::one_of(idt, f32, s32))
229 cvt_z_s32_s8(startIdx, regNum);
230 if (idt == u8) cvt_z_u8_s8(startIdx, regNum);
231 break;
232 case u8:
233 if (idt == f32) cvt_z_f32_s32(startIdx, regNum);
234 if (utils::one_of(idt, f32, s32))
235 cvt_z_s32_u8(startIdx, regNum);
236 if (idt == data_type::s8) cvt_z_s8_u8(startIdx, regNum);
237 break;
238 default: assert(!"unreachable");
239 }
240 };
241
242 const int unroll = 8;
243
244 const bool interim_f32 = (prb_.itype != f32)
245 || utils::one_of(f32, prb_.itype, prb_.otype);
246
247 const bool need_saturation
248 = (utils::one_of(prb_.otype, u8, data_type::s8, s32)
249 && interim_f32);
250 const uint64_t sveLen = get_sve_length();
251
252 add_imm(X_TMP_0, XReg(x_ptr_in_off), i_off * itype_sz, X_DEFAULT_ADDR);
253 add_imm(X_TMP_1, X_TMP_0, is(0) * itype_sz, X_DEFAULT_ADDR);
254 add_imm(X_TMP_2, X_TMP_1, is(0) * itype_sz, X_DEFAULT_ADDR);
255 add_imm(X_TMP_3, X_TMP_2, is(0) * itype_sz, X_DEFAULT_ADDR);
256
257 if (unroll * itype_sz == 32)
258 for (uint32_t i = 0; i < 4; i++)
259 ld1w(ZRegS {i}, p_lsb_256 / T_z, ptr(x_tmp_vec[i]));
260 else if (unroll * itype_sz == 16)
261 for (uint32_t i = 0; i < 4; i++)
262 ldr(QReg {i}, ptr(x_tmp_vec[i]));
263 else if (unroll * itype_sz == 8)
264 for (uint32_t i = 0; i < 4; i++)
265 ldr(DReg {i}, ptr(x_tmp_vec[i]));
266
267 add_imm(X_TMP_0, X_TMP_3, is(0) * itype_sz, X_DEFAULT_ADDR);
268 add_imm(X_TMP_1, X_TMP_0, is(0) * itype_sz, X_DEFAULT_ADDR);
269 add_imm(X_TMP_2, X_TMP_1, is(0) * itype_sz, X_DEFAULT_ADDR);
270 add_imm(X_TMP_3, X_TMP_2, is(0) * itype_sz, X_DEFAULT_ADDR);
271
272 if (unroll * itype_sz == 32)
273 for (uint32_t i = 0; i < 4; i++)
274 ld1w(ZRegS {4 + i}, p_lsb_256 / T_z, ptr(x_tmp_vec[i]));
275 else if (unroll * itype_sz == 16)
276 for (uint32_t i = 0; i < 4; i++)
277 ldr(QReg {4 + i}, ptr(x_tmp_vec[i]));
278 else if (unroll * itype_sz == 8)
279 for (uint32_t i = 0; i < 4; i++)
280 ldr(DReg {4 + i}, ptr(x_tmp_vec[i]));
281
282 if (interim_f32) cvt2ps(0, unroll, prb_.itype);
283
284 #if 0
285 /* Deubg code */
286 index(z0.s, 0, 1);
287 mov(z0.s, P_NOT_256/T_m, 0);
288 mov(z_tmp_vec[0].s, 16);
289 for(uint32_t i=1; i<8; i++) {
290 add(ZRegS{i}, ZRegS{i-1}, z_tmp_vec[0].s);
291 mov(ZRegS{i}, P_NOT_256/T_m, 0);
292 }
293 #endif
294
295 ptrue(p_tmp0.s, VL4);
296 /* 1st turn */
297 for (uint32_t i = 0; i < unroll / 2; i++) {
298 trn1(z_tmp_vec[i].s, ZRegS {2 * i}, ZRegS {2 * i + 1});
299 trn2(z_tmp_vec[unroll / 2 + i].s, ZRegS {2 * i}, ZRegS {2 * i + 1});
300 }
301
302 /* 2nd turn */
303 trn1(z4.d, z_tmp_vec[0].d, z_tmp_vec[1].d);
304 trn1(z5.d, z_tmp_vec[4].d, z_tmp_vec[5].d);
305 trn2(z6.d, z_tmp_vec[0].d, z_tmp_vec[1].d);
306 trn2(z7.d, z_tmp_vec[4].d, z_tmp_vec[5].d);
307 trn1(z_tmp_vec[0].d, z_tmp_vec[2].d, z_tmp_vec[3].d);
308 trn1(z_tmp_vec[1].d, z_tmp_vec[6].d, z_tmp_vec[7].d);
309 trn2(z_tmp_vec[2].d, z_tmp_vec[2].d, z_tmp_vec[3].d);
310 trn2(z_tmp_vec[3].d, z_tmp_vec[6].d, z_tmp_vec[7].d);
311
312 /* 3rd turn */
313 for (uint32_t i = 0; i < unroll / 2; i++) {
314 mov(ZRegD {i}, ZRegD {unroll / 2 + i});
315 mov(z_tmp_vec[unroll / 2 + i].d, z_tmp_vec[i].d);
316 }
317
318 /* 4th turn */
319 for (uint32_t i = 0; i < unroll / 2; i++) {
320 ZRegB z {unroll / 2 + i};
321 ZRegB z_tmp = z_tmp_vec[unroll / 2 + i].b;
322 /* Move bit 128-255 to 0-127. */
323 ext(z, z, 16);
324 /* Move bit 0-127 to 128-255. */
325 ext(z_tmp, z_tmp, sveLen - 16);
326 }
327
328 /* 5th turn */
329 for (uint32_t i = 0; i < unroll / 2; i++) {
330 ZRegS z0 {i};
331 ZRegS z1 {unroll / 2 + i};
332 sel(z0, p_tmp0.s, z0, z_tmp_vec[unroll / 2 + i].s);
333 sel(z1, p_tmp0, z1, z_tmp_vec[i].s);
334 }
335
336 if (need_saturation) {
337 init_saturate_f32(ymm_zero, ymm_saturation_ubound, reg_tmp,
338 interim_f32 ? f32 : prb_.itype, prb_.otype);
339 for (int i = 0; i < unroll; i++)
340 saturate_f32(ZRegS(i), ymm_zero, ymm_saturation_ubound,
341 prb_.otype, p_all);
342 }
343
344 if (prb_.otype != f32)
345 cvt2odt(0, unroll, prb_.otype, interim_f32 ? f32 : prb_.itype);
346
347 add_imm(X_TMP_0, XReg(x_ptr_out_off), o_off * otype_sz, X_DEFAULT_ADDR);
348 add_imm(X_TMP_1, X_TMP_0, os(1) * otype_sz, X_DEFAULT_ADDR);
349 add_imm(X_TMP_2, X_TMP_1, os(1) * otype_sz, X_DEFAULT_ADDR);
350 add_imm(X_TMP_3, X_TMP_2, os(1) * otype_sz, X_DEFAULT_ADDR);
351
352 if (unroll * otype_sz == 32)
353 for (uint32_t i = 0; i < 4; i++)
354 st1w(ZRegS {i}, p_lsb_256 / T_z, ptr(x_tmp_vec[i]));
355 else if (unroll * otype_sz == 16)
356 for (uint32_t i = 0; i < 4; i++)
357 str(QReg {i}, ptr(x_tmp_vec[i]));
358 else if (unroll * otype_sz == 8)
359 for (uint32_t i = 0; i < 4; i++)
360 str(DReg {i}, ptr(x_tmp_vec[i]));
361
362 add_imm(X_TMP_0, X_TMP_3, os(1) * otype_sz, X_DEFAULT_ADDR);
363 add_imm(X_TMP_1, X_TMP_0, os(1) * otype_sz, X_DEFAULT_ADDR);
364 add_imm(X_TMP_2, X_TMP_1, os(1) * otype_sz, X_DEFAULT_ADDR);
365 add_imm(X_TMP_3, X_TMP_2, os(1) * otype_sz, X_DEFAULT_ADDR);
366
367 if (unroll * otype_sz == 32)
368 for (uint32_t i = 0; i < 4; i++)
369 st1w(ZRegS {4 + i}, p_lsb_256 / T_z, ptr(x_tmp_vec[i]));
370 else if (unroll * otype_sz == 16)
371 for (uint32_t i = 0; i < 4; i++)
372 str(QReg {4 + i}, ptr(x_tmp_vec[i]));
373 else if (unroll * otype_sz == 8)
374 for (uint32_t i = 0; i < 4; i++)
375 str(DReg {4 + i}, ptr(x_tmp_vec[i]));
376 }
377
can_do_tr8x8dnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t378 bool can_do_tr8x8() {
379 using namespace data_type;
380
381 return get_sve_length() >= Xbyak_aarch64::util::SVE_256
382 && prb_.ndims >= 2
383 && ((utils::one_of(prb_.itype, u8, data_type::s8, s32, f32)
384 && utils::one_of(
385 prb_.otype, u8, data_type::s8, s32, f32)))
386 && utils::everyone_is(8, n(0), n(1))
387 && utils::everyone_is(1, os(0), is(1))
388 && prb_.scale_type == scale_type_t::NONE && prb_.beta == 0.f;
389 }
390
process_unroll_tr8x8dnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t391 bool process_unroll_tr8x8(int len) {
392 if (!can_do_tr8x8()) return false;
393
394 const int step_size = n(0) * n(1);
395 int i_off = 0, o_off = 0;
396 for (int off = 0; off < len; off += step_size) {
397 step(off, i_off, o_off, i_off, o_off, step_size);
398 tr8x8_sve256(i_off, o_off);
399 }
400
401 return true;
402 }
403
404 template <cpu_isa_t isa>
process_direct_copydnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t405 bool process_direct_copy(int len) {
406 using namespace data_type;
407
408 const int simd_w = cpu_isa_traits<isa>::vlen == 16
409 ? cpu_isa_traits<isa>::vlen / itype_sz /* use 128-bit VReg */
410 : cpu_isa_traits<isa>::vlen / itype_sz
411 / 2; /* use lower half of 512-bit ZReg */
412
413 bool can_do = true && mayiuse(isa)
414 && utils::everyone_is(1, os(0), is(0))
415 && (false || prb_.itype == prb_.otype
416 || (prb_.itype == s32 && prb_.otype == f32)
417 || (prb_.itype == f32 && prb_.otype == s32))
418 && len % simd_w == 0 && n(0) % len == 0
419 && prb_.scale_type == scale_type_t::NONE && prb_.beta == 0.f;
420 if (!can_do) return false;
421
422 for (int off = 0; off < len;) {
423 const int unroll
424 = nstl::min(16 - (prb_.otype == s32), (len - off) / simd_w);
425
426 int ur = 0;
427 int tmp_ur = 0;
428 while (ur < unroll) {
429 int count = 0;
430 const int vlen = cpu_isa_traits<isa>::vlen;
431
432 do {
433 add_imm(x_tmp_vec[count++], x_ptr_in_off,
434 (off + ur * simd_w) * itype_sz, X_DEFAULT_ADDR);
435 ur++;
436 } while (ur < unroll && count < x_tmp_vec_size);
437
438 for (int i = 0; i < count; i++) {
439 /* if (vlen == 64)
440 ldr(ZReg(tmp_ur + i), ptr(x_tmp_vec[i]));
441 else */
442 if (vlen == 64 || vlen == 32)
443 ld1w(ZRegS(tmp_ur + i), p_lsb_256 / T_z,
444 ptr(x_tmp_vec[i]));
445 else if (vlen == 16)
446 ldr(QReg(tmp_ur + i), ptr(x_tmp_vec[i]));
447 else
448 assert(!"unreachable");
449 }
450 tmp_ur += count;
451 }
452
453 if (prb_.itype != prb_.otype) {
454 const int vlen = cpu_isa_traits<isa>::vlen;
455 for (int ur = 0; ur < unroll; ++ur) {
456 if (prb_.itype == s32 && prb_.otype == f32) {
457 if (vlen == 64 || vlen == 32) {
458 ZRegS r(ur);
459 /* MSB side 256 bits are ignored. */
460 scvtf(r, p_all / T_m, r);
461 } else if (vlen == 16) {
462 VReg4S r(ur);
463 scvtf(r, r);
464 } else
465 assert(!"unreachable");
466 } else if (prb_.itype == f32 && prb_.otype == s32) {
467 /* Out of order can be expected. */
468 if (vlen == 64 || vlen == 32) {
469 ZRegS r(ur);
470 frinti(r, p_all / T_m, r);
471 fcvtzs(r, p_all / T_m, r);
472 } else if (vlen == 16) {
473 VReg4S r(ur);
474 frinti(r, r);
475 fcvtzs(r, r);
476 } else
477 assert(!"unreachable");
478 } else
479 assert(!"unreachable");
480 }
481 }
482
483 ur = 0;
484 tmp_ur = 0;
485 while (ur < unroll) {
486 int count = 0;
487 const int vlen = cpu_isa_traits<isa>::vlen;
488
489 do {
490 add_imm(x_tmp_vec[count++], x_ptr_out_off,
491 (off + ur * simd_w) * otype_sz, X_DEFAULT_ADDR);
492 ur++;
493 } while (ur < unroll && count < x_tmp_vec_size);
494
495 for (int i = 0; i < count; i++) {
496 if (vlen == 64 || vlen == 32)
497 st1w(ZRegS(tmp_ur + i), p_lsb_256 / T_z,
498 ptr(x_tmp_vec[i]));
499 else if (vlen == 16)
500 str(QReg(tmp_ur + i), ptr(x_tmp_vec[i]));
501 else
502 assert(!"unreachable");
503 }
504 tmp_ur += count;
505 }
506
507 off += unroll * simd_w;
508 }
509
510 return true;
511 }
512
process_unroll_generic_stepdnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t513 void process_unroll_generic_step(int reg_unroll, const int *i_off,
514 const int *o_off, const int *s_off) {
515 using namespace data_type;
516
517 auto cvt2ps
518 = [=](const int startIdx, const int regNum, data_type_t idt) {
519 switch (idt) {
520 case f32:
521 /* do nothing */
522 break;
523 case s32: cvt_v_s32_f32(startIdx, regNum); break;
524 case data_type::s8:
525 cvt_v_s8_s32(startIdx, regNum);
526 cvt_v_s32_f32(startIdx, regNum);
527 break;
528 case u8:
529 cvt_v_u8_s32(startIdx, regNum);
530 cvt_v_s32_f32(startIdx, regNum);
531 break;
532 default: assert(!"unreachable");
533 }
534 };
535
536 auto cvt2odt = [=](const int startIdx, const int regNum,
537 data_type_t odt, data_type_t idt) {
538 switch (odt) {
539 case s32:
540 if (idt == f32)
541 cvt_v_f32_s32(startIdx, regNum);
542 else if (idt == data_type::s8)
543 cvt_v_s8_s32(startIdx, regNum);
544 else if (idt == u8)
545 cvt_v_u8_s32(startIdx, regNum);
546 break;
547 case data_type::s8:
548 if (idt == f32) cvt_v_f32_s32(startIdx, regNum);
549 if (idt == f32 || idt == s32)
550 cvt_v_s32_s8(startIdx, regNum);
551 if (idt == u8) { cvt_v_u8_s8(startIdx, regNum); }
552 break;
553 case u8:
554 if (idt == f32) cvt_v_f32_s32(startIdx, regNum);
555 if (idt == f32 || idt == s32)
556 cvt_v_s32_u8(startIdx, regNum);
557 if (idt == data_type::s8) cvt_v_s8_u8(startIdx, regNum);
558 break;
559 default: assert(!"unreachable");
560 }
561 };
562
563 /* check whether loading 4 values at once is possible */
564 bool can_load_xmm = reg_unroll % 4 == 0;
565 for (int ur = 1; ur < reg_unroll; ++ur)
566 if (i_off[ur] != i_off[ur - 1] + 1) can_load_xmm = false;
567 const int load_step = can_load_xmm ? 4 : 1;
568
569 /* check whether storing 4 values at once is possible */
570 bool can_store_xmm = reg_unroll % 4 == 0;
571 for (int ur = 1; ur < reg_unroll; ++ur)
572 if (o_off[ur] != o_off[ur - 1] + 1) can_store_xmm = false;
573 const int ur_step = can_store_xmm ? 4 : 1;
574
575 const bool interim_f32 = false
576 || utils::one_of(f32, prb_.itype, prb_.otype)
577 || prb_.scale_type != scale_type_t::NONE || prb_.beta != 0.f;
578
579 const bool need_saturation
580 = (utils::one_of(prb_.otype, u8, data_type::s8, s32)
581 && interim_f32);
582
583 if (!can_load_xmm && can_store_xmm) {
584 assert(ur_step == 4);
585 /* load with stride */
586 for (int ur = 0; ur < reg_unroll; ur += ur_step) {
587
588 /* x_tmp_vec = X_TMP_0 - X_TMP_4
589 Do not use X_TMP_? as the last arg. */
590 for (int r = 0; r < ur_step; ++r) {
591 add_imm(x_tmp_vec[r], x_ptr_in_off,
592 i_off[ur + r] * itype_sz, X_DEFAULT_ADDR);
593 }
594
595 for (int r = 0; r < ur_step; ++r) {
596 if (itype_sz == 4)
597 ld1(VReg4S(ur)[r], ptr(x_tmp_vec[r]));
598 else if (itype_sz == 2)
599 ld1(VReg8H(ur)[r], ptr(x_tmp_vec[r]));
600 else
601 ld1(VReg16B(ur)[r], ptr(x_tmp_vec[r]));
602 }
603 }
604 } else {
605 int ur = 0;
606 int tmp_ur = 0;
607 while (ur < reg_unroll) {
608 int count = 0;
609
610 do {
611 add_imm(x_tmp_vec[count++], x_ptr_in_off,
612 i_off[ur] * itype_sz, X_DEFAULT_ADDR);
613 ur += load_step;
614 } while (ur < reg_unroll && count < x_tmp_vec_size);
615
616 for (int i = 0; i < count; i++) {
617
618 switch (load_step * itype_sz) {
619 case 16: ldr(QReg(tmp_ur), ptr(x_tmp_vec[i])); break;
620 case 8: ldr(DReg(tmp_ur), ptr(x_tmp_vec[i])); break;
621 case 4: ldr(SReg(tmp_ur), ptr(x_tmp_vec[i])); break;
622 case 2: ldr(HReg(tmp_ur), ptr(x_tmp_vec[i])); break;
623 case 1: ldr(BReg(tmp_ur), ptr(x_tmp_vec[i])); break;
624 default: assert(!"unreachable");
625 }
626 tmp_ur += load_step;
627 }
628 }
629 }
630
631 /* xmm[:] <-- (f32)xmm[:] */
632 if (interim_f32) {
633 const int cvt_step = nstl::max(load_step, ur_step);
634 for (int ur = 0; ur < reg_unroll; ur += cvt_step)
635 cvt2ps(ur, 1, prb_.itype);
636 }
637
638 if (can_load_xmm && !can_store_xmm) {
639 const bool fast_return = true // transposition on the fly
640 && prb_.scale_type != scale_type_t::MANY
641 && prb_.beta == 0.f;
642 if (fast_return) {
643 if (prb_.scale_type == scale_type_t::COMMON)
644 for (int ur = 0; ur < reg_unroll; ur += load_step)
645 fmul(VReg4S(ur), VReg4S(ur), xmm_scale);
646 if (prb_.otype != f32) {
647 init_saturate_f32(xmm_zero, xmm_saturation_ubound, reg_tmp,
648 interim_f32 ? f32 : prb_.itype, prb_.otype);
649 for (int ur = 0; ur < reg_unroll; ur += load_step)
650 if (need_saturation)
651 saturate_f32(VReg4S(ur), xmm_zero,
652 xmm_saturation_ubound, prb_.otype, p_all);
653
654 for (int ur = 0; ur < reg_unroll; ur += load_step)
655 cvt2odt(ur, 1, prb_.otype,
656 interim_f32 ? f32 : prb_.itype);
657 }
658 /* load_step is 1 or 4. */
659 for (int ur = 0; ur < reg_unroll; ur += load_step) {
660 for (int r = 0; r < load_step; ++r) {
661 add_imm(x_tmp_vec[r], x_ptr_out_off,
662 o_off[ur + r] * otype_sz, X_DEFAULT_ADDR);
663 }
664
665 for (int r = 0; r < load_step; ++r) {
666 if (otype_sz == 4)
667 st1(VReg4S(ur)[r], ptr(x_tmp_vec[r]));
668 else if (otype_sz == 2)
669 st1(VReg8H(ur)[r], ptr(x_tmp_vec[r]));
670 else
671 st1(VReg16B(ur)[r], ptr(x_tmp_vec[r]));
672 }
673 }
674 return;
675 }
676
677 /* scatter elements of xmm into 4 xmms */
678 if (itype_sz == 4 || interim_f32) {
679 for (int ur = 0; ur < reg_unroll; ur += load_step)
680 for (int r = 1; r < load_step; ++r) {
681 VReg4S v(ur);
682 VReg4S v_r(ur + r);
683 dup(VReg16B(ur + r), VReg16B(ur)[0]);
684 ins(VReg4S(ur + r)[0], VReg4S(ur)[r]);
685 }
686 } else {
687 for (int ur = 0; ur < reg_unroll; ur += load_step)
688 for (int r = 1; r < load_step; ++r)
689 ext(VReg16B(ur + r), VReg16B(ur), VReg16B(ur),
690 itype_sz * r);
691 }
692 }
693
694 /* scale and beta processing */
695 if (can_store_xmm) {
696 /* xmm <-- scale * xmm[:] */
697 if (prb_.scale_type == scale_type_t::COMMON) {
698 for (int ur = 0; ur < reg_unroll; ur += ur_step)
699 fmul(VReg4S(ur), VReg4S(ur), xmm_scale);
700 } else if (prb_.scale_type == scale_type_t::MANY) {
701 enum class scale_load_type_t { bcast, load, gather };
702
703 for (int ur = 0; ur < reg_unroll; ur += ur_step) {
704 scale_load_type_t scale_load_type
705 = scale_load_type_t::bcast; // the best case
706
707 for (int r = ur + 1; r < ur + ur_step; ++r)
708 if (s_off[r] != s_off[r - 1] + 0)
709 scale_load_type = scale_load_type_t::load;
710
711 if (scale_load_type == scale_load_type_t::bcast) {
712 VReg4S v(xmm_scale.getIdx());
713 VReg4S v_dst(ur);
714 add_imm(X_TMP_0, x_ptr_scale_off, s_off[ur] * stype_sz,
715 X_DEFAULT_ADDR);
716 ldr(W_TMP_0, ptr(X_TMP_0));
717 dup(v, W_TMP_0);
718 fmul(v_dst, v_dst, v);
719 continue;
720 }
721
722 // bcast doesn't work, the next try -- load
723 for (int r = ur + 1; r < ur + ur_step; ++r)
724 if (s_off[r] != s_off[r - 1] + 1)
725 scale_load_type = scale_load_type_t::gather;
726
727 if (scale_load_type == scale_load_type_t::load) {
728 uint32_t idx = xmm_scale.getIdx();
729 VReg4S v_dst(ur);
730 add_imm(X_TMP_0, x_ptr_scale_off, s_off[ur] * stype_sz,
731 X_DEFAULT_ADDR);
732
733 ldr(QReg {idx}, ptr(X_TMP_0));
734 fmul(v_dst, v_dst, VReg4S {idx});
735 continue;
736 }
737
738 // load doesn't work as well
739 // so gather the scale factors one by one
740 /*ur_step is 1 or 4. */
741 for (int r = ur; r < ur + ur_step; ++r) {
742 /* x_tmp_vec = X_TMP_0 - X_TMP_4
743 Do not use X_TMP_? as the last arg. */
744 add_imm(x_tmp_vec[r - ur], x_ptr_scale_off,
745 s_off[r] * stype_sz, X_DEFAULT_ADDR);
746 }
747 for (int r = ur; r < ur + ur_step; ++r) {
748 VReg4S v(xmm_scale.getIdx());
749 ld1(v[r - ur], ptr(x_tmp_vec[r - ur]));
750 }
751 fmul(VReg4S(ur), VReg4S(ur), xmm_scale);
752 }
753 }
754
755 /* dst <-- beta * dst + xmm[:] */
756 assert(prb_.beta == 0.f || prb_.beta == 1.f);
757 if (prb_.beta == 1.f) {
758 int ur = 0;
759 int tmp_ur = 0;
760
761 while (ur < reg_unroll) {
762 int count = 0;
763
764 do {
765 add_imm(x_tmp_vec[count++], x_ptr_out_off,
766 o_off[ur] * otype_sz, X_DEFAULT_ADDR);
767 ur += ur_step;
768 } while (ur < reg_unroll && count < x_tmp_vec_size);
769
770 assert(count <= z_tmp_vec_size);
771 /* Firstly, data is loaded. */
772 for (int i = 0; i < count; i++) {
773
774 if (prb_.otype == f32 || prb_.otype == s32) {
775 ldr(QReg(tmp_vec_idx[i]), ptr(x_tmp_vec[i])); // bug
776 } else if (prb_.otype == data_type::s8
777 || prb_.otype == u8) {
778 ldr(SReg(tmp_vec_idx[i]), ptr(x_tmp_vec[i])); // bug
779 } else
780 assert(!"unreachable");
781 }
782
783 /* Secondly, it is added. */
784 if (prb_.otype == f32) {
785 for (int i = 0; i < count; i++) {
786 VReg4S v(tmp_ur);
787 fadd(v, v, VReg4S(tmp_vec_idx[i]));
788 tmp_ur += ur_step;
789 }
790 } else {
791 for (int i = 0; i < count; i++) {
792 /* cvt2ps() generate successive instructions
793 which have save destination operand,
794 but out of order can be expected. */
795 cvt2ps(tmp_vec_idx[i], 1, prb_.otype);
796 }
797 for (int i = 0; i < count; i++) {
798 VReg4S v(tmp_ur);
799 fadd(v, v, VReg4S(tmp_vec_idx[i]));
800 tmp_ur += ur_step;
801 }
802 }
803 }
804 }
805 } else {
806 /* xmm[0] <-- scale * xmm[0] */
807 if (prb_.scale_type == scale_type_t::COMMON) {
808 for (int ur = 0; ur < reg_unroll; ur += ur_step) {
809 VReg4S tmp(ur);
810 fmul(tmp, tmp, VReg4S(xmm_scale.getIdx()));
811 }
812 } else if (prb_.scale_type == scale_type_t::MANY) {
813 int ur = 0;
814 int tmp_ur = 0;
815 while (ur < reg_unroll) {
816 int count = 0;
817
818 do {
819 add_imm(x_tmp_vec[count++], x_ptr_scale_off,
820 s_off[ur] * stype_sz, X_DEFAULT_ADDR);
821 ur += ur_step;
822 } while (ur < reg_unroll && count < x_tmp_vec_size);
823
824 for (int i = 0; i < count; i++)
825 ldr(SReg(tmp_vec_idx[i]), ptr(x_tmp_vec[i]));
826 for (int i = 0; i < count; i++) {
827 VReg4S tmp(tmp_ur + ur_step * i);
828 fmul(tmp, tmp, VReg4S(tmp_vec_idx[i]));
829 }
830
831 tmp_ur += ur_step * count;
832 }
833 }
834
835 /* dst <-- beta * dst + xmm[0] */
836 assert(prb_.beta == 0.f || prb_.beta == 1.f);
837 if (prb_.beta == 1.f) {
838 int ur = 0;
839 int tmp_ur = 0;
840 while (ur < reg_unroll) {
841 int count = 0;
842
843 do {
844 add_imm(x_tmp_vec[count++], x_ptr_out_off,
845 o_off[ur] * otype_sz, X_DEFAULT_ADDR);
846 ur += ur_step;
847 } while (ur < reg_unroll && count < (x_tmp_vec_size / 2));
848
849 assert(static_cast<size_t>(count) <= z_tmp_vec.size());
850
851 if (prb_.otype == f32) {
852 /* addss: dest[31:0] <- src1[31:0] + src2[31:0]
853 dset[MAXVL-1:32] (Unmodified) */
854 for (int i = 0; i < count; i++) {
855 ld1(VReg4S(z_tmp_vec[i].getIdx())[0],
856 ptr(x_tmp_vec[i]));
857 }
858 for (int i = 0; i < count; i++) {
859 SReg s {tmp_vec_idx[i]};
860 fadd(s, s, SReg(tmp_ur + ur_step * i));
861 }
862 for (int i = 0; i < count; i++) {
863 mov(VReg4S(tmp_ur)[0], VReg4S(tmp_vec_idx[i])[0]);
864 tmp_ur += ur_step;
865 }
866 } else {
867 for (int i = 0; i < count; i++) {
868 if (prb_.otype == s32) {
869 ldr(SReg(tmp_vec_idx[i]), ptr(x_tmp_vec[i]));
870 } else if (utils::one_of(
871 prb_.otype, data_type::s8, u8)) {
872 ldr(BReg(tmp_vec_idx[i]), ptr(x_tmp_vec[i]));
873 } else {
874 assert(!"unsupported o_type");
875 }
876 cvt2ps(tmp_vec_idx[i], 1, prb_.otype);
877 }
878 for (int i = 0; i < count; i++) {
879 VReg4S v(tmp_ur);
880 fadd(v, v, VReg4S(tmp_vec_idx[i]));
881 tmp_ur += ur_step;
882 }
883 }
884 }
885 }
886 }
887
888 if (need_saturation) {
889 init_saturate_f32(
890 xmm_zero, xmm_saturation_ubound, reg_tmp, f32, prb_.otype);
891 for (int ur = 0; ur < reg_unroll; ur += ur_step) {
892 saturate_f32(VReg4S(ur), xmm_zero, xmm_saturation_ubound,
893 prb_.otype, p_all);
894 }
895 }
896
897 for (int ur = 0; ur < reg_unroll; ur += ur_step) {
898 if (prb_.otype != f32)
899 cvt2odt(ur, 1, prb_.otype, interim_f32 ? f32 : prb_.itype);
900 }
901
902 int ur = 0;
903 int tmp_ur = 0;
904 while (ur < reg_unroll) {
905 int count = 0;
906
907 do {
908 add_imm(x_tmp_vec[count++], x_ptr_out_off, o_off[ur] * otype_sz,
909 X_DEFAULT_ADDR);
910 ur += ur_step;
911 } while (ur < reg_unroll && count < x_tmp_vec_size);
912
913 for (int i = 0; i < count; i++) {
914
915 switch (ur_step * otype_sz) {
916 case 16: str(QReg(tmp_ur), ptr(x_tmp_vec[i])); break;
917 case 8: str(DReg(tmp_ur), ptr(x_tmp_vec[i])); break;
918 case 4: str(SReg(tmp_ur), ptr(x_tmp_vec[i])); break;
919 case 2: str(HReg(tmp_ur), ptr(x_tmp_vec[i])); break;
920 case 1: str(BReg(tmp_ur), ptr(x_tmp_vec[i])); break;
921 default: assert(!"unreachable");
922 }
923 tmp_ur += ur_step;
924 }
925 }
926 }
927
process_unroll_genericdnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t928 void process_unroll_generic(int len) {
929 const int blk = 8;
930
931 int i_off[2 * blk] = {0};
932 int o_off[2 * blk] = {0};
933 int s_off[2 * blk] = {0};
934
935 int curr = 0; // will switch between 0 and 1
936
937 for (int off = 0; off < len; off += blk) {
938 const int reg_unroll = nstl::min(off + blk, len) - off;
939
940 /* compute offsets */
941 for (int ur = off != 0 ? 0 : 1; ur < reg_unroll; ++ur) {
942 const int ur_c = curr * blk + ur;
943 const int ur_p = (ur_c - 1 + 2 * blk) % (2 * blk); // prev ur
944 step(off + ur, i_off[ur_p], o_off[ur_p], s_off[ur_p],
945 i_off[ur_c], o_off[ur_c], s_off[ur_c]);
946 }
947
948 process_unroll_generic_step(reg_unroll, i_off + curr * blk,
949 o_off + curr * blk, s_off + curr * blk);
950
951 curr = 1 - curr;
952 }
953 }
954
loop_begindnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t955 void loop_begin(Label &l, XReg reg_cnt, int len) {
956 mov(reg_cnt, len);
957 L(l);
958 }
959
loop_enddnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t960 void loop_end(Label &l, XReg reg_cnt, int len, int i_step, int o_step,
961 int s_step) {
962 add_imm(reg_off_in, reg_off_in, i_step * itype_sz, X_TMP_0);
963 add_imm(reg_off_out, reg_off_out, o_step * otype_sz, X_TMP_0);
964 add_imm(x_ptr_in_off, x_ptr_in_off, i_step * itype_sz, X_TMP_0);
965 add_imm(x_ptr_out_off, x_ptr_out_off, o_step * otype_sz, X_TMP_0);
966
967 if (prb_.scale_type == scale_type_t::MANY) {
968 add_imm(reg_off_scale, reg_off_scale, s_step * stype_sz, X_TMP_0);
969 add_imm(x_ptr_scale_off, x_ptr_scale_off, s_step * stype_sz,
970 X_TMP_0);
971 }
972 subs(reg_cnt, reg_cnt, 1);
973 b(NE, l);
974
975 sub_imm(reg_off_in, reg_off_in, len * i_step * itype_sz, X_TMP_0);
976 sub_imm(reg_off_out, reg_off_out, len * o_step * otype_sz, X_TMP_0);
977 sub_imm(x_ptr_in_off, x_ptr_in_off, len * i_step * itype_sz, X_TMP_0);
978 sub_imm(x_ptr_out_off, x_ptr_out_off, len * o_step * otype_sz, X_TMP_0);
979
980 if (prb_.scale_type == scale_type_t::MANY) {
981 sub_imm(reg_off_scale, reg_off_scale, len * s_step * stype_sz,
982 X_TMP_0);
983 sub_imm(x_ptr_scale_off, x_ptr_scale_off, len * s_step * stype_sz,
984 X_TMP_0);
985 }
986 }
987
simple_impldnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t988 bool simple_impl() {
989 simple_impl_desc_t d;
990 if (!simple_impl_desc_init(prb_, &d)) return false;
991
992 const int nfu = d.ndims_full_unroll;
993 const int ldu = d.len_last_dim_unroll;
994 const int n_jit_loops = prb_.ndims - d.ndims_full_unroll;
995 assert(n_jit_loops <= ndims_jit_loop_max);
996
997 eor(reg_off_in, reg_off_in, reg_off_in);
998 eor(reg_off_out, reg_off_out, reg_off_out);
999 mov(x_ptr_in_off, XReg(reg_ptr_in.getIdx()));
1000 mov(x_ptr_out_off, XReg(reg_ptr_out.getIdx()));
1001 if (prb_.scale_type == scale_type_t::MANY) {
1002 eor(reg_off_scale, reg_off_scale, reg_off_scale);
1003 mov(x_ptr_scale_off, XReg(reg_ptr_scale.getIdx()));
1004 }
1005
1006 Label l_loop[3];
1007 XReg reg_cnt[3] = {x15, x14, x13};
1008
1009 if (n_jit_loops > 2) loop_begin(l_loop[2], reg_cnt[2], n(nfu + 2));
1010
1011 if (n_jit_loops > 1) loop_begin(l_loop[1], reg_cnt[1], n(nfu + 1));
1012
1013 if (n_jit_loops > 0)
1014 loop_begin(l_loop[0], reg_cnt[0], n(nfu + 0) / ldu);
1015
1016 bool optimized = false;
1017 optimized = optimized || process_direct_copy<sve_512>(d.len_unroll);
1018 optimized = optimized || process_direct_copy<asimd>(d.len_unroll);
1019 optimized = optimized || process_unroll_tr8x8(d.len_unroll);
1020 if (!optimized) process_unroll_generic(d.len_unroll);
1021
1022 if (n_jit_loops > 0)
1023 loop_end(l_loop[0], reg_cnt[0], n(nfu + 0) / ldu, is(nfu + 0) * ldu,
1024 os(nfu + 0) * ldu, ss(nfu + 0) * ldu);
1025
1026 if (n_jit_loops > 1)
1027 loop_end(l_loop[1], reg_cnt[1], n(nfu + 1), is(nfu + 1),
1028 os(nfu + 1), ss(nfu + 1));
1029
1030 if (n_jit_loops > 2)
1031 loop_end(l_loop[2], reg_cnt[2], n(nfu + 2), is(nfu + 2),
1032 os(nfu + 2), ss(nfu + 2));
1033
1034 return true;
1035 }
1036
impldnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t1037 void impl() {
1038 if (simple_impl()) return;
1039 assert(!"no implementation available");
1040 }
1041
1042 #define UNROLL_INST(inst, reg, ...) \
1043 for (size_t i = startIdx; i < startIdx + regNum; i++) { \
1044 reg tmp(i); \
1045 inst(__VA_ARGS__); \
1046 }
1047 #define UNROLL_INST2(inst, ...) \
1048 for (size_t i = startIdx; i < startIdx + regNum; i++) \
1049 inst(__VA_ARGS__);
1050
cvt_z_s32_f32dnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t1051 void cvt_z_s32_f32(const size_t startIdx, const size_t regNum) {
1052 UNROLL_INST(scvtf, ZRegS, tmp, p_all / T_m, tmp);
1053 }
1054
cvt_v_s32_f32dnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t1055 void cvt_v_s32_f32(const size_t startIdx, const size_t regNum) {
1056 UNROLL_INST(scvtf, VReg4S, tmp, tmp);
1057 }
1058
cvt_z_f32_s32dnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t1059 void cvt_z_f32_s32(const size_t startIdx, const size_t regNum) {
1060 UNROLL_INST(frinti, ZRegS, tmp, p_all / T_m, tmp);
1061 UNROLL_INST(fcvtzs, ZRegS, tmp, p_all / T_m, tmp);
1062 }
1063
cvt_v_f32_s32dnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t1064 void cvt_v_f32_s32(const size_t startIdx, const size_t regNum) {
1065 UNROLL_INST(frinti, VReg4S, tmp, tmp);
1066 UNROLL_INST(fcvtzs, VReg4S, tmp, tmp);
1067 }
1068
cvt_z_s8_s32dnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t1069 void cvt_z_s8_s32(const size_t startIdx, const size_t regNum) {
1070 cvt_z_b_s(startIdx, regNum);
1071 UNROLL_INST(sxtb, ZRegS, tmp, p_all / T_m, tmp);
1072 }
1073
cvt_v_s8_s32dnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t1074 void cvt_v_s8_s32(const size_t startIdx, const size_t regNum) {
1075 UNROLL_INST(sxtl, VReg, tmp.h8, tmp.b8);
1076 UNROLL_INST(sxtl, VReg, tmp.s4, tmp.h4);
1077 }
1078
cvt_z_s8_f32dnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t1079 void cvt_z_s8_f32(const size_t startIdx, const size_t regNum) {
1080 cvt_z_b_s(startIdx, regNum);
1081 cvt_z_s32_f32(startIdx, regNum);
1082 }
1083
cvt_v_s8_f32dnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t1084 void cvt_v_s8_f32(const size_t startIdx, const size_t regNum) {
1085 cvt_v_b_s(startIdx, regNum);
1086 cvt_v_s32_f32(startIdx, regNum);
1087 }
1088
cvt_z_b_sdnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t1089 void cvt_z_b_s(const size_t startIdx, const size_t regNum) {
1090 assert(z_tmp7.getIdx() < startIdx
1091 || startIdx + regNum - 1 < z_tmp7.getIdx());
1092
1093 dup(z_tmp7.b, 0);
1094 UNROLL_INST(zip1, ZRegB, tmp, tmp, z_tmp7.b);
1095 UNROLL_INST(zip1, ZRegH, tmp, tmp, z_tmp7.h);
1096 }
1097
cvt_v_b_sdnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t1098 void cvt_v_b_s(const size_t startIdx, const size_t regNum) {
1099 assert(v_tmp7.getIdx() < startIdx
1100 || startIdx + regNum - 1 < v_tmp7.getIdx());
1101
1102 mov_imm(W_TMP_0, 0);
1103 dup(v_tmp7.b16, W_TMP_0);
1104 UNROLL_INST(zip1, VReg16B, tmp, tmp, v_tmp7.b16);
1105 UNROLL_INST(zip1, VReg8H, tmp, tmp, v_tmp7.h8);
1106 }
1107
cvt_z_u8_s32dnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t1108 void cvt_z_u8_s32(const size_t startIdx, const size_t regNum) {
1109 cvt_z_b_s(startIdx, regNum);
1110 UNROLL_INST(uxtb, ZRegS, tmp, p_all / T_m, tmp);
1111 }
1112
cvt_v_u8_s32dnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t1113 void cvt_v_u8_s32(const size_t startIdx, const size_t regNum) {
1114 UNROLL_INST(uxtl, VReg, tmp.h8, tmp.b8);
1115 UNROLL_INST(uxtl, VReg, tmp.s4, tmp.h4);
1116 }
1117
cvt_z_s32_s8dnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t1118 void cvt_z_s32_s8(const size_t startIdx, const size_t regNum) {
1119 assert(z_tmp7.getIdx() < startIdx
1120 || startIdx + regNum - 1 < z_tmp7.getIdx());
1121
1122 dup(z_tmp7.s, 0);
1123 UNROLL_INST2(smin, ZRegS(i), 127);
1124 UNROLL_INST2(smax, ZRegS(i), -128);
1125 UNROLL_INST(uzp1, ZRegH, tmp, tmp, z_tmp7.h);
1126 UNROLL_INST(uzp1, ZRegB, tmp, tmp, z_tmp7.b);
1127 }
1128
cvt_v_s32_s8dnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t1129 void cvt_v_s32_s8(const size_t startIdx, const size_t regNum) {
1130 assert(v_tmp7.getIdx() < startIdx
1131 || startIdx + regNum - 1 < v_tmp7.getIdx());
1132
1133 mov_imm(W_TMP_0, 127);
1134 dup(v_tmp7.s4, W_TMP_0);
1135 mov_imm(W_TMP_0, -128);
1136 UNROLL_INST2(smin, VReg4S(i), VReg4S(i), v_tmp7.s4);
1137 dup(v_tmp7.s4, W_TMP_0);
1138 UNROLL_INST2(smax, VReg4S(i), VReg4S(i), v_tmp7.s4);
1139 mov_imm(W_TMP_0, 0);
1140 dup(v_tmp7.s4, W_TMP_0);
1141 UNROLL_INST(uzp1, VReg8H, tmp, tmp, v_tmp7.h8);
1142 UNROLL_INST(uzp1, VReg16B, tmp, tmp, v_tmp7.b16);
1143 }
1144
cvt_z_u8_s8dnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t1145 void cvt_z_u8_s8(const size_t startIdx, const size_t regNum) {
1146 UNROLL_INST2(umin, ZRegB(i), 127);
1147 }
1148
cvt_v_u8_s8dnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t1149 void cvt_v_u8_s8(const size_t startIdx, const size_t regNum) {
1150 assert(v_tmp7.getIdx() < startIdx
1151 || startIdx + regNum - 1 < v_tmp7.getIdx());
1152
1153 mov_imm(W_TMP_0, 127);
1154 dup(v_tmp7.b16, W_TMP_0);
1155 UNROLL_INST(umin, VReg16B, tmp, tmp, v_tmp7.b16);
1156 }
1157
cvt_z_u32_u8dnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t1158 void cvt_z_u32_u8(const size_t startIdx, const size_t regNum) {
1159 UNROLL_INST2(umin, ZRegS(i), 255);
1160 UNROLL_INST(uzp1, ZRegH, tmp, tmp, tmp);
1161 UNROLL_INST(uzp1, ZRegB, tmp, tmp, tmp);
1162 }
1163
cvt_v_u32_u8dnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t1164 void cvt_v_u32_u8(const size_t startIdx, const size_t regNum) {
1165 assert(v_tmp7.getIdx() < startIdx
1166 || startIdx + regNum - 1 < v_tmp7.getIdx());
1167
1168 mov_imm(W_TMP_0, 255);
1169 dup(v_tmp7.s4, W_TMP_0);
1170 UNROLL_INST(umin, VReg4S, tmp, tmp, v_tmp7.s4);
1171 UNROLL_INST(uzp1, VReg8H, tmp, tmp, tmp);
1172 UNROLL_INST(uzp1, VReg16B, tmp, tmp, tmp);
1173 }
1174
cvt_z_s32_u8dnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t1175 void cvt_z_s32_u8(const size_t startIdx, const size_t regNum) {
1176 assert(z_tmp7.getIdx() < startIdx
1177 || startIdx + regNum - 1 < z_tmp7.getIdx());
1178
1179 dupm(z_tmp7.s, 255);
1180 UNROLL_INST2(smax, ZRegS(i), 0);
1181 UNROLL_INST2(smin, ZRegS(i), p_all / T_m, z_tmp7.s);
1182 UNROLL_INST(uzp1, ZRegH, tmp, tmp, tmp);
1183 UNROLL_INST(uzp1, ZRegB, tmp, tmp, tmp);
1184 UNROLL_INST2(mov, ZRegB(i), P_NOT_128 / T_m, 0);
1185 }
1186
cvt_v_s32_u8dnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t1187 void cvt_v_s32_u8(const size_t startIdx, const size_t regNum) {
1188 assert(v_tmp7.getIdx() < startIdx
1189 || startIdx + regNum - 1 < v_tmp7.getIdx());
1190
1191 mov_imm(W_TMP_0, 0);
1192 dup(v_tmp7.s4, W_TMP_0);
1193 mov_imm(W_TMP_0, 255);
1194 UNROLL_INST(smax, VReg4S, tmp, tmp, v_tmp7.s4);
1195 dup(v_tmp7.s4, W_TMP_0);
1196 UNROLL_INST(smin, VReg4S, tmp, tmp, v_tmp7.s4);
1197 UNROLL_INST(uzp1, VReg8H, tmp, tmp, tmp);
1198 UNROLL_INST(uzp1, VReg16B, tmp, tmp, tmp);
1199 }
1200
cvt_z_s8_u8dnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t1201 void cvt_z_s8_u8(const size_t startIdx, const size_t regNum) {
1202 UNROLL_INST2(smax, ZRegB(i), 0);
1203 }
1204
cvt_v_s8_u8dnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t1205 void cvt_v_s8_u8(const size_t startIdx, const size_t regNum) {
1206 assert(v_tmp7.getIdx() < startIdx
1207 || startIdx + regNum - 1 < v_tmp7.getIdx());
1208
1209 mov_imm(W_TMP_0, 0);
1210 dup(v_tmp7.b16, W_TMP_0);
1211 UNROLL_INST(smax, VReg16B, tmp, tmp, v_tmp7.b16);
1212 }
1213 #undef UNROLL_INST
1214 #undef UNROLL_INST
1215
jit_uni_reorder_kernel_f32_tdnnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t1216 jit_uni_reorder_kernel_f32_t(const desc_t &desc) : kernel_t(desc) {
1217 itype_sz = data_type_size(prb_.itype);
1218 otype_sz = data_type_size(prb_.otype);
1219 stype_sz = sizeof(float);
1220 }
1221
generatednnl::impl::cpu::aarch64::tr::jit_uni_reorder_kernel_f32_t1222 void generate() override {
1223 using namespace Xbyak_aarch64::util;
1224 uint64_t sveLen = get_sve_length();
1225
1226 preamble();
1227 #define PARAM(x) offsetof(call_param_t, x)
1228 if (prb_.scale_type == scale_type_t::COMMON) {
1229 add_imm(X_DEFAULT_ADDR, abi_param1, PARAM(scale), X_TMP_1);
1230 ldr(X_TMP_0, ptr(X_DEFAULT_ADDR));
1231 ldr(W_TMP_1, ptr(X_TMP_0));
1232 dup(xmm_scale, W_TMP_1);
1233 } else if (prb_.scale_type == scale_type_t::MANY) {
1234 add_imm(X_DEFAULT_ADDR, abi_param1, PARAM(scale), X_TMP_0);
1235 ldr(reg_ptr_scale, ptr(X_DEFAULT_ADDR));
1236 }
1237 add_imm(X_TMP_0, abi_param1, PARAM(in), X_TMP_2);
1238 add_imm(X_TMP_1, abi_param1, PARAM(out), X_TMP_2);
1239 ldr(reg_ptr_in, ptr(X_TMP_0));
1240 ldr(reg_ptr_out, ptr(X_TMP_1));
1241 #undef PARAM
1242
1243 mov(x_ptr_in_off, XReg(reg_ptr_in.getIdx()));
1244 mov(x_ptr_out_off, XReg(reg_ptr_out.getIdx()));
1245 mov(x_ptr_scale_off, XReg(reg_ptr_scale.getIdx()));
1246
1247 if (sveLen) { /* SVE is available. */
1248 ptrue(p_lsb_256.b, VL32);
1249 ptrue(p_all.b);
1250 }
1251
1252 if (can_do_tr8x8()) {
1253 dup(ymm_zero, 0);
1254
1255 if (prb_.itype == data_type::u8 && prb_.otype == data_type::s8) {
1256 mov_imm(reg_tmp, 0x7f7f7f7f7f7f7f7f);
1257 mov(VReg4S(ymm_8x127b.getIdx())[0], WReg(reg_tmp.getIdx()));
1258 }
1259 } else if (mayiuse(sve_512)) {
1260 movi(xmm_zero, 0);
1261
1262 if (prb_.itype == data_type::u8 && prb_.otype == data_type::s8) {
1263 mov(WReg(reg_tmp.getIdx()), 0x7f7f7f7f);
1264 mov(xmm_4x127b[0], WReg(reg_tmp.getIdx()));
1265 }
1266 }
1267
1268 impl();
1269 postamble();
1270 }
1271
1272 private:
1273 int itype_sz;
1274 int otype_sz;
1275 int stype_sz;
1276
1277 XReg reg_ptr_in = x6;
1278 XReg reg_ptr_out = x2;
1279 XReg reg_ptr_scale = abi_not_param1;
1280
1281 XReg reg_off_in = x8;
1282 XReg reg_off_out = x9;
1283 XReg reg_off_scale = x10;
1284
1285 XReg reg_tmp = x0;
1286
1287 VReg4S xmm_scale = v15.s;
1288 VReg4S xmm_zero = v14.s;
1289 VReg4S xmm_4x127b = v13.s; // TODO: unite with ymm_zero
1290 ZRegS ymm_zero = z14.s;
1291 ZRegS ymm_8x127b = z13.s;
1292 VReg4S xmm_tmp = v12.s;
1293 VReg4S xmm_saturation_ubound = v12.s;
1294 ZRegS ymm_saturation_ubound = z12.s;
1295
1296 /* Note: x22 - x28 are already used as temporal registgers
1297 in jit_generator.hpp.
1298 x_ptr_(in|out|scale)_off keeps (base + offset) address. */
1299 XReg x_ptr_in_off = x16;
1300 XReg x_ptr_out_off = x18;
1301 XReg x_ptr_scale_off = x20;
1302
1303 /* Caution: Chose predicate registers not used by x64's implementation. */
1304 PReg p_lsb_256 = p7;
1305 PReg p_all = p6;
1306 PReg p_tmp0 = p5;
1307
1308 const std::vector<uint32_t> tmp_vec_idx = {20, 21, 22, 23, 24, 25, 26, 27};
1309 ZReg z_tmp0 = z20;
1310 ZReg z_tmp1 = z21;
1311 ZReg z_tmp2 = z22;
1312 ZReg z_tmp3 = z23;
1313 ZReg z_tmp4 = z24;
1314 ZReg z_tmp5 = z25;
1315 ZReg z_tmp6 = z26;
1316 ZReg z_tmp7 = z27;
1317 VReg v_tmp7 = v27;
1318
1319 const std::vector<ZReg> z_tmp_vec
1320 = {z_tmp0, z_tmp1, z_tmp2, z_tmp3, z_tmp4, z_tmp5, z_tmp6, z_tmp7};
1321 constexpr static int z_tmp_vec_size = 8;
1322 };
1323
desc_init(kernel_t::desc_t & desc,const prb_t & prb,int ndims_ker_max)1324 status_t kernel_t::desc_init(
1325 kernel_t::desc_t &desc, const prb_t &prb, int ndims_ker_max) {
1326 desc.prb = prb;
1327 desc.prb.ioff = desc.prb.ooff = 0;
1328
1329 if (ndims_ker_max > prb.ndims) return status::invalid_arguments;
1330
1331 auto ndims_ker_max_f = [&]() {
1332 size_t cur_size = 1;
1333 for (int d = 0; d < prb.ndims; cur_size *= prb.nodes[d++].n)
1334 if (cur_size >= ker_prb_size_min) return d;
1335 return prb.ndims;
1336 };
1337
1338 if (ndims_ker_max <= 0) ndims_ker_max = ndims_ker_max_f();
1339
1340 /* traverse through kernel implementations */
1341 /* TODO: find a better way to do that... */
1342 desc.id = 0;
1343 for (int ndims_ker = ndims_ker_max; ndims_ker > 0; --ndims_ker) {
1344 desc.prb.ndims = ndims_ker;
1345 if (jit_uni_reorder_kernel_f32_t::applicable(desc.prb))
1346 return status::success;
1347 }
1348
1349 return status::unimplemented;
1350 }
1351
create(const kernel_t::desc_t & desc)1352 kernel_t *kernel_t::create(const kernel_t::desc_t &desc) {
1353 switch (desc.id) {
1354 case 0: return new jit_uni_reorder_kernel_f32_t(desc);
1355 default: assert(!"unknown kernel id"); return nullptr;
1356 }
1357
1358 return nullptr;
1359 }
1360 } // namespace tr
1361
prb_block_for_cache(tr::prb_t & prb)1362 static void prb_block_for_cache(tr::prb_t &prb) {
1363 /* If strides for 0th and 1st nodes are cache friendly
1364 * then one can altogether do away with blocking ! */
1365 const bool cache_blocking_needed = false
1366 || (prb.nodes[0].is % 64 == 0 && prb.nodes[0].n > 16)
1367 || (prb.ndims > 1 && prb.nodes[1].is % 64 == 0
1368 && prb.nodes[1].n > 16);
1369 if (!cache_blocking_needed) return;
1370
1371 int unit_input_stride_idx = -1;
1372 for (auto idx = 0; idx < prb.ndims; ++idx) {
1373 if (prb.nodes[idx].is == 1) unit_input_stride_idx = idx;
1374 }
1375
1376 /* Re-prioritize the sequential read over sequential write:
1377 * /-> [n0:is0:1][16n1:1:osk]...
1378 * [n0:is0:1]...[nk:1:osk] --> or
1379 * \-> [16n1:1:osk][n0:is0:1]... */
1380 if (unit_input_stride_idx != -1) {
1381 const auto output_stride = prb.nodes[unit_input_stride_idx].os;
1382 const auto num_elems = prb.nodes[unit_input_stride_idx].n;
1383
1384 const bool split_needed = (num_elems > 16) && (num_elems % 16 == 0);
1385 const int move_location = (output_stride % 4 != 0) ? 0 : 1;
1386 if (split_needed) prb_node_split(prb, unit_input_stride_idx, 16);
1387
1388 /* Because of cache-unfriendly nature of unit-output stride node, let
1389 * us move unit-input stride node on or near front! */
1390 prb_node_move(prb, unit_input_stride_idx, move_location);
1391 }
1392
1393 /* Potentially, split the node with os=1 in two and pull in the node with
1394 * is=1 between them for better cache reuse:
1395 * [n0:is0:1][n1:1:os1] --> [16n0:is0:1][n1:1:os1][n0/16:is0*16:16] */
1396 if (prb.ndims >= 2 && prb.nodes[0].os == 1 && prb.nodes[1].is == 1) {
1397 const auto input_stride = prb.nodes[0].is;
1398 const auto num_elems = prb.nodes[0].n;
1399
1400 const bool split_needed = true && (num_elems > 16)
1401 && (num_elems % 16 == 0) && (input_stride >= 256)
1402 && (input_stride % 64 == 0);
1403 if (split_needed) {
1404 prb_node_split(prb, 0, 16);
1405 prb_node_move(prb, 1, 2);
1406 }
1407 }
1408 }
1409
1410 /** finds the maximum number of dimension the kernel should process and
1411 * optionally splits one of the dimension to achieve better balance between
1412 * parallel driver and the kernel. */
prb_thread_kernel_balance(tr::prb_t & prb,int & ndims_ker_max,int nthr)1413 static void prb_thread_kernel_balance(
1414 tr::prb_t &prb, int &ndims_ker_max, int nthr) {
1415 size_t sz_total = 1;
1416 for (int d = 0; d < prb.ndims; ++d)
1417 sz_total *= prb.nodes[d].n;
1418
1419 /* sz_drv_min is the minimal size for the parallel
1420 * driver required for good parallelization */
1421 const size_t sz_drv_min
1422 = nstl::min<size_t>(16 * nthr, utils::div_up(sz_total, 1024));
1423
1424 /* kdims -- # of dimensions processed by a kernel
1425 * sz_ker_cur -- product of the dimension processed by a kernel
1426 * sz_drv_cur -- product of the dimension processed by a driver */
1427
1428 int kdims = prb.ndims;
1429 size_t sz_drv_cur = 1;
1430 for (; kdims > 1 && sz_drv_cur < sz_drv_min; --kdims)
1431 sz_drv_cur *= prb.nodes[kdims - 1].n;
1432
1433 size_t sz_ker_cur = 1;
1434 for (int d = 0; d < kdims; ++d)
1435 sz_ker_cur *= prb.nodes[d].n;
1436
1437 /* Initially kdims is chosen so that sz_drv_cur >= sz_drv_min.
1438 *
1439 * It might happen that for chosen kdims the sz_ker_cur is too small
1440 * (less than tr::ker_prb_size_min). In that case try to split the
1441 * innermost driver dimension into two, to increase sz_ker_cur. */
1442 bool want_borrow_ker_from_drv = true && kdims < prb.ndims
1443 && sz_ker_cur < tr::ker_prb_size_min && sz_drv_cur > sz_drv_min;
1444 if (want_borrow_ker_from_drv) {
1445 /* sz_want_borrow is the minimal sz, so that:
1446 * o) sz_ker_cur * sz_want_borrow >= tr::ker_prb_size_min
1447 * o) current innermost driver dimension is divisible by
1448 * sz_want_borrow (so that we can evenly split that
1449 * dimension into two)
1450 *
1451 * In the worst case the minimal sz_want_borrow is equal
1452 * to the innermost driver dimension itself. In that case
1453 * we will sacrifice it in favor of kernel (is it fine?). */
1454 size_t sz_want_borrow = utils::div_up(tr::ker_prb_size_min, sz_ker_cur);
1455 for (; prb.nodes[kdims].n % sz_want_borrow; ++sz_want_borrow)
1456 ;
1457 if (sz_want_borrow != prb.nodes[kdims].n)
1458 prb_node_split(prb, kdims, sz_want_borrow);
1459 kdims += 1;
1460 }
1461
1462 /* On the other hand it might happen that for chosen kdims
1463 * the sz_drv_cur is too small (less than sz_drv_min). In that case
1464 * try to split the outermost kernel dimension into two, to increase
1465 * sz_drv_cur. */
1466 bool want_borrow_drv_from_ker = true && sz_ker_cur > tr::ker_prb_size_min
1467 && sz_drv_cur < sz_drv_min;
1468 if (want_borrow_drv_from_ker) {
1469 size_t sz_want_borrow = utils::div_up(sz_drv_min, sz_drv_cur);
1470 for (; prb.nodes[kdims - 1].n % sz_want_borrow; ++sz_want_borrow)
1471 ;
1472 if (sz_want_borrow != prb.nodes[kdims - 1].n)
1473 prb_node_split(
1474 prb, kdims - 1, prb.nodes[kdims - 1].n / sz_want_borrow);
1475 }
1476
1477 ndims_ker_max = kdims;
1478
1479 if (want_borrow_ker_from_drv || want_borrow_drv_from_ker) {
1480 DEBUG({
1481 printf("split: ");
1482 prb_dump(prb);
1483 printf("ndims_ker_max = %d\n", ndims_ker_max);
1484 });
1485 }
1486 }
1487
create(reorder_pd_t ** reorder_pd,engine_t * engine,const primitive_attr_t * attr,engine_t * src_engine,const memory_desc_t * src_md,engine_t * dst_engine,const memory_desc_t * dst_md)1488 status_t jit_uni_reorder_t::pd_t::create(reorder_pd_t **reorder_pd,
1489 engine_t *engine, const primitive_attr_t *attr, engine_t *src_engine,
1490 const memory_desc_t *src_md, engine_t *dst_engine,
1491 const memory_desc_t *dst_md) {
1492 auto prb = tr::prb_t();
1493
1494 status_t prb_init_status = prb_init(prb, *src_md, *dst_md, attr);
1495 if (prb_init_status != status::success) return prb_init_status;
1496
1497 DEBUG({
1498 printf("init : ");
1499 prb_dump(prb);
1500 });
1501 // Sort the prb array in increasing sizes of the output stride
1502 prb_normalize(prb);
1503 DEBUG({
1504 printf("norm : ");
1505 prb_dump(prb);
1506 });
1507 /* Combine the variables, which appear together on both
1508 * sides of the reorder */
1509 prb_simplify(prb);
1510 DEBUG({
1511 printf("smpl : ");
1512 prb_dump(prb);
1513 });
1514
1515 prb_block_for_cache(prb);
1516 DEBUG({
1517 printf("cache: ");
1518 prb_dump(prb);
1519 });
1520
1521 int ndims_ker_max;
1522 int nthr = dnnl_get_max_threads();
1523 prb_thread_kernel_balance(prb, ndims_ker_max, nthr);
1524
1525 tr::kernel_t::desc_t ker_desc;
1526 status_t ker_init_status
1527 = tr::kernel_t::desc_init(ker_desc, prb, ndims_ker_max);
1528 if (ker_init_status != status::success) return ker_init_status;
1529
1530 const int ndims_driver = prb.ndims - ker_desc.prb.ndims;
1531 if (ndims_driver > jit_uni_reorder_t::ndims_driver_max)
1532 return status::unimplemented;
1533
1534 DEBUG({
1535 printf("ker : ");
1536 prb_dump(ker_desc.prb);
1537 });
1538
1539 auto _pd = new pd_t(
1540 attr, src_engine->kind(), src_md, dst_engine->kind(), dst_md);
1541 if (_pd == nullptr) return status::out_of_memory;
1542 if (_pd->init(engine, src_engine, dst_engine) != status::success) {
1543 delete _pd;
1544 return status::unimplemented;
1545 }
1546 _pd->prb_ = prb;
1547 _pd->ker_desc_ = ker_desc;
1548 _pd->init_scratchpad_md();
1549 _pd->nthr_ = nthr;
1550 return safe_ptr_assign(*reorder_pd, _pd);
1551 }
1552
omp_driver_0d(int off,const char * in,char * out,const float * scale) const1553 void jit_uni_reorder_t::omp_driver_0d(
1554 int off, const char *in, char *out, const float *scale) const {
1555 tr::call_param_t c {in, out, scale};
1556 (*kernel_)(&c);
1557 }
1558
omp_driver_1d(int ithr,int nthr,int off,const char * in,char * out,const float * scale) const1559 void jit_uni_reorder_t::omp_driver_1d(int ithr, int nthr, int off,
1560 const char *in, char *out, const float *scale) const {
1561 const tr::node_t *ns = pd()->prb_.nodes + off;
1562 for_nd(ithr, nthr, (ptrdiff_t)ns[0].n, [&](ptrdiff_t d0) {
1563 auto c = tr::call_param_t();
1564 c.in = in + d0 * ns[0].is * data_type_size(pd()->prb_.itype);
1565 c.out = out + d0 * ns[0].os * data_type_size(pd()->prb_.otype);
1566 c.scale = scale + d0 * ns[0].ss;
1567 (*kernel_)(&c);
1568 });
1569 }
1570
omp_driver_2d(int ithr,int nthr,int off,const char * in,char * out,const float * scale) const1571 void jit_uni_reorder_t::omp_driver_2d(int ithr, int nthr, int off,
1572 const char *in, char *out, const float *scale) const {
1573 const tr::node_t *ns = pd()->prb_.nodes + off;
1574 for_nd(ithr, nthr, (ptrdiff_t)ns[1].n, (ptrdiff_t)ns[0].n,
1575 [&](ptrdiff_t d1, ptrdiff_t d0) {
1576 auto c = tr::call_param_t();
1577 c.in = in
1578 + (d0 * ns[0].is + d1 * ns[1].is)
1579 * data_type_size(pd()->prb_.itype);
1580 c.out = out
1581 + (d0 * ns[0].os + d1 * ns[1].os)
1582 * data_type_size(pd()->prb_.otype);
1583 c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss;
1584 (*kernel_)(&c);
1585 });
1586 }
1587
omp_driver_3d(int ithr,int nthr,int off,const char * in,char * out,const float * scale) const1588 void jit_uni_reorder_t::omp_driver_3d(int ithr, int nthr, int off,
1589 const char *in, char *out, const float *scale) const {
1590 const tr::node_t *ns = pd()->prb_.nodes + off;
1591 for_nd(ithr, nthr, (ptrdiff_t)ns[2].n, (ptrdiff_t)ns[1].n,
1592 (ptrdiff_t)ns[0].n, [&](ptrdiff_t d2, ptrdiff_t d1, ptrdiff_t d0) {
1593 auto c = tr::call_param_t();
1594 c.in = in
1595 + (d0 * ns[0].is + d1 * ns[1].is + d2 * ns[2].is)
1596 * data_type_size(pd()->prb_.itype);
1597 c.out = out
1598 + (d0 * ns[0].os + d1 * ns[1].os + d2 * ns[2].os)
1599 * data_type_size(pd()->prb_.otype);
1600 c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss + d2 * ns[2].ss;
1601 (*kernel_)(&c);
1602 });
1603 }
1604
omp_driver_4d(int ithr,int nthr,int off,const char * in,char * out,const float * scale) const1605 void jit_uni_reorder_t::omp_driver_4d(int ithr, int nthr, int off,
1606 const char *in, char *out, const float *scale) const {
1607 const tr::node_t *ns = pd()->prb_.nodes + off;
1608 for_nd(ithr, nthr, (ptrdiff_t)ns[3].n, (ptrdiff_t)ns[2].n,
1609 (ptrdiff_t)ns[1].n, (ptrdiff_t)ns[0].n,
1610 [&](ptrdiff_t d3, ptrdiff_t d2, ptrdiff_t d1, ptrdiff_t d0) {
1611 auto c = tr::call_param_t();
1612 c.in = in
1613 + (d0 * ns[0].is + d1 * ns[1].is + d2 * ns[2].is
1614 + d3 * ns[3].is)
1615 * data_type_size(pd()->prb_.itype);
1616 c.out = out
1617 + (d0 * ns[0].os + d1 * ns[1].os + d2 * ns[2].os
1618 + d3 * ns[3].os)
1619 * data_type_size(pd()->prb_.otype);
1620 c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss + d2 * ns[2].ss
1621 + d3 * ns[3].ss;
1622 (*kernel_)(&c);
1623 });
1624 }
1625
omp_driver(const char * in,char * out,const float * scale) const1626 void jit_uni_reorder_t::omp_driver(
1627 const char *in, char *out, const float *scale) const {
1628 in += pd()->prb_.ioff * data_type_size(pd()->prb_.itype);
1629 out += pd()->prb_.ooff * data_type_size(pd()->prb_.otype);
1630
1631 DEBUG({
1632 printf("prb : ");
1633 tr::prb_dump(pd()->prb_);
1634 });
1635 DEBUG({
1636 printf("ker : ");
1637 tr::prb_dump(pd()->ker_desc_.prb);
1638 });
1639
1640 int ndims = pd()->prb_.ndims;
1641 int ndims_ker = pd()->ker_desc_.prb.ndims;
1642 assert(ndims - ndims_ker <= ndims_driver_max);
1643
1644 if (ndims - ndims_ker == 0) {
1645 omp_driver_0d(ndims_ker, in, out, scale);
1646 } else {
1647 parallel(pd()->nthr_, [&](const int ithr, const int nthr) {
1648 switch (ndims - ndims_ker) {
1649 case 1:
1650 omp_driver_1d(ithr, nthr, ndims_ker, in, out, scale);
1651 break;
1652 case 2:
1653 omp_driver_2d(ithr, nthr, ndims_ker, in, out, scale);
1654 break;
1655 case 3:
1656 omp_driver_3d(ithr, nthr, ndims_ker, in, out, scale);
1657 break;
1658 case 4:
1659 omp_driver_4d(ithr, nthr, ndims_ker, in, out, scale);
1660 break;
1661 default: assert(!"unimplemented");
1662 }
1663 });
1664 }
1665 }
1666
init(engine_t * engine)1667 status_t jit_uni_reorder_t::init(engine_t *engine) {
1668 CHECK(safe_ptr_assign(kernel_, tr::kernel_t::create(pd()->ker_desc_)));
1669 return kernel_->create_kernel();
1670 }
1671
execute(const exec_ctx_t & ctx) const1672 status_t jit_uni_reorder_t::execute(const exec_ctx_t &ctx) const {
1673 status_t status = status::success;
1674 auto in = CTX_IN_MEM(const char *, DNNL_ARG_FROM);
1675 auto out = CTX_OUT_CLEAN_MEM(char *, DNNL_ARG_TO, status);
1676 CHECK(status);
1677 DEFINE_SCALES_BUFFER(scales);
1678
1679 omp_driver(in, out, scales);
1680
1681 return status::success;
1682 }
1683
1684 } // namespace aarch64
1685 } // namespace cpu
1686 } // namespace impl
1687 } // namespace dnnl
1688