1 /*******************************************************************************
2 * Copyright 2017-2021 Intel Corporation
3 * Copyright 2021 FUJITSU LIMITED
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *******************************************************************************/
17
18 #include "common/bfloat16.hpp"
19 #include "common/c_types_map.hpp"
20 #include "common/dnnl_thread.hpp"
21 #include "common/nstl.hpp"
22 #include "common/utils.hpp"
23
24 #include "cpu/aarch64/jit_generator.hpp"
25
26 #include "cpu/aarch64/injectors/jit_uni_eltwise_injector.hpp"
27 #include "cpu/aarch64/jit_uni_eltwise.hpp"
28
29 #define GET_OFF(field) offsetof(jit_args_t, field)
30
31 namespace dnnl {
32 namespace impl {
33 namespace cpu {
34 namespace aarch64 {
35
36 using namespace Xbyak_aarch64;
37
38 struct jit_args_t {
39 const void *src; // fwd: src; bwd: src/dst based on alg;
40 const void *dst; // fwd: dst; bwd: diff_src;
41 const void *diff_dst; // fwd: nullptr; bwd: diff_dst;
42 size_t work_amount;
43 };
44
45 struct jit_uni_eltwise_kernel : public jit_generator {
jit_uni_eltwise_kerneldnnl::impl::cpu::aarch64::jit_uni_eltwise_kernel46 jit_uni_eltwise_kernel(const eltwise_pd_t *pd) : pd_(pd) {}
47
operator ()dnnl::impl::cpu::aarch64::jit_uni_eltwise_kernel48 void operator()(jit_args_t *p) { jit_generator::operator()(p); }
49
50 protected:
51 const eltwise_pd_t *pd_;
52
data_typednnl::impl::cpu::aarch64::jit_uni_eltwise_kernel53 data_type_t data_type() const { return pd_->src_md()->data_type; }
is_bf16dnnl::impl::cpu::aarch64::jit_uni_eltwise_kernel54 bool is_bf16() const { return data_type() == data_type::bf16; }
dtype_sizednnl::impl::cpu::aarch64::jit_uni_eltwise_kernel55 int dtype_size() const { return types::data_type_size(data_type()); }
56 };
57
58 // jit kernels
59 namespace {
60
61 template <cpu_isa_t isa>
62 struct jit_uni_kernel_t : public jit_uni_eltwise_kernel {
63 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_kernel)
64
jit_uni_kernel_tdnnl::impl::cpu::aarch64::__anon94fa034b0111::jit_uni_kernel_t65 jit_uni_kernel_t(const eltwise_pd_t *pd) : jit_uni_eltwise_kernel(pd) {
66 const auto &desc = *pd_->desc();
67 // there's no auxiliary vregs on fwd path
68 const bool is_fwd = pd_->is_fwd();
69 const bool save_state = is_fwd ? false : true;
70 eltwise_injector_.reset(new jit_uni_eltwise_injector_f32<isa>(this,
71 desc.alg_kind, desc.alpha, desc.beta, 1.f, save_state,
72 reg_injector_table, injector_mask, injector_p_tmp0,
73 injector_p_all, is_fwd, pd_->use_dst()));
74 }
75
generatednnl::impl::cpu::aarch64::__anon94fa034b0111::jit_uni_kernel_t76 void generate() override {
77 const bool is_fwd = pd_->is_fwd();
78 preamble();
79
80 XReg param = param1;
81 add_imm(X_TMP_0, param, GET_OFF(src), X_TMP_1);
82 ldr(reg_src, ptr(X_TMP_0));
83 add_imm(X_TMP_0, param, GET_OFF(dst), X_TMP_1);
84 ldr(reg_dst, ptr(X_TMP_0));
85 if (!is_fwd) {
86 add_imm(X_TMP_0, param, GET_OFF(diff_dst), X_TMP_1);
87 ldr(reg_diff_dst, ptr(X_TMP_0));
88 }
89 add_imm(X_TMP_0, param, GET_OFF(work_amount), X_TMP_1);
90 ldr(reg_work_amount, ptr(X_TMP_0));
91 eltwise_injector_->load_table_addr();
92 ptrue(p_512.b);
93
94 Label reminder_loop_start, reminder_loop_end;
95 Label vectorized_loop_start, vectorized_loop_end;
96
97 cmp(reg_work_amount, simd_w());
98 b(LT, reminder_loop_start);
99
100 L(vectorized_loop_start);
101
102 // TODO: consider improving.
103 // This piece of code is responsible for the preserve_zero function
104 // being a natural restriction of this implementation. It works with any
105 // dense and blocked layout, but the problem raises when blocking
106 // dimension is not divisible by block size. For such case, the code
107 // below should save the mask, where zero padding should be preserved
108 // and apply it on register before storing into dst memory. Until
109 // there's a restriction on certain blocked layouts, when this behavior
110 // can be relevantly easy controlled, this will cost much from code
111 // perspective and will complicate the compute logic significantly.
112 ldr(vmm_src, ptr(reg_src));
113 eltwise_injector_->compute_vector(vmm_src.getIdx());
114 if (!is_fwd) {
115 ldr(ZReg(vmm_diff_dst.getIdx()), ptr(reg_diff_dst));
116 fmul(vmm_src.s, vmm_src.s, vmm_diff_dst);
117 }
118 str(vmm_src, ptr(reg_dst));
119
120 const auto shift = cpu_isa_traits<isa>::vlen;
121 add_imm(reg_src, reg_src, shift, X_TMP_0);
122 add_imm(reg_dst, reg_dst, shift, X_TMP_0);
123 if (!is_fwd) add_imm(reg_diff_dst, reg_diff_dst, shift, X_TMP_0);
124
125 sub_imm(reg_work_amount, reg_work_amount, simd_w(), X_TMP_0);
126 cmp(reg_work_amount, simd_w());
127 b(GE, vectorized_loop_start);
128
129 L(vectorized_loop_end);
130
131 L(reminder_loop_start);
132
133 cmp(reg_work_amount, 0);
134 b(LE, reminder_loop_end);
135
136 ld1(xmm_src[0], ptr(reg_src));
137 eltwise_injector_->compute_vector(xmm_src.getIdx());
138 if (!is_fwd) {
139 ld1(xmm_diff_dst[0], ptr(reg_diff_dst));
140 fmul(xmm_src, xmm_src, xmm_diff_dst);
141 }
142 st1(xmm_src[0], ptr(reg_dst));
143 add_imm(reg_src, reg_src, dtype_size(), X_TMP_0);
144 add_imm(reg_dst, reg_dst, dtype_size(), X_TMP_0);
145 if (!is_fwd) add_imm(reg_diff_dst, reg_diff_dst, dtype_size(), X_TMP_0);
146
147 subs(reg_work_amount, reg_work_amount, 1);
148 b(reminder_loop_start);
149
150 L(reminder_loop_end);
151
152 postamble();
153
154 eltwise_injector_->prepare_table();
155 }
156
157 private:
158 using TReg = typename cpu_isa_traits<isa>::TReg;
159 using TRegS = typename cpu_isa_traits<isa>::TRegS;
160
simd_wdnnl::impl::cpu::aarch64::__anon94fa034b0111::jit_uni_kernel_t161 int simd_w() {
162 int simd_w = cpu_isa_traits<isa>::vlen / dtype_size();
163 /* Return value is used for CMP (immediate). */
164 assert(simd_w < (1 << 12));
165 return simd_w;
166 }
167
168 XReg reg_src = x11;
169 XReg reg_dst = x8;
170 XReg reg_injector_table = x9;
171 XReg reg_diff_dst = x10;
172 XReg reg_work_amount = x6;
173 XReg imm_addr64 = x3;
174 PReg injector_mask = p1;
175 PReg injector_p_tmp0 = p4;
176 PReg injector_p_all = p7;
177
178 VReg4S xmm_src {1};
179 TReg vmm_src {1};
180 VReg4S xmm_diff_dst {2};
181 TRegS vmm_diff_dst {2};
182 std::unique_ptr<jit_uni_eltwise_injector_f32<isa>> eltwise_injector_;
183
184 PReg p_512 {7}; /* Index is temporal. */
185 PReg p_tmp0 {4}; /* Index is temporal. */
186 };
187
188 } // namespace
189
190 template <cpu_isa_t isa, data_type_t d_type>
init(engine_t * engine)191 status_t jit_uni_eltwise_fwd_t<isa, d_type>::pd_t::init(engine_t *engine) {
192 using namespace alg_kind;
193
194 const memory_desc_wrapper data_d(src_md());
195
196 bool ok = mayiuse(isa) && is_fwd() && src_md()->data_type == d_type
197 && !has_zero_dim_memory()
198 && data_d.is_dense(true)
199 // refer to a comment in jit_uni_kernel why this is needed
200 && IMPLICATION(!data_d.is_dense(), is_zero_preserved())
201 && attr()->has_default_values();
202
203 ok &= utils::one_of(desc_.alg_kind, eltwise_relu_use_dst_for_bwd,
204 eltwise_relu, eltwise_elu_use_dst_for_bwd, eltwise_elu,
205 eltwise_tanh_use_dst_for_bwd, eltwise_tanh, eltwise_square,
206 eltwise_abs, eltwise_sqrt_use_dst_for_bwd, eltwise_sqrt,
207 eltwise_linear, eltwise_bounded_relu, eltwise_soft_relu,
208 eltwise_logistic_use_dst_for_bwd, eltwise_logistic,
209 eltwise_exp_use_dst_for_bwd, eltwise_exp, eltwise_gelu_tanh,
210 eltwise_swish, eltwise_log, eltwise_clip, eltwise_gelu_erf,
211 eltwise_round);
212
213 return ok ? status::success : status::unimplemented;
214 }
215
216 template <cpu_isa_t isa, data_type_t d_type>
jit_uni_eltwise_fwd_t(const pd_t * apd)217 jit_uni_eltwise_fwd_t<isa, d_type>::jit_uni_eltwise_fwd_t(const pd_t *apd)
218 : primitive_t(apd) {}
219
220 template <cpu_isa_t isa, data_type_t d_type>
221 jit_uni_eltwise_fwd_t<isa, d_type>::~jit_uni_eltwise_fwd_t() = default;
222
223 template <cpu_isa_t isa, data_type_t d_type>
init(engine_t * engine)224 status_t jit_uni_eltwise_fwd_t<isa, d_type>::init(engine_t *engine) {
225 CHECK(safe_ptr_assign(kernel_, new jit_uni_kernel_t<isa>(pd())));
226 return kernel_->create_kernel();
227 }
228
229 template <cpu_isa_t isa, data_type_t d_type>
execute(const exec_ctx_t & ctx) const230 status_t jit_uni_eltwise_fwd_t<isa, d_type>::execute(
231 const exec_ctx_t &ctx) const {
232 status_t status = status::success;
233 auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
234 auto dst = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_DST, status);
235 CHECK(status);
236
237 const memory_desc_wrapper data_d(pd()->src_md());
238 const auto nelems = data_d.nelems(true);
239 const int simd_w = 64 / data_d.data_type_size();
240
241 src += data_d.offset0();
242 dst += data_d.offset0();
243
244 parallel(0, [&](const int ithr, const int nthr) {
245 dim_t start {0}, end {0};
246
247 balance211(utils::div_up(nelems, simd_w), nthr, ithr, start, end);
248 start = nstl::min(nelems, start * simd_w);
249 end = nstl::min(nelems, end * simd_w);
250 if (start == end) return;
251
252 jit_args_t args;
253 args.src = src + start;
254 args.dst = dst + start;
255 args.diff_dst = nullptr;
256 args.work_amount = end - start;
257 (*kernel_)(&args);
258 });
259
260 return status::success;
261 }
262
263 template <cpu_isa_t isa, data_type_t d_type>
init(engine_t * engine)264 status_t jit_uni_eltwise_bwd_t<isa, d_type>::pd_t::init(engine_t *engine) {
265 using namespace alg_kind;
266
267 const memory_desc_wrapper data_d(src_md());
268
269 bool ok = mayiuse(isa) && !is_fwd()
270 && utils::everyone_is(
271 d_type, src_md()->data_type, diff_src_md()->data_type)
272 && !has_zero_dim_memory() && set_default_formats_common()
273 && data_d.is_dense(true)
274 // refer to a comment in jit_uni_kernel why this is needed
275 && IMPLICATION(!data_d.is_dense(), is_zero_preserved())
276 && data_d == memory_desc_wrapper(diff_dst_md())
277 && attr()->has_default_values();
278
279 ok &= utils::one_of(desc_.alg_kind, eltwise_relu_use_dst_for_bwd,
280 eltwise_relu, eltwise_elu_use_dst_for_bwd, eltwise_elu,
281 eltwise_tanh_use_dst_for_bwd, eltwise_tanh, eltwise_square,
282 eltwise_abs, eltwise_sqrt_use_dst_for_bwd, eltwise_sqrt,
283 eltwise_linear, eltwise_bounded_relu, eltwise_soft_relu,
284 eltwise_logistic_use_dst_for_bwd, eltwise_logistic,
285 eltwise_exp_use_dst_for_bwd, eltwise_exp, eltwise_gelu_tanh,
286 eltwise_swish, eltwise_log, eltwise_clip, eltwise_gelu_erf);
287
288 return ok ? status::success : status::unimplemented;
289 }
290
291 template <cpu_isa_t isa, data_type_t d_type>
jit_uni_eltwise_bwd_t(const pd_t * apd)292 jit_uni_eltwise_bwd_t<isa, d_type>::jit_uni_eltwise_bwd_t(const pd_t *apd)
293 : primitive_t(apd) {}
294
295 template <cpu_isa_t isa, data_type_t d_type>
296 jit_uni_eltwise_bwd_t<isa, d_type>::~jit_uni_eltwise_bwd_t() = default;
297
298 template <cpu_isa_t isa, data_type_t d_type>
init(engine_t * engine)299 status_t jit_uni_eltwise_bwd_t<isa, d_type>::init(engine_t *engine) {
300 CHECK(safe_ptr_assign(kernel_, new jit_uni_kernel_t<isa>(pd())));
301 return kernel_->create_kernel();
302 }
303
304 template <cpu_isa_t isa, data_type_t d_type>
execute(const exec_ctx_t & ctx) const305 status_t jit_uni_eltwise_bwd_t<isa, d_type>::execute(
306 const exec_ctx_t &ctx) const {
307 status_t status = status::success;
308 auto src = pd()->use_dst() ? CTX_IN_MEM(const data_t *, DNNL_ARG_DST)
309 : CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
310 auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST);
311 auto diff_src = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_DIFF_SRC, status);
312 CHECK(status);
313
314 const memory_desc_wrapper data_d(pd()->src_md());
315 const memory_desc_wrapper diff_data_d(pd()->diff_src_md());
316 const auto nelems = data_d.nelems(true);
317 const int simd_w = 64 / data_d.data_type_size();
318
319 src += data_d.offset0();
320 diff_dst += diff_data_d.offset0();
321 diff_src += diff_data_d.offset0();
322
323 parallel(0, [&](const int ithr, const int nthr) {
324 dim_t start {0}, end {0};
325
326 balance211(utils::div_up(nelems, simd_w), nthr, ithr, start, end);
327 start = nstl::min(nelems, start * simd_w);
328 end = nstl::min(nelems, end * simd_w);
329 if (start == end) return;
330
331 jit_args_t args;
332 args.src = src + start;
333 args.dst = diff_src + start;
334 args.diff_dst = diff_dst + start;
335 args.work_amount = end - start;
336 (*kernel_)(&args);
337 });
338
339 return status::success;
340 }
341
342 template struct jit_uni_eltwise_fwd_t<sve_512, data_type::f32>;
343 template struct jit_uni_eltwise_bwd_t<sve_512, data_type::f32>;
344
345 } // namespace aarch64
346 } // namespace cpu
347 } // namespace impl
348 } // namespace dnnl
349