1 /*******************************************************************************
2 * Copyright 2021 Intel Corporation
3 * Copyright 2021 FUJITSU LIMITED
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *******************************************************************************/
17 
18 #include "common/c_types_map.hpp"
19 #include "common/dnnl_thread.hpp"
20 #include "common/nstl.hpp"
21 #include "common/utils.hpp"
22 
23 #include "cpu/aarch64/jit_generator.hpp"
24 
25 #include "cpu/aarch64/jit_uni_eltwise_int.hpp"
26 
27 namespace dnnl {
28 namespace impl {
29 namespace cpu {
30 namespace aarch64 {
31 
32 using namespace Xbyak_aarch64;
33 
34 struct jit_args_t {
35     const void *from;
36     const void *for_comparison;
37     const void *to;
38     size_t work_amount;
39 };
40 
41 struct jit_uni_eltwise_int_kernel : public jit_generator {
jit_uni_eltwise_int_kerneldnnl::impl::cpu::aarch64::jit_uni_eltwise_int_kernel42     jit_uni_eltwise_int_kernel(const eltwise_desc_t &desc) : desc_(desc) {}
43 
operator ()dnnl::impl::cpu::aarch64::jit_uni_eltwise_int_kernel44     void operator()(jit_args_t *p) { jit_generator::operator()(p); }
45 
46 protected:
data_typednnl::impl::cpu::aarch64::jit_uni_eltwise_int_kernel47     data_type_t data_type() const { return desc_.data_desc.data_type; }
dtype_sizednnl::impl::cpu::aarch64::jit_uni_eltwise_int_kernel48     int dtype_size() const { return types::data_type_size(data_type()); }
49 
descdnnl::impl::cpu::aarch64::jit_uni_eltwise_int_kernel50     const eltwise_desc_t &desc() const { return desc_; }
51 
52 private:
53     const eltwise_desc_t &desc_;
54 };
55 
56 /* jit kernels */
57 namespace {
58 using namespace Xbyak_aarch64;
59 
60 template <cpu_isa_t isa>
61 struct jit_uni_subkernel_int_t : public jit_uni_eltwise_int_kernel {
62     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_subkernel_int)
63 
jit_uni_subkernel_int_tdnnl::impl::cpu::aarch64::__anonbfc0da950111::jit_uni_subkernel_int_t64     jit_uni_subkernel_int_t(const eltwise_desc_t &desc)
65         : jit_uni_eltwise_int_kernel(desc) {
66         using namespace data_type;
67 
68         // Relu and linear for int types: s32, s8, u8; Only forward direction
69         assert(utils::one_of(desc.alg_kind, alg_kind::eltwise_relu,
70                 alg_kind::eltwise_linear));
71         assert(utils::one_of(data_type(), s32, data_type::s8, u8));
72         assert(isa == sve_512);
73     }
74 
generatednnl::impl::cpu::aarch64::__anonbfc0da950111::jit_uni_subkernel_int_t75     void generate() override {
76         XReg param = abi_param1;
77 
78         const size_t vlen = cpu_isa_traits<isa>::vlen;
79         const size_t simd_w = vlen / sizeof(float);
80         const size_t loop_dec[] = {simd_w, 1};
81         const size_t uf[] = {1, 1};
82         const size_t shift[] = {dtype_size() * simd_w, (size_t)dtype_size()};
83         const bool loop_vectorize[] = {true, false};
84 
85         preamble();
86 
87 #define GET_OFF(field) offsetof(jit_args_t, field)
88         add_imm(X_TMP_0, param, GET_OFF(from), X_TMP_1);
89         ldr(reg_from, ptr(X_TMP_0));
90 
91         add_imm(X_TMP_0, param, GET_OFF(to), X_TMP_1);
92         ldr(reg_to, ptr(X_TMP_0));
93 
94         add_imm(X_TMP_0, param, GET_OFF(work_amount), X_TMP_1);
95         ldr(reg_work_amount, ptr(X_TMP_0));
96 #undef GET_OFF
97 
98         mov_imm(W_TMP_0, float2int(desc().alpha));
99         mov_imm(W_TMP_1, float2int(desc().beta));
100         dup(ts_alpha, W_TMP_0);
101         dup(ts_beta, W_TMP_1);
102 
103         eor(t_zero.d, t_zero.d, t_zero.d);
104 
105         if (isa == sve_512) {
106             ptrue(p_vl1.b, VL1);
107             ptrue(p_all_one.b);
108         }
109 
110         Label loop_label[3];
111 
112         for (int id = 0; id < 2; id++) {
113             L(loop_label[id]);
114             mov_imm(X_TMP_0, uf[id] * loop_dec[id] - 1);
115             cmp(reg_work_amount, X_TMP_0);
116 
117             b(LE, loop_label[id + 1]);
118 
119             compute_step(
120                     loop_vectorize[id], uf[id], shift[id], desc().alg_kind);
121 
122             add_imm(reg_from, reg_from, uf[id] * shift[id], X_TMP_0);
123             add_imm(reg_to, reg_to, uf[id] * shift[id], X_TMP_0);
124             sub_imm(reg_work_amount, reg_work_amount, uf[id] * loop_dec[id],
125                     X_TMP_0);
126             b(loop_label[id]);
127         }
128 
129         L(loop_label[2]);
130         postamble();
131     }
132 
133 private:
134     using TReg = typename cpu_isa_traits<isa>::TReg;
135     using TRegS = typename cpu_isa_traits<isa>::TRegS;
136 
137     const XReg reg_from = x1;
138     const XReg reg_to = x8;
139     const XReg reg_work_amount = x6;
140     const XReg imm_addr64 = x3;
141 
142     const TReg t_tmp0 = TReg(31);
143 
144     const TReg t_saturation_ubound = TReg(26);
145     const TRegS ts_alpha = TRegS(27);
146     const TRegS ts_beta = TRegS(28);
147     const TReg t_zero = TReg(29);
148 
149     const PReg p_vl1 = p0;
150     const PReg p_mask = p1;
151     const PReg p_mask_int8 = p_vl1; // Mask for store 1 byte in case of SVE_512
152     const PReg p_all_one = p3;
153 
is32bitdnnl::impl::cpu::aarch64::__anonbfc0da950111::jit_uni_subkernel_int_t154     bool is32bit() const { return data_type() == data_type::s32; }
155 
156     // Load 32bit data type (s32)
load_32bitdnnl::impl::cpu::aarch64::__anonbfc0da950111::jit_uni_subkernel_int_t157     void load_32bit(
158             const bool vectorize, const TReg &vr_from, const XReg &mem_from) {
159 
160         if (vectorize) {
161             // load full TReg size
162             uni_ldr(vr_from, mem_from);
163         } else {
164             // load exactly one data item
165             ldr(W_TMP_0, ptr(mem_from));
166             mov(vr_from.s, W_TMP_0);
167         }
168     }
169 
170     // Load 8bit data type (u8/s8)
load_8bitdnnl::impl::cpu::aarch64::__anonbfc0da950111::jit_uni_subkernel_int_t171     void load_8bit(const bool vectorize, const TReg &vr_from,
172             const XReg &mem_from, bool is_signed) {
173 
174         // data type u8/s8 load as s32
175         if (vectorize) {
176             // load full TReg size
177             ldr(QReg(t_tmp0.getIdx()), ptr(mem_from));
178             zip1(t_tmp0.b, t_tmp0.b, t_tmp0.b);
179             zip1(t_tmp0.h, t_tmp0.h, t_tmp0.h);
180 
181             if (is_signed)
182                 sxtb(vr_from.s, p_all_one / T_m, t_tmp0.s);
183             else
184                 uxtb(vr_from.s, p_all_one / T_m, t_tmp0.s);
185         } else {
186             // load exactly one data item
187             ldurb(W_TMP_0, ptr(mem_from));
188             uni_clear(vr_from);
189 
190             if (is_signed)
191                 sxtb(W_TMP_0, W_TMP_0);
192             else
193                 uxtb(W_TMP_0, W_TMP_0);
194 
195             mov(VReg(vr_from.getIdx()).d[0], X_TMP_0);
196         }
197     }
198 
199     // Load vregs with data from mem
loaddnnl::impl::cpu::aarch64::__anonbfc0da950111::jit_uni_subkernel_int_t200     void load(const bool vectorize, const TReg &vr_from, const XReg &mem_from) {
201 
202         // Branching on data size
203         if (is32bit())
204             load_32bit(vectorize, vr_from, mem_from);
205         else
206             load_8bit(
207                     vectorize, vr_from, mem_from, data_type() == data_type::s8);
208     }
209 
210     // Processing
211     void process_linear(const TReg &vr_to, const TReg &vr_from);
212     void process_relu(const TReg &vr_to, const TReg &vr_from);
213 
214     // Store s32 for any isa
store_32bitdnnl::impl::cpu::aarch64::__anonbfc0da950111::jit_uni_subkernel_int_t215     void store_32bit(
216             const bool vectorize, const XReg &mem_to, const TReg &vr_to) {
217         if (vectorize) {
218             // store full TReg size
219             uni_str(vr_to, mem_to);
220         } else {
221             // store exactly one data item
222             st1w(vr_to.s, p_vl1, ptr(mem_to));
223         }
224     }
225 
226     // Store 8 bit int - isa-dependent
227     void store_8bit(const bool vectorize, const XReg &mem_to, const TReg &vr_to,
228             bool is_signed);
229 
230     // Store results from vregs to mem
storednnl::impl::cpu::aarch64::__anonbfc0da950111::jit_uni_subkernel_int_t231     void store(const bool vectorize, const XReg &mem_to, const TReg &vr_to) {
232         // Branching on data size
233         if (is32bit())
234             store_32bit(vectorize, mem_to, vr_to);
235         else
236             store_8bit(vectorize, mem_to, vr_to, data_type() == data_type::s8);
237     }
238 
compute_stepdnnl::impl::cpu::aarch64::__anonbfc0da950111::jit_uni_subkernel_int_t239     void compute_step(bool vectorize, const size_t uf, const size_t shift,
240             const alg_kind_t alg) {
241 
242         auto vreg_from = [&](const size_t i) -> TReg { return TReg(i + 1); };
243         auto vreg_to = [&](const size_t i) -> TReg { return TReg(uf + i + 1); };
244 
245         // 1. Load (vregs <- mem)
246         for (size_t i = 0; i < uf; i++) {
247             add_imm(reg_from, reg_from, i * shift, X_TMP_0);
248             load(vectorize, vreg_from(i), reg_from);
249         }
250 
251         // 2. Process (vregs <- vergs)
252         switch (alg) {
253             case alg_kind::eltwise_linear:
254                 for (size_t i = 0; i < uf; i++)
255                     process_linear(vreg_to(i), vreg_from(i));
256                 break;
257             case alg_kind::eltwise_relu:
258                 for (size_t i = 0; i < uf; i++)
259                     process_relu(vreg_to(i), vreg_from(i));
260                 break;
261             default: assert(!"unsupported alg");
262         }
263 
264         // 3. Store (mem <- vregs)
265         for (size_t i = 0; i < uf; i++) {
266             add_imm(reg_to, reg_to, i * shift, X_TMP_0);
267             store(vectorize, reg_to, vreg_to(i));
268         }
269     }
270 };
271 
272 template <cpu_isa_t isa>
process_linear(const TReg & vr_to,const TReg & vr_from)273 void jit_uni_subkernel_int_t<isa>::process_linear(
274         const TReg &vr_to, const TReg &vr_from) {
275 
276     scvtf(vr_to.s, p_all_one / T_m, vr_from.s);
277     fmad(vr_to.s, p_all_one / T_m, ts_alpha, ts_beta);
278 
279     // Saturate before converting from f32 to s32
280     XReg reg_tmp = x10;
281 
282     uni_clear(t_zero);
283     init_saturate_f32(
284             t_zero, t_saturation_ubound, reg_tmp, data_type::f32, data_type());
285     saturate_f32(vr_to, t_zero, t_saturation_ubound, data_type(), p_all_one);
286 
287     frinti(vr_to.s, p_all_one / T_m, vr_to.s);
288     fcvtzs(vr_to.s, p_all_one / T_m, vr_to.s);
289 }
290 
291 template <cpu_isa_t isa>
process_relu(const TReg & vr_to,const TReg & vr_from)292 void jit_uni_subkernel_int_t<isa>::process_relu(
293         const TReg &vr_to, const TReg &vr_from) {
294     assert(!"unsupported isa");
295 }
296 
297 template <>
process_relu(const TReg & vr_to,const TReg & vr_from)298 void jit_uni_subkernel_int_t<sve_512>::process_relu(
299         const TReg &vr_to, const TReg &vr_from) {
300 
301     scvtf(vr_from.s, p_all_one / T_m, vr_from.s);
302 
303     fmul(vr_to.s, vr_from.s, ts_alpha);
304 
305     fcmgt(p_mask.s, p_all_one / T_z, vr_from.s, t_zero.s);
306 
307     sel(vr_to.s, p_mask / T_m, vr_from.s, vr_to.s);
308 
309     frinti(vr_to.s, p_all_one / T_m, vr_to.s);
310     fcvtzs(vr_to.s, p_all_one / T_m, vr_to.s);
311 }
312 
313 template <cpu_isa_t isa>
store_8bit(const bool vectorize,const XReg & mem_to,const TReg & vr_to,bool is_signed)314 void jit_uni_subkernel_int_t<isa>::store_8bit(const bool vectorize,
315         const XReg &mem_to, const TReg &vr_to, bool is_signed) {
316     assert(!"unsupported isa");
317 }
318 
319 template <>
store_8bit(const bool vectorize,const XReg & mem_to,const TReg & vr_to,bool is_signed)320 void jit_uni_subkernel_int_t<sve_512>::store_8bit(const bool vectorize,
321         const XReg &mem_to, const TReg &vr_to, bool is_signed) {
322     if (vectorize) {
323         // store full TReg size
324         mov(t_tmp0.d, vr_to.d);
325         if (is_signed) {
326             smin(t_tmp0.s, 127);
327             smax(t_tmp0.s, -128);
328         } else {
329             umin(t_tmp0.s, 255);
330         }
331         st1b(t_tmp0.s, p_all_one, ptr(mem_to));
332     } else {
333         // store exactly one data item
334         // s32 save as s8/u8
335         mov(t_tmp0.d, vr_to.d);
336         if (is_signed) {
337             smin(t_tmp0.s, 127);
338             smax(t_tmp0.s, -128);
339         } else {
340             umin(t_tmp0.s, 255);
341         }
342         st1b(t_tmp0.s, p_mask_int8, ptr(mem_to));
343     }
344 }
345 
346 } /* namespace */
347 
348 template <cpu_isa_t isa, data_type_t d_type>
init(engine_t * engine)349 status_t jit_uni_eltwise_int_fwd_t<isa, d_type>::pd_t::init(engine_t *engine) {
350     bool ok = mayiuse(isa)
351             && desc()->data_desc.data_type == d_type
352             // only relu and linear so far
353             && utils::one_of(desc()->alg_kind, alg_kind::eltwise_relu,
354                     alg_kind::eltwise_linear)
355             && !has_zero_dim_memory()
356             && memory_desc_wrapper(src_md()).is_dense(true)
357             && attr()->has_default_values();
358 
359     return ok ? status::success : status::unimplemented;
360 }
361 
362 template <cpu_isa_t isa, data_type_t d_type>
jit_uni_eltwise_int_fwd_t(const pd_t * apd)363 jit_uni_eltwise_int_fwd_t<isa, d_type>::jit_uni_eltwise_int_fwd_t(
364         const pd_t *apd)
365     : primitive_t(apd) {}
366 
367 template <cpu_isa_t isa, data_type_t d_type>
init(engine_t * engine)368 status_t jit_uni_eltwise_int_fwd_t<isa, d_type>::init(engine_t *engine) {
369     const auto &desc = *pd()->desc();
370     CHECK(safe_ptr_assign(kernel_, new jit_uni_subkernel_int_t<isa>(desc)));
371     return kernel_->create_kernel();
372 }
373 
374 template <cpu_isa_t isa, data_type_t d_type>
~jit_uni_eltwise_int_fwd_t()375 jit_uni_eltwise_int_fwd_t<isa, d_type>::~jit_uni_eltwise_int_fwd_t() {
376     delete kernel_;
377 }
378 
379 template <cpu_isa_t isa, impl::data_type_t d_type>
execute_forward(const exec_ctx_t & ctx) const380 status_t jit_uni_eltwise_int_fwd_t<isa, d_type>::execute_forward(
381         const exec_ctx_t &ctx) const {
382     status_t status = status::success;
383     auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
384     auto dst = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_DST, status);
385     CHECK(status);
386 
387     const memory_desc_wrapper data_d(pd()->data_md());
388 
389     const size_t nelems = data_d.nelems(true);
390 
391     src += data_d.offset0();
392     dst += data_d.offset0();
393 
394     const int cache_line = 64 / data_d.data_type_size();
395     parallel(0, [&](const int ithr, const int nthr) {
396         size_t start {0}, end {0};
397 
398         balance211(utils::div_up(nelems, cache_line), nthr, ithr, start, end);
399         start = nstl::min(nelems, start * cache_line);
400         end = nstl::min(nelems, end * cache_line);
401 
402         auto arg = jit_args_t();
403         arg.from = (const void *)&src[start];
404         arg.for_comparison = (const void *)&src[start];
405         arg.to = (const void *)&dst[start];
406         arg.work_amount = end - start;
407         if (arg.work_amount) (*kernel_)(&arg);
408     });
409     return status::success;
410 }
411 
412 using namespace data_type;
413 
414 template struct jit_uni_eltwise_int_fwd_t<sve_512, s32>;
415 template struct jit_uni_eltwise_int_fwd_t<sve_512, data_type::s8>;
416 template struct jit_uni_eltwise_int_fwd_t<sve_512, u8>;
417 
418 } // namespace aarch64
419 } // namespace cpu
420 } // namespace impl
421 } // namespace dnnl
422