1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 * Copyright 2021 FUJITSU LIMITED
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *******************************************************************************/
17 
18 #ifndef CPU_AARCH64_JIT_UNI_ELTWISE_INJECTOR_HPP
19 #define CPU_AARCH64_JIT_UNI_ELTWISE_INJECTOR_HPP
20 
21 #include <assert.h>
22 
23 #include "common/c_types_map.hpp"
24 #include "common/primitive_attr.hpp"
25 #include "common/type_helpers.hpp"
26 #include "common/utils.hpp"
27 
28 #include "cpu/aarch64/injectors/injector_utils.hpp"
29 #include "cpu/aarch64/jit_generator.hpp"
30 
31 namespace dnnl {
32 namespace impl {
33 namespace cpu {
34 namespace aarch64 {
35 
36 namespace eltwise_injector {
37 struct static_params_t {
38 
static_params_tdnnl::impl::cpu::aarch64::eltwise_injector::static_params_t39     static_params_t(bool save_state = true,
40             Xbyak_aarch64::XReg x_table = Xbyak_aarch64::XReg(0),
41             Xbyak_aarch64::PReg p_mask = Xbyak_aarch64::PReg(1),
42             Xbyak_aarch64::PReg p_tmp0 = Xbyak_aarch64::PReg(4),
43             Xbyak_aarch64::PReg p_all = Xbyak_aarch64::PReg(7),
44             bool is_fwd = true, bool use_dst = false)
45         : save_state(save_state)
46         , x_table(x_table)
47         , p_mask(p_mask)
48         , p_tmp0(p_tmp0)
49         , p_all(p_all)
50         , is_fwd(is_fwd)
51         , use_dst(use_dst) {}
52 
53     bool save_state;
54     Xbyak_aarch64::XReg x_table;
55     Xbyak_aarch64::PReg p_mask;
56     Xbyak_aarch64::PReg p_tmp0;
57     Xbyak_aarch64::PReg p_all;
58     bool is_fwd;
59     bool use_dst;
60 };
61 } // namespace eltwise_injector
62 
63 template <cpu_isa_t isa>
64 struct jit_uni_eltwise_injector_f32 {
65     using TReg = typename cpu_isa_traits<isa>::TReg;
66     using TRegS = typename cpu_isa_traits<isa>::TRegS;
67 
68     // Arguments description:
69     // host - jit generator which is filled with instructions
70     // alg, alpha, beta, scale - user eltwise arguments
71     // save_state - when true, preserves on stack vmm_aux registers preventing
72     //   results spoiling. Restores them when done in injector_postamble().
73     // p_table - GPR where table label is stored to get access for pre-defined
74     //   constants used in alg codes.
75     // k_mask - k_register to operate with masks in alg codes.
76     // is_fwd - when true, computes d = alg(s), otherwise, computes ds = alg'(s)
77     //   - algorithm derivative.
78     // use_dst - defines whether source or destination point is passed to alg
79     //   code. Depends on algorithm. See `_use_dst_for_bwd` algs definition.
jit_uni_eltwise_injector_f32dnnl::impl::cpu::aarch64::jit_uni_eltwise_injector_f3280     jit_uni_eltwise_injector_f32(jit_generator *host, alg_kind_t alg,
81             float alpha, float beta, float scale, bool save_state = true,
82             Xbyak_aarch64::XReg x_table = Xbyak_aarch64::XReg(0),
83             Xbyak_aarch64::PReg p_mask = Xbyak_aarch64::PReg(1),
84             Xbyak_aarch64::PReg p_tmp0 = Xbyak_aarch64::PReg(4),
85             Xbyak_aarch64::PReg p_all = Xbyak_aarch64::PReg(7),
86             bool is_fwd = true, bool use_dst = false)
87         : alg_(alg)
88         , alpha_(alpha)
89         , beta_(beta)
90         , scale_(scale)
91         , h(host)
92         , save_state_(save_state)
93         , x_table(x_table)
94         , p_mask(p_mask)
95         , p_tmp0(p_tmp0)
96         , p_all(p_all)
97         , is_fwd_(is_fwd)
98         , use_dst_(use_dst)
99 
100     {
101         using namespace alg_kind;
102         assert(utils::one_of(isa, sve_512));
103         assert(utils::one_of(alg_, eltwise_relu, eltwise_tanh, eltwise_elu,
104                 eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
105                 eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic,
106                 eltwise_exp, eltwise_gelu_tanh, eltwise_swish, eltwise_log,
107                 eltwise_clip, eltwise_clip_v2, eltwise_gelu_erf, eltwise_round,
108                 eltwise_relu_use_dst_for_bwd, eltwise_tanh_use_dst_for_bwd,
109                 eltwise_elu_use_dst_for_bwd, eltwise_sqrt_use_dst_for_bwd,
110                 eltwise_logistic_use_dst_for_bwd, eltwise_exp_use_dst_for_bwd,
111                 eltwise_clip_v2_use_dst_for_bwd));
112         register_table_entries();
113     }
114 
jit_uni_eltwise_injector_f32dnnl::impl::cpu::aarch64::jit_uni_eltwise_injector_f32115     jit_uni_eltwise_injector_f32(jit_generator *host,
116             const post_ops_t::entry_t::eltwise_t &eltwise,
117             bool save_state = true,
118             Xbyak_aarch64::XReg x_table = Xbyak_aarch64::XReg(0),
119             Xbyak_aarch64::PReg p_mask = Xbyak_aarch64::PReg(1),
120             Xbyak_aarch64::PReg p_tmp0 = Xbyak_aarch64::PReg(4),
121             Xbyak_aarch64::PReg p_all = Xbyak_aarch64::PReg(7),
122             bool is_fwd = true, bool use_dst = false)
123         : jit_uni_eltwise_injector_f32(host, eltwise.alg, eltwise.alpha,
124                 eltwise.beta, eltwise.scale, save_state, x_table, p_mask,
125                 p_tmp0, p_all, is_fwd, use_dst) {}
126 
127     void compute_vector_range(size_t start_idx, size_t end_idx);
128     void compute_vector_range(const injector_utils::vmm_index_set_t &vmm_idxs);
compute_vectordnnl::impl::cpu::aarch64::jit_uni_eltwise_injector_f32129     void compute_vector(size_t idx) { compute_vector_range(idx, idx + 1); }
130     void prepare_table(bool gen_table = true);
load_table_addrdnnl::impl::cpu::aarch64::jit_uni_eltwise_injector_f32131     void load_table_addr() { h->adr(x_table, l_table); }
132 
133 private:
134     const alg_kind_t alg_;
135     const float alpha_;
136     const float beta_;
137     const float scale_;
138 
139     jit_generator *const h;
140 
141     const bool save_state_;
142     const Xbyak_aarch64::XReg x_table;
143     const Xbyak_aarch64::PReg p_mask;
144     const Xbyak_aarch64::PReg p_tmp0;
145     const Xbyak_aarch64::PReg p_all;
146     const bool is_fwd_;
147     const bool use_dst_;
148 
149     Xbyak_aarch64::Label l_table;
150 
151     // if only the injector was inherited from jit_generator...
152     enum {
153         _cmp_eq_oq = jit_generator::_cmp_eq_oq,
154         _cmp_lt_os = jit_generator::_cmp_lt_os,
155         _cmp_le_os = jit_generator::_cmp_le_os,
156         _cmp_ge_os = jit_generator::_cmp_nlt_us,
157         _cmp_gt_os = jit_generator::_cmp_nle_us,
158         _op_floor = jit_generator::_op_floor,
159         _op_mxcsr = jit_generator::_op_mxcsr
160     };
161 
162     static constexpr size_t vlen = cpu_isa_traits<isa>::vlen;
163     static constexpr size_t preserved_vecs_max = 9;
164     static constexpr size_t preserved_gprs_max = 4;
165     static constexpr size_t vecs_count = 32;
166     static constexpr int n_mantissa_bits = 23;
167     static constexpr int k_mask_size = 8;
168 
169     size_t vecs_to_preserve = 0;
170     size_t preserved_vecs_count = 0;
171     size_t preserved_vec_idxs[preserved_vecs_max] = {0};
172     size_t preserved_gpr_idxs[preserved_gprs_max] = {0};
173     injector_utils::vmm_index_set_iterator_t start_idx_tail;
174 
175     /* These vector register must be assigned proper index. */
176     TRegS vmm_mask {0}, vmm_aux0 {0}, vmm_aux1 {0}, vmm_aux2 {0}, vmm_aux3 {0},
177             vmm_aux4 {0}, vmm_aux5 {0}, vmm_aux6 {0}, vmm_aux7 {0}, vmm_tmp {0};
178     /* Default tempooral index. Chose a SVE register
179      not to be same as jit_uni_eltwise.(cpp|hpp).
180      This index is changed by assign_regs() in case of eltwise injection.
181   */
182     TRegS z_tmp {31};
183 
184     size_t aux_vecs_count();
185     size_t aux_gprs_count();
186 
187     void compute_body(
188             const injector_utils::vmm_index_set_iterator_t &start_idx_it,
189             const injector_utils::vmm_index_set_iterator_t &end_idx_it);
190     void injector_preamble(const injector_utils::vmm_index_set_t &vmm_idxs);
191     void injector_preamble_tail(
192             const injector_utils::vmm_index_set_iterator_t start_idx_it);
193     void injector_postamble();
194     void assign_regs();
195     void set_coef_to_regs();
196     void compute_cmp_mask(
197             const TRegS &vmm_src, const TRegS &vmm_cmpare, int cmp_predicate);
198     void blend_with_mask(const TRegS &vmm_dst, const TRegS &src);
199     void test_mask();
200 
201     void exp_compute_vector_fwd(const TRegS &vmm_src);
202     void relu_compute_vector_fwd(const TRegS &vmm_src);
203     void relu_zero_ns_compute_vector_fwd(const TRegS &vmm_src);
204     void elu_compute_vector_fwd(const TRegS &vmm_src);
205     void tanh_compute_vector_fwd(const TRegS &vmm_src);
206     void square_compute_vector_fwd(const TRegS &vmm_src);
207     void abs_compute_vector_fwd(const TRegS &vmm_src);
208     void sqrt_compute_vector_fwd(const TRegS &vmm_src);
209     void linear_compute_vector_fwd(const TRegS &vmm_src);
210     void bounded_relu_compute_vector_fwd(const TRegS &vmm_src);
211     void soft_relu_compute_vector_fwd(const TRegS &vmm_src);
212     void logistic_compute_vector_fwd(const TRegS &vmm_src);
213     void gelu_tanh_compute_vector_fwd(const TRegS &vmm_src);
214     void swish_compute_vector_fwd(const TRegS &vmm_src);
215     void log_compute_vector_fwd(const TRegS &vmm_src);
216     void clip_compute_vector_fwd(const TRegS &vmm_src);
217     void gelu_erf_compute_vector_fwd(const TRegS &vmm_src);
218     void round_compute_vector_fwd(const TRegS &vmm_src);
219 
220     void exp_compute_vector_bwd(const TRegS &vmm_src);
221     void relu_compute_vector_bwd(const TRegS &vmm_src);
222     void elu_compute_vector_bwd(const TRegS &vmm_src);
223     void tanh_compute_vector_bwd(const TRegS &vmm_src);
224     void square_compute_vector_bwd(const TRegS &vmm_src);
225     void abs_compute_vector_bwd(const TRegS &vmm_src);
226     void sqrt_compute_vector_bwd(const TRegS &vmm_src);
227     void linear_compute_vector_bwd(const TRegS &vmm_src);
228     void bounded_relu_compute_vector_bwd(const TRegS &vmm_src);
229     void soft_relu_compute_vector_bwd(const TRegS &vmm_src);
230     void logistic_compute_vector_bwd(const TRegS &vmm_src);
231     void gelu_tanh_compute_vector_bwd(const TRegS &vmm_src);
232     void swish_compute_vector_bwd(const TRegS &vmm_src);
233     void log_compute_vector_bwd(const TRegS &vmm_src);
234     void clip_compute_vector_bwd(const TRegS &vmm_src);
235     void gelu_erf_compute_vector_bwd(const TRegS &vmm_src);
236 
237     enum key_t {
238         scale = 0, // scale argument
239         alpha, // alpha argument
240         beta, // beta argument
241         zero, // 0.f
242         half, // 0.5f
243         one, // 1.f  or  mask for exponent bits
244         two, // 2.f
245         minus_one, // -1.f  or  changes sign to opposite
246         minus_two, // -2.f
247         ln2f, // 0.69314718f
248         positive_mask, // changes sign to positive
249         sign_mask, // gets sign value
250         exponent_bias, // (127 = 2^7 - 1), gets exponent bits
251         exp_log2ef, // 1.44269502f - formula-based for approx
252         exp_ln_flt_max_f, // logf(FLT_MAX) - max normal value
253         exp_ln_flt_min_f, // logf(FLT_MIN) - min normal value
254         exp_pol, // see correspondent table for float values
255         exp_coeff1, // 0.6931473921 (0x3f31721c)
256         exp_coeff2, // 0.2413862043 (0x3e772df2)
257         exp_not_mask17, // ~((1u << 17) - 1)
258         tanh_range, // tanh(x) = x - x^3/3 for |x| < tanh_range
259         tanh_m1d3, // -1/3
260         soft_relu_one_twenty_six, // 126.f
261         soft_relu_mantissa_sign_mask, // mask for mantissa bits and sign
262         soft_relu_pol, // see correspondent table for float values
263         gelu_tanh_fitting_const, // 0.044715f
264         gelu_tanh_fitting_const_times_three, // 0.134145f
265         gelu_tanh_sqrt_two_over_pi, // sqrtf(2.f/pi) = 0.797884f
266         gelu_erf_approx_const, // 0.3275911f - implementation based for approx
267         gelu_erf_one_over_sqrt_two, // 1.f / sqrtf(2.f)
268         gelu_erf_one_over_sqrt_pi, // 1.f / sqrtf(pi) = 0.564190f
269         gelu_erf_pol, // see correspondent table for float values
270         log_minus_inf, // -inf
271         log_qnan, // qnan
272         log_mantissa_mask, // gets mantissa bits
273         log_full_k_reg_mask, // sets k_register with all bits of 1
274         log_full_vector_reg_mask, // sets vector register will all bits of 1
275         log_five_bit_offset, // 5 bits off (31 = 2^5 - 1)
276         log_pol, // see correspondent table for float values
277         log_predefined_vals, // see correspondent table for float values
278         log_i127shl23,
279         log_x7fffff,
280         log_log2,
281         log_log1p5,
282         log_f2div3,
283         log_coeffTbl,
284         undef_key,
285     };
286 
table_offdnnl::impl::cpu::aarch64::jit_uni_eltwise_injector_f32287     size_t table_off(key_t key, size_t key_off_val_shift = 0) {
288         // assumption: all table entries sharing the same key also
289         // share their broadcast property
290         // TODO: enforce through data structure
291         const auto it = entry_map_.find(key); // search an entry for a key
292         assert(it != entry_map_.end());
293         const auto &te = (*it).second;
294         const auto scale = te.bcast ? vlen : sizeof(table_entry_val_t);
295         return te.off + key_off_val_shift * scale;
296     }
297 
table_valdnnl::impl::cpu::aarch64::jit_uni_eltwise_injector_f32298     TRegS table_val(key_t key, TRegS zreg, size_t key_off_val_shift = 0) {
299         Xbyak_aarch64::XReg x_addr(h->X_DEFAULT_ADDR);
300         auto off = table_off(key, key_off_val_shift);
301 
302         if (off) {
303             h->add_imm(x_addr, x_table, off, h->X_TMP_0);
304         } else {
305             x_addr = x_table;
306         }
307 
308         h->ldr(TReg(zreg.getIdx()), ptr(x_addr));
309         return zreg;
310     }
311 
312     // we accept only 32bit hexadecimal table values to avoid any rounding
313     using table_entry_val_t = uint32_t;
314     using table_entry_offset_t = size_t; // offsets are in bytes wrt p_table
315     using table_entry_bcast_t = bool; // true => bcast value
316 
317     struct table_entry_t {
318         table_entry_val_t val;
319         table_entry_bcast_t bcast;
320     };
321     struct mapped_table_entry_t {
322         table_entry_offset_t off;
323         table_entry_val_t val;
324         table_entry_bcast_t bcast;
325     };
326 
327     using table_t = std::multimap<key_t, table_entry_t>;
328     using mapped_table_t = std::multimap<key_t, mapped_table_entry_t>;
329 
330     void register_table_entries();
331     mapped_table_t entry_map_;
332 };
333 
334 } // namespace aarch64
335 } // namespace cpu
336 } // namespace impl
337 } // namespace dnnl
338 #endif
339