1 /*******************************************************************************
2 * Copyright 2017-2021 Intel Corporation
3 * Copyright 2021 FUJITSU LIMITED
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *******************************************************************************/
17 
18 #include "common/bfloat16.hpp"
19 #include "common/c_types_map.hpp"
20 #include "common/dnnl_thread.hpp"
21 #include "common/nstl.hpp"
22 #include "common/utils.hpp"
23 
24 #include "cpu/aarch64/jit_generator.hpp"
25 
26 #include "cpu/aarch64/injectors/jit_uni_eltwise_injector.hpp"
27 #include "cpu/aarch64/jit_uni_eltwise.hpp"
28 
29 #define GET_OFF(field) offsetof(jit_args_t, field)
30 
31 namespace dnnl {
32 namespace impl {
33 namespace cpu {
34 namespace aarch64 {
35 
36 using namespace Xbyak_aarch64;
37 
38 struct jit_args_t {
39     const void *src; // fwd: src;  bwd: src/dst based on alg;
40     const void *dst; // fwd: dst;  bwd: diff_src;
41     const void *diff_dst; // fwd: nullptr;  bwd: diff_dst;
42     size_t work_amount;
43 };
44 
45 struct jit_uni_eltwise_kernel : public jit_generator {
jit_uni_eltwise_kerneldnnl::impl::cpu::aarch64::jit_uni_eltwise_kernel46     jit_uni_eltwise_kernel(const eltwise_pd_t *pd) : pd_(pd) {}
47 
operator ()dnnl::impl::cpu::aarch64::jit_uni_eltwise_kernel48     void operator()(jit_args_t *p) { jit_generator::operator()(p); }
49 
50 protected:
51     const eltwise_pd_t *pd_;
52 
data_typednnl::impl::cpu::aarch64::jit_uni_eltwise_kernel53     data_type_t data_type() const { return pd_->src_md()->data_type; }
is_bf16dnnl::impl::cpu::aarch64::jit_uni_eltwise_kernel54     bool is_bf16() const { return data_type() == data_type::bf16; }
dtype_sizednnl::impl::cpu::aarch64::jit_uni_eltwise_kernel55     int dtype_size() const { return types::data_type_size(data_type()); }
56 };
57 
58 // jit kernels
59 namespace {
60 
61 template <cpu_isa_t isa>
62 struct jit_uni_kernel_t : public jit_uni_eltwise_kernel {
63     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_kernel)
64 
jit_uni_kernel_tdnnl::impl::cpu::aarch64::__anon94fa034b0111::jit_uni_kernel_t65     jit_uni_kernel_t(const eltwise_pd_t *pd) : jit_uni_eltwise_kernel(pd) {
66         const auto &desc = *pd_->desc();
67         // there's no auxiliary vregs on fwd path
68         const bool is_fwd = pd_->is_fwd();
69         const bool save_state = is_fwd ? false : true;
70         eltwise_injector_.reset(new jit_uni_eltwise_injector_f32<isa>(this,
71                 desc.alg_kind, desc.alpha, desc.beta, 1.f, save_state,
72                 reg_injector_table, injector_mask, injector_p_tmp0,
73                 injector_p_all, is_fwd, pd_->use_dst()));
74     }
75 
generatednnl::impl::cpu::aarch64::__anon94fa034b0111::jit_uni_kernel_t76     void generate() override {
77         const bool is_fwd = pd_->is_fwd();
78         preamble();
79 
80         XReg param = param1;
81         add_imm(X_TMP_0, param, GET_OFF(src), X_TMP_1);
82         ldr(reg_src, ptr(X_TMP_0));
83         add_imm(X_TMP_0, param, GET_OFF(dst), X_TMP_1);
84         ldr(reg_dst, ptr(X_TMP_0));
85         if (!is_fwd) {
86             add_imm(X_TMP_0, param, GET_OFF(diff_dst), X_TMP_1);
87             ldr(reg_diff_dst, ptr(X_TMP_0));
88         }
89         add_imm(X_TMP_0, param, GET_OFF(work_amount), X_TMP_1);
90         ldr(reg_work_amount, ptr(X_TMP_0));
91         eltwise_injector_->load_table_addr();
92         ptrue(p_512.b);
93 
94         Label reminder_loop_start, reminder_loop_end;
95         Label vectorized_loop_start, vectorized_loop_end;
96 
97         cmp(reg_work_amount, simd_w());
98         b(LT, reminder_loop_start);
99 
100         L(vectorized_loop_start);
101 
102         // TODO: consider improving.
103         // This piece of code is responsible for the preserve_zero function
104         // being a natural restriction of this implementation. It works with any
105         // dense and blocked layout, but the problem raises when blocking
106         // dimension is not divisible by block size. For such case, the code
107         // below should save the mask, where zero padding should be preserved
108         // and apply it on register before storing into dst memory. Until
109         // there's a restriction on certain blocked layouts, when this behavior
110         // can be relevantly easy controlled, this will cost much from code
111         // perspective and will complicate the compute logic significantly.
112         ldr(vmm_src, ptr(reg_src));
113         eltwise_injector_->compute_vector(vmm_src.getIdx());
114         if (!is_fwd) {
115             ldr(ZReg(vmm_diff_dst.getIdx()), ptr(reg_diff_dst));
116             fmul(vmm_src.s, vmm_src.s, vmm_diff_dst);
117         }
118         str(vmm_src, ptr(reg_dst));
119 
120         const auto shift = cpu_isa_traits<isa>::vlen;
121         add_imm(reg_src, reg_src, shift, X_TMP_0);
122         add_imm(reg_dst, reg_dst, shift, X_TMP_0);
123         if (!is_fwd) add_imm(reg_diff_dst, reg_diff_dst, shift, X_TMP_0);
124 
125         sub_imm(reg_work_amount, reg_work_amount, simd_w(), X_TMP_0);
126         cmp(reg_work_amount, simd_w());
127         b(GE, vectorized_loop_start);
128 
129         L(vectorized_loop_end);
130 
131         L(reminder_loop_start);
132 
133         cmp(reg_work_amount, 0);
134         b(LE, reminder_loop_end);
135 
136         ld1(xmm_src[0], ptr(reg_src));
137         eltwise_injector_->compute_vector(xmm_src.getIdx());
138         if (!is_fwd) {
139             ld1(xmm_diff_dst[0], ptr(reg_diff_dst));
140             fmul(xmm_src, xmm_src, xmm_diff_dst);
141         }
142         st1(xmm_src[0], ptr(reg_dst));
143         add_imm(reg_src, reg_src, dtype_size(), X_TMP_0);
144         add_imm(reg_dst, reg_dst, dtype_size(), X_TMP_0);
145         if (!is_fwd) add_imm(reg_diff_dst, reg_diff_dst, dtype_size(), X_TMP_0);
146 
147         subs(reg_work_amount, reg_work_amount, 1);
148         b(reminder_loop_start);
149 
150         L(reminder_loop_end);
151 
152         postamble();
153 
154         eltwise_injector_->prepare_table();
155     }
156 
157 private:
158     using TReg = typename cpu_isa_traits<isa>::TReg;
159     using TRegS = typename cpu_isa_traits<isa>::TRegS;
160 
simd_wdnnl::impl::cpu::aarch64::__anon94fa034b0111::jit_uni_kernel_t161     int simd_w() {
162         int simd_w = cpu_isa_traits<isa>::vlen / dtype_size();
163         /* Return value is used for CMP (immediate). */
164         assert(simd_w < (1 << 12));
165         return simd_w;
166     }
167 
168     XReg reg_src = x11;
169     XReg reg_dst = x8;
170     XReg reg_injector_table = x9;
171     XReg reg_diff_dst = x10;
172     XReg reg_work_amount = x6;
173     XReg imm_addr64 = x3;
174     PReg injector_mask = p1;
175     PReg injector_p_tmp0 = p4;
176     PReg injector_p_all = p7;
177 
178     VReg4S xmm_src {1};
179     TReg vmm_src {1};
180     VReg4S xmm_diff_dst {2};
181     TRegS vmm_diff_dst {2};
182     std::unique_ptr<jit_uni_eltwise_injector_f32<isa>> eltwise_injector_;
183 
184     PReg p_512 {7}; /* Index is temporal. */
185     PReg p_tmp0 {4}; /* Index is temporal. */
186 };
187 
188 } // namespace
189 
190 template <cpu_isa_t isa, data_type_t d_type>
init(engine_t * engine)191 status_t jit_uni_eltwise_fwd_t<isa, d_type>::pd_t::init(engine_t *engine) {
192     using namespace alg_kind;
193 
194     const memory_desc_wrapper data_d(src_md());
195 
196     bool ok = mayiuse(isa) && is_fwd() && src_md()->data_type == d_type
197             && !has_zero_dim_memory()
198             && data_d.is_dense(true)
199             // refer to a comment in jit_uni_kernel why this is needed
200             && IMPLICATION(!data_d.is_dense(), is_zero_preserved())
201             && attr()->has_default_values();
202 
203     ok &= utils::one_of(desc_.alg_kind, eltwise_relu_use_dst_for_bwd,
204             eltwise_relu, eltwise_elu_use_dst_for_bwd, eltwise_elu,
205             eltwise_tanh_use_dst_for_bwd, eltwise_tanh, eltwise_square,
206             eltwise_abs, eltwise_sqrt_use_dst_for_bwd, eltwise_sqrt,
207             eltwise_linear, eltwise_bounded_relu, eltwise_soft_relu,
208             eltwise_logistic_use_dst_for_bwd, eltwise_logistic,
209             eltwise_exp_use_dst_for_bwd, eltwise_exp, eltwise_gelu_tanh,
210             eltwise_swish, eltwise_log, eltwise_clip, eltwise_gelu_erf,
211             eltwise_round);
212 
213     return ok ? status::success : status::unimplemented;
214 }
215 
216 template <cpu_isa_t isa, data_type_t d_type>
jit_uni_eltwise_fwd_t(const pd_t * apd)217 jit_uni_eltwise_fwd_t<isa, d_type>::jit_uni_eltwise_fwd_t(const pd_t *apd)
218     : primitive_t(apd) {}
219 
220 template <cpu_isa_t isa, data_type_t d_type>
221 jit_uni_eltwise_fwd_t<isa, d_type>::~jit_uni_eltwise_fwd_t() = default;
222 
223 template <cpu_isa_t isa, data_type_t d_type>
init(engine_t * engine)224 status_t jit_uni_eltwise_fwd_t<isa, d_type>::init(engine_t *engine) {
225     CHECK(safe_ptr_assign(kernel_, new jit_uni_kernel_t<isa>(pd())));
226     return kernel_->create_kernel();
227 }
228 
229 template <cpu_isa_t isa, data_type_t d_type>
execute(const exec_ctx_t & ctx) const230 status_t jit_uni_eltwise_fwd_t<isa, d_type>::execute(
231         const exec_ctx_t &ctx) const {
232     status_t status = status::success;
233     auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
234     auto dst = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_DST, status);
235     CHECK(status);
236 
237     const memory_desc_wrapper data_d(pd()->src_md());
238     const auto nelems = data_d.nelems(true);
239     const int simd_w = 64 / data_d.data_type_size();
240 
241     src += data_d.offset0();
242     dst += data_d.offset0();
243 
244     parallel(0, [&](const int ithr, const int nthr) {
245         dim_t start {0}, end {0};
246 
247         balance211(utils::div_up(nelems, simd_w), nthr, ithr, start, end);
248         start = nstl::min(nelems, start * simd_w);
249         end = nstl::min(nelems, end * simd_w);
250         if (start == end) return;
251 
252         jit_args_t args;
253         args.src = src + start;
254         args.dst = dst + start;
255         args.diff_dst = nullptr;
256         args.work_amount = end - start;
257         (*kernel_)(&args);
258     });
259 
260     return status::success;
261 }
262 
263 template <cpu_isa_t isa, data_type_t d_type>
init(engine_t * engine)264 status_t jit_uni_eltwise_bwd_t<isa, d_type>::pd_t::init(engine_t *engine) {
265     using namespace alg_kind;
266 
267     const memory_desc_wrapper data_d(src_md());
268 
269     bool ok = mayiuse(isa) && !is_fwd()
270             && utils::everyone_is(
271                     d_type, src_md()->data_type, diff_src_md()->data_type)
272             && !has_zero_dim_memory() && set_default_formats_common()
273             && data_d.is_dense(true)
274             // refer to a comment in jit_uni_kernel why this is needed
275             && IMPLICATION(!data_d.is_dense(), is_zero_preserved())
276             && data_d == memory_desc_wrapper(diff_dst_md())
277             && attr()->has_default_values();
278 
279     ok &= utils::one_of(desc_.alg_kind, eltwise_relu_use_dst_for_bwd,
280             eltwise_relu, eltwise_elu_use_dst_for_bwd, eltwise_elu,
281             eltwise_tanh_use_dst_for_bwd, eltwise_tanh, eltwise_square,
282             eltwise_abs, eltwise_sqrt_use_dst_for_bwd, eltwise_sqrt,
283             eltwise_linear, eltwise_bounded_relu, eltwise_soft_relu,
284             eltwise_logistic_use_dst_for_bwd, eltwise_logistic,
285             eltwise_exp_use_dst_for_bwd, eltwise_exp, eltwise_gelu_tanh,
286             eltwise_swish, eltwise_log, eltwise_clip, eltwise_gelu_erf);
287 
288     return ok ? status::success : status::unimplemented;
289 }
290 
291 template <cpu_isa_t isa, data_type_t d_type>
jit_uni_eltwise_bwd_t(const pd_t * apd)292 jit_uni_eltwise_bwd_t<isa, d_type>::jit_uni_eltwise_bwd_t(const pd_t *apd)
293     : primitive_t(apd) {}
294 
295 template <cpu_isa_t isa, data_type_t d_type>
296 jit_uni_eltwise_bwd_t<isa, d_type>::~jit_uni_eltwise_bwd_t() = default;
297 
298 template <cpu_isa_t isa, data_type_t d_type>
init(engine_t * engine)299 status_t jit_uni_eltwise_bwd_t<isa, d_type>::init(engine_t *engine) {
300     CHECK(safe_ptr_assign(kernel_, new jit_uni_kernel_t<isa>(pd())));
301     return kernel_->create_kernel();
302 }
303 
304 template <cpu_isa_t isa, data_type_t d_type>
execute(const exec_ctx_t & ctx) const305 status_t jit_uni_eltwise_bwd_t<isa, d_type>::execute(
306         const exec_ctx_t &ctx) const {
307     status_t status = status::success;
308     auto src = pd()->use_dst() ? CTX_IN_MEM(const data_t *, DNNL_ARG_DST)
309                                : CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
310     auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST);
311     auto diff_src = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_DIFF_SRC, status);
312     CHECK(status);
313 
314     const memory_desc_wrapper data_d(pd()->src_md());
315     const memory_desc_wrapper diff_data_d(pd()->diff_src_md());
316     const auto nelems = data_d.nelems(true);
317     const int simd_w = 64 / data_d.data_type_size();
318 
319     src += data_d.offset0();
320     diff_dst += diff_data_d.offset0();
321     diff_src += diff_data_d.offset0();
322 
323     parallel(0, [&](const int ithr, const int nthr) {
324         dim_t start {0}, end {0};
325 
326         balance211(utils::div_up(nelems, simd_w), nthr, ithr, start, end);
327         start = nstl::min(nelems, start * simd_w);
328         end = nstl::min(nelems, end * simd_w);
329         if (start == end) return;
330 
331         jit_args_t args;
332         args.src = src + start;
333         args.dst = diff_src + start;
334         args.diff_dst = diff_dst + start;
335         args.work_amount = end - start;
336         (*kernel_)(&args);
337     });
338 
339     return status::success;
340 }
341 
342 template struct jit_uni_eltwise_fwd_t<sve_512, data_type::f32>;
343 template struct jit_uni_eltwise_bwd_t<sve_512, data_type::f32>;
344 
345 } // namespace aarch64
346 } // namespace cpu
347 } // namespace impl
348 } // namespace dnnl
349