1 /*******************************************************************************
2 * Copyright 2021 Intel Corporation
3 * Copyright 2021 FUJITSU LIMITED
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *******************************************************************************/
17
18 #include "common/c_types_map.hpp"
19 #include "common/dnnl_thread.hpp"
20 #include "common/nstl.hpp"
21 #include "common/utils.hpp"
22
23 #include "cpu/aarch64/jit_generator.hpp"
24
25 #include "cpu/aarch64/jit_uni_eltwise_int.hpp"
26
27 namespace dnnl {
28 namespace impl {
29 namespace cpu {
30 namespace aarch64 {
31
32 using namespace Xbyak_aarch64;
33
34 struct jit_args_t {
35 const void *from;
36 const void *for_comparison;
37 const void *to;
38 size_t work_amount;
39 };
40
41 struct jit_uni_eltwise_int_kernel : public jit_generator {
jit_uni_eltwise_int_kerneldnnl::impl::cpu::aarch64::jit_uni_eltwise_int_kernel42 jit_uni_eltwise_int_kernel(const eltwise_desc_t &desc) : desc_(desc) {}
43
operator ()dnnl::impl::cpu::aarch64::jit_uni_eltwise_int_kernel44 void operator()(jit_args_t *p) { jit_generator::operator()(p); }
45
46 protected:
data_typednnl::impl::cpu::aarch64::jit_uni_eltwise_int_kernel47 data_type_t data_type() const { return desc_.data_desc.data_type; }
dtype_sizednnl::impl::cpu::aarch64::jit_uni_eltwise_int_kernel48 int dtype_size() const { return types::data_type_size(data_type()); }
49
descdnnl::impl::cpu::aarch64::jit_uni_eltwise_int_kernel50 const eltwise_desc_t &desc() const { return desc_; }
51
52 private:
53 const eltwise_desc_t &desc_;
54 };
55
56 /* jit kernels */
57 namespace {
58 using namespace Xbyak_aarch64;
59
60 template <cpu_isa_t isa>
61 struct jit_uni_subkernel_int_t : public jit_uni_eltwise_int_kernel {
62 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_subkernel_int)
63
jit_uni_subkernel_int_tdnnl::impl::cpu::aarch64::__anonbfc0da950111::jit_uni_subkernel_int_t64 jit_uni_subkernel_int_t(const eltwise_desc_t &desc)
65 : jit_uni_eltwise_int_kernel(desc) {
66 using namespace data_type;
67
68 // Relu and linear for int types: s32, s8, u8; Only forward direction
69 assert(utils::one_of(desc.alg_kind, alg_kind::eltwise_relu,
70 alg_kind::eltwise_linear));
71 assert(utils::one_of(data_type(), s32, data_type::s8, u8));
72 assert(isa == sve_512);
73 }
74
generatednnl::impl::cpu::aarch64::__anonbfc0da950111::jit_uni_subkernel_int_t75 void generate() override {
76 XReg param = abi_param1;
77
78 const size_t vlen = cpu_isa_traits<isa>::vlen;
79 const size_t simd_w = vlen / sizeof(float);
80 const size_t loop_dec[] = {simd_w, 1};
81 const size_t uf[] = {1, 1};
82 const size_t shift[] = {dtype_size() * simd_w, (size_t)dtype_size()};
83 const bool loop_vectorize[] = {true, false};
84
85 preamble();
86
87 #define GET_OFF(field) offsetof(jit_args_t, field)
88 add_imm(X_TMP_0, param, GET_OFF(from), X_TMP_1);
89 ldr(reg_from, ptr(X_TMP_0));
90
91 add_imm(X_TMP_0, param, GET_OFF(to), X_TMP_1);
92 ldr(reg_to, ptr(X_TMP_0));
93
94 add_imm(X_TMP_0, param, GET_OFF(work_amount), X_TMP_1);
95 ldr(reg_work_amount, ptr(X_TMP_0));
96 #undef GET_OFF
97
98 mov_imm(W_TMP_0, float2int(desc().alpha));
99 mov_imm(W_TMP_1, float2int(desc().beta));
100 dup(ts_alpha, W_TMP_0);
101 dup(ts_beta, W_TMP_1);
102
103 eor(t_zero.d, t_zero.d, t_zero.d);
104
105 if (isa == sve_512) {
106 ptrue(p_vl1.b, VL1);
107 ptrue(p_all_one.b);
108 }
109
110 Label loop_label[3];
111
112 for (int id = 0; id < 2; id++) {
113 L(loop_label[id]);
114 mov_imm(X_TMP_0, uf[id] * loop_dec[id] - 1);
115 cmp(reg_work_amount, X_TMP_0);
116
117 b(LE, loop_label[id + 1]);
118
119 compute_step(
120 loop_vectorize[id], uf[id], shift[id], desc().alg_kind);
121
122 add_imm(reg_from, reg_from, uf[id] * shift[id], X_TMP_0);
123 add_imm(reg_to, reg_to, uf[id] * shift[id], X_TMP_0);
124 sub_imm(reg_work_amount, reg_work_amount, uf[id] * loop_dec[id],
125 X_TMP_0);
126 b(loop_label[id]);
127 }
128
129 L(loop_label[2]);
130 postamble();
131 }
132
133 private:
134 using TReg = typename cpu_isa_traits<isa>::TReg;
135 using TRegS = typename cpu_isa_traits<isa>::TRegS;
136
137 const XReg reg_from = x1;
138 const XReg reg_to = x8;
139 const XReg reg_work_amount = x6;
140 const XReg imm_addr64 = x3;
141
142 const TReg t_tmp0 = TReg(31);
143
144 const TReg t_saturation_ubound = TReg(26);
145 const TRegS ts_alpha = TRegS(27);
146 const TRegS ts_beta = TRegS(28);
147 const TReg t_zero = TReg(29);
148
149 const PReg p_vl1 = p0;
150 const PReg p_mask = p1;
151 const PReg p_mask_int8 = p_vl1; // Mask for store 1 byte in case of SVE_512
152 const PReg p_all_one = p3;
153
is32bitdnnl::impl::cpu::aarch64::__anonbfc0da950111::jit_uni_subkernel_int_t154 bool is32bit() const { return data_type() == data_type::s32; }
155
156 // Load 32bit data type (s32)
load_32bitdnnl::impl::cpu::aarch64::__anonbfc0da950111::jit_uni_subkernel_int_t157 void load_32bit(
158 const bool vectorize, const TReg &vr_from, const XReg &mem_from) {
159
160 if (vectorize) {
161 // load full TReg size
162 uni_ldr(vr_from, mem_from);
163 } else {
164 // load exactly one data item
165 ldr(W_TMP_0, ptr(mem_from));
166 mov(vr_from.s, W_TMP_0);
167 }
168 }
169
170 // Load 8bit data type (u8/s8)
load_8bitdnnl::impl::cpu::aarch64::__anonbfc0da950111::jit_uni_subkernel_int_t171 void load_8bit(const bool vectorize, const TReg &vr_from,
172 const XReg &mem_from, bool is_signed) {
173
174 // data type u8/s8 load as s32
175 if (vectorize) {
176 // load full TReg size
177 ldr(QReg(t_tmp0.getIdx()), ptr(mem_from));
178 zip1(t_tmp0.b, t_tmp0.b, t_tmp0.b);
179 zip1(t_tmp0.h, t_tmp0.h, t_tmp0.h);
180
181 if (is_signed)
182 sxtb(vr_from.s, p_all_one / T_m, t_tmp0.s);
183 else
184 uxtb(vr_from.s, p_all_one / T_m, t_tmp0.s);
185 } else {
186 // load exactly one data item
187 ldurb(W_TMP_0, ptr(mem_from));
188 uni_clear(vr_from);
189
190 if (is_signed)
191 sxtb(W_TMP_0, W_TMP_0);
192 else
193 uxtb(W_TMP_0, W_TMP_0);
194
195 mov(VReg(vr_from.getIdx()).d[0], X_TMP_0);
196 }
197 }
198
199 // Load vregs with data from mem
loaddnnl::impl::cpu::aarch64::__anonbfc0da950111::jit_uni_subkernel_int_t200 void load(const bool vectorize, const TReg &vr_from, const XReg &mem_from) {
201
202 // Branching on data size
203 if (is32bit())
204 load_32bit(vectorize, vr_from, mem_from);
205 else
206 load_8bit(
207 vectorize, vr_from, mem_from, data_type() == data_type::s8);
208 }
209
210 // Processing
211 void process_linear(const TReg &vr_to, const TReg &vr_from);
212 void process_relu(const TReg &vr_to, const TReg &vr_from);
213
214 // Store s32 for any isa
store_32bitdnnl::impl::cpu::aarch64::__anonbfc0da950111::jit_uni_subkernel_int_t215 void store_32bit(
216 const bool vectorize, const XReg &mem_to, const TReg &vr_to) {
217 if (vectorize) {
218 // store full TReg size
219 uni_str(vr_to, mem_to);
220 } else {
221 // store exactly one data item
222 st1w(vr_to.s, p_vl1, ptr(mem_to));
223 }
224 }
225
226 // Store 8 bit int - isa-dependent
227 void store_8bit(const bool vectorize, const XReg &mem_to, const TReg &vr_to,
228 bool is_signed);
229
230 // Store results from vregs to mem
storednnl::impl::cpu::aarch64::__anonbfc0da950111::jit_uni_subkernel_int_t231 void store(const bool vectorize, const XReg &mem_to, const TReg &vr_to) {
232 // Branching on data size
233 if (is32bit())
234 store_32bit(vectorize, mem_to, vr_to);
235 else
236 store_8bit(vectorize, mem_to, vr_to, data_type() == data_type::s8);
237 }
238
compute_stepdnnl::impl::cpu::aarch64::__anonbfc0da950111::jit_uni_subkernel_int_t239 void compute_step(bool vectorize, const size_t uf, const size_t shift,
240 const alg_kind_t alg) {
241
242 auto vreg_from = [&](const size_t i) -> TReg { return TReg(i + 1); };
243 auto vreg_to = [&](const size_t i) -> TReg { return TReg(uf + i + 1); };
244
245 // 1. Load (vregs <- mem)
246 for (size_t i = 0; i < uf; i++) {
247 add_imm(reg_from, reg_from, i * shift, X_TMP_0);
248 load(vectorize, vreg_from(i), reg_from);
249 }
250
251 // 2. Process (vregs <- vergs)
252 switch (alg) {
253 case alg_kind::eltwise_linear:
254 for (size_t i = 0; i < uf; i++)
255 process_linear(vreg_to(i), vreg_from(i));
256 break;
257 case alg_kind::eltwise_relu:
258 for (size_t i = 0; i < uf; i++)
259 process_relu(vreg_to(i), vreg_from(i));
260 break;
261 default: assert(!"unsupported alg");
262 }
263
264 // 3. Store (mem <- vregs)
265 for (size_t i = 0; i < uf; i++) {
266 add_imm(reg_to, reg_to, i * shift, X_TMP_0);
267 store(vectorize, reg_to, vreg_to(i));
268 }
269 }
270 };
271
272 template <cpu_isa_t isa>
process_linear(const TReg & vr_to,const TReg & vr_from)273 void jit_uni_subkernel_int_t<isa>::process_linear(
274 const TReg &vr_to, const TReg &vr_from) {
275
276 scvtf(vr_to.s, p_all_one / T_m, vr_from.s);
277 fmad(vr_to.s, p_all_one / T_m, ts_alpha, ts_beta);
278
279 // Saturate before converting from f32 to s32
280 XReg reg_tmp = x10;
281
282 uni_clear(t_zero);
283 init_saturate_f32(
284 t_zero, t_saturation_ubound, reg_tmp, data_type::f32, data_type());
285 saturate_f32(vr_to, t_zero, t_saturation_ubound, data_type(), p_all_one);
286
287 frinti(vr_to.s, p_all_one / T_m, vr_to.s);
288 fcvtzs(vr_to.s, p_all_one / T_m, vr_to.s);
289 }
290
291 template <cpu_isa_t isa>
process_relu(const TReg & vr_to,const TReg & vr_from)292 void jit_uni_subkernel_int_t<isa>::process_relu(
293 const TReg &vr_to, const TReg &vr_from) {
294 assert(!"unsupported isa");
295 }
296
297 template <>
process_relu(const TReg & vr_to,const TReg & vr_from)298 void jit_uni_subkernel_int_t<sve_512>::process_relu(
299 const TReg &vr_to, const TReg &vr_from) {
300
301 scvtf(vr_from.s, p_all_one / T_m, vr_from.s);
302
303 fmul(vr_to.s, vr_from.s, ts_alpha);
304
305 fcmgt(p_mask.s, p_all_one / T_z, vr_from.s, t_zero.s);
306
307 sel(vr_to.s, p_mask / T_m, vr_from.s, vr_to.s);
308
309 frinti(vr_to.s, p_all_one / T_m, vr_to.s);
310 fcvtzs(vr_to.s, p_all_one / T_m, vr_to.s);
311 }
312
313 template <cpu_isa_t isa>
store_8bit(const bool vectorize,const XReg & mem_to,const TReg & vr_to,bool is_signed)314 void jit_uni_subkernel_int_t<isa>::store_8bit(const bool vectorize,
315 const XReg &mem_to, const TReg &vr_to, bool is_signed) {
316 assert(!"unsupported isa");
317 }
318
319 template <>
store_8bit(const bool vectorize,const XReg & mem_to,const TReg & vr_to,bool is_signed)320 void jit_uni_subkernel_int_t<sve_512>::store_8bit(const bool vectorize,
321 const XReg &mem_to, const TReg &vr_to, bool is_signed) {
322 if (vectorize) {
323 // store full TReg size
324 mov(t_tmp0.d, vr_to.d);
325 if (is_signed) {
326 smin(t_tmp0.s, 127);
327 smax(t_tmp0.s, -128);
328 } else {
329 umin(t_tmp0.s, 255);
330 }
331 st1b(t_tmp0.s, p_all_one, ptr(mem_to));
332 } else {
333 // store exactly one data item
334 // s32 save as s8/u8
335 mov(t_tmp0.d, vr_to.d);
336 if (is_signed) {
337 smin(t_tmp0.s, 127);
338 smax(t_tmp0.s, -128);
339 } else {
340 umin(t_tmp0.s, 255);
341 }
342 st1b(t_tmp0.s, p_mask_int8, ptr(mem_to));
343 }
344 }
345
346 } /* namespace */
347
348 template <cpu_isa_t isa, data_type_t d_type>
init(engine_t * engine)349 status_t jit_uni_eltwise_int_fwd_t<isa, d_type>::pd_t::init(engine_t *engine) {
350 bool ok = mayiuse(isa)
351 && desc()->data_desc.data_type == d_type
352 // only relu and linear so far
353 && utils::one_of(desc()->alg_kind, alg_kind::eltwise_relu,
354 alg_kind::eltwise_linear)
355 && !has_zero_dim_memory()
356 && memory_desc_wrapper(src_md()).is_dense(true)
357 && attr()->has_default_values();
358
359 return ok ? status::success : status::unimplemented;
360 }
361
362 template <cpu_isa_t isa, data_type_t d_type>
jit_uni_eltwise_int_fwd_t(const pd_t * apd)363 jit_uni_eltwise_int_fwd_t<isa, d_type>::jit_uni_eltwise_int_fwd_t(
364 const pd_t *apd)
365 : primitive_t(apd) {}
366
367 template <cpu_isa_t isa, data_type_t d_type>
init(engine_t * engine)368 status_t jit_uni_eltwise_int_fwd_t<isa, d_type>::init(engine_t *engine) {
369 const auto &desc = *pd()->desc();
370 CHECK(safe_ptr_assign(kernel_, new jit_uni_subkernel_int_t<isa>(desc)));
371 return kernel_->create_kernel();
372 }
373
374 template <cpu_isa_t isa, data_type_t d_type>
~jit_uni_eltwise_int_fwd_t()375 jit_uni_eltwise_int_fwd_t<isa, d_type>::~jit_uni_eltwise_int_fwd_t() {
376 delete kernel_;
377 }
378
379 template <cpu_isa_t isa, impl::data_type_t d_type>
execute_forward(const exec_ctx_t & ctx) const380 status_t jit_uni_eltwise_int_fwd_t<isa, d_type>::execute_forward(
381 const exec_ctx_t &ctx) const {
382 status_t status = status::success;
383 auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
384 auto dst = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_DST, status);
385 CHECK(status);
386
387 const memory_desc_wrapper data_d(pd()->data_md());
388
389 const size_t nelems = data_d.nelems(true);
390
391 src += data_d.offset0();
392 dst += data_d.offset0();
393
394 const int cache_line = 64 / data_d.data_type_size();
395 parallel(0, [&](const int ithr, const int nthr) {
396 size_t start {0}, end {0};
397
398 balance211(utils::div_up(nelems, cache_line), nthr, ithr, start, end);
399 start = nstl::min(nelems, start * cache_line);
400 end = nstl::min(nelems, end * cache_line);
401
402 auto arg = jit_args_t();
403 arg.from = (const void *)&src[start];
404 arg.for_comparison = (const void *)&src[start];
405 arg.to = (const void *)&dst[start];
406 arg.work_amount = end - start;
407 if (arg.work_amount) (*kernel_)(&arg);
408 });
409 return status::success;
410 }
411
412 using namespace data_type;
413
414 template struct jit_uni_eltwise_int_fwd_t<sve_512, s32>;
415 template struct jit_uni_eltwise_int_fwd_t<sve_512, data_type::s8>;
416 template struct jit_uni_eltwise_int_fwd_t<sve_512, u8>;
417
418 } // namespace aarch64
419 } // namespace cpu
420 } // namespace impl
421 } // namespace dnnl
422