1 /******************************************************************************* 2 * Copyright 2019-2021 Intel Corporation 3 * Copyright 2021 FUJITSU LIMITED 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 *******************************************************************************/ 17 18 #ifndef CPU_AARCH64_JIT_UNI_ELTWISE_INJECTOR_HPP 19 #define CPU_AARCH64_JIT_UNI_ELTWISE_INJECTOR_HPP 20 21 #include <assert.h> 22 23 #include "common/c_types_map.hpp" 24 #include "common/primitive_attr.hpp" 25 #include "common/type_helpers.hpp" 26 #include "common/utils.hpp" 27 28 #include "cpu/aarch64/injectors/injector_utils.hpp" 29 #include "cpu/aarch64/jit_generator.hpp" 30 31 namespace dnnl { 32 namespace impl { 33 namespace cpu { 34 namespace aarch64 { 35 36 namespace eltwise_injector { 37 struct static_params_t { 38 static_params_tdnnl::impl::cpu::aarch64::eltwise_injector::static_params_t39 static_params_t(bool save_state = true, 40 Xbyak_aarch64::XReg x_table = Xbyak_aarch64::XReg(0), 41 Xbyak_aarch64::PReg p_mask = Xbyak_aarch64::PReg(1), 42 Xbyak_aarch64::PReg p_tmp0 = Xbyak_aarch64::PReg(4), 43 Xbyak_aarch64::PReg p_all = Xbyak_aarch64::PReg(7), 44 bool is_fwd = true, bool use_dst = false) 45 : save_state(save_state) 46 , x_table(x_table) 47 , p_mask(p_mask) 48 , p_tmp0(p_tmp0) 49 , p_all(p_all) 50 , is_fwd(is_fwd) 51 , use_dst(use_dst) {} 52 53 bool save_state; 54 Xbyak_aarch64::XReg x_table; 55 Xbyak_aarch64::PReg p_mask; 56 Xbyak_aarch64::PReg p_tmp0; 57 Xbyak_aarch64::PReg p_all; 58 bool is_fwd; 59 bool use_dst; 60 }; 61 } // namespace eltwise_injector 62 63 template <cpu_isa_t isa> 64 struct jit_uni_eltwise_injector_f32 { 65 using TReg = typename cpu_isa_traits<isa>::TReg; 66 using TRegS = typename cpu_isa_traits<isa>::TRegS; 67 68 // Arguments description: 69 // host - jit generator which is filled with instructions 70 // alg, alpha, beta, scale - user eltwise arguments 71 // save_state - when true, preserves on stack vmm_aux registers preventing 72 // results spoiling. Restores them when done in injector_postamble(). 73 // p_table - GPR where table label is stored to get access for pre-defined 74 // constants used in alg codes. 75 // k_mask - k_register to operate with masks in alg codes. 76 // is_fwd - when true, computes d = alg(s), otherwise, computes ds = alg'(s) 77 // - algorithm derivative. 78 // use_dst - defines whether source or destination point is passed to alg 79 // code. Depends on algorithm. See `_use_dst_for_bwd` algs definition. jit_uni_eltwise_injector_f32dnnl::impl::cpu::aarch64::jit_uni_eltwise_injector_f3280 jit_uni_eltwise_injector_f32(jit_generator *host, alg_kind_t alg, 81 float alpha, float beta, float scale, bool save_state = true, 82 Xbyak_aarch64::XReg x_table = Xbyak_aarch64::XReg(0), 83 Xbyak_aarch64::PReg p_mask = Xbyak_aarch64::PReg(1), 84 Xbyak_aarch64::PReg p_tmp0 = Xbyak_aarch64::PReg(4), 85 Xbyak_aarch64::PReg p_all = Xbyak_aarch64::PReg(7), 86 bool is_fwd = true, bool use_dst = false) 87 : alg_(alg) 88 , alpha_(alpha) 89 , beta_(beta) 90 , scale_(scale) 91 , h(host) 92 , save_state_(save_state) 93 , x_table(x_table) 94 , p_mask(p_mask) 95 , p_tmp0(p_tmp0) 96 , p_all(p_all) 97 , is_fwd_(is_fwd) 98 , use_dst_(use_dst) 99 100 { 101 using namespace alg_kind; 102 assert(utils::one_of(isa, sve_512)); 103 assert(utils::one_of(alg_, eltwise_relu, eltwise_tanh, eltwise_elu, 104 eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear, 105 eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic, 106 eltwise_exp, eltwise_gelu_tanh, eltwise_swish, eltwise_log, 107 eltwise_clip, eltwise_clip_v2, eltwise_gelu_erf, eltwise_round, 108 eltwise_relu_use_dst_for_bwd, eltwise_tanh_use_dst_for_bwd, 109 eltwise_elu_use_dst_for_bwd, eltwise_sqrt_use_dst_for_bwd, 110 eltwise_logistic_use_dst_for_bwd, eltwise_exp_use_dst_for_bwd, 111 eltwise_clip_v2_use_dst_for_bwd)); 112 register_table_entries(); 113 } 114 jit_uni_eltwise_injector_f32dnnl::impl::cpu::aarch64::jit_uni_eltwise_injector_f32115 jit_uni_eltwise_injector_f32(jit_generator *host, 116 const post_ops_t::entry_t::eltwise_t &eltwise, 117 bool save_state = true, 118 Xbyak_aarch64::XReg x_table = Xbyak_aarch64::XReg(0), 119 Xbyak_aarch64::PReg p_mask = Xbyak_aarch64::PReg(1), 120 Xbyak_aarch64::PReg p_tmp0 = Xbyak_aarch64::PReg(4), 121 Xbyak_aarch64::PReg p_all = Xbyak_aarch64::PReg(7), 122 bool is_fwd = true, bool use_dst = false) 123 : jit_uni_eltwise_injector_f32(host, eltwise.alg, eltwise.alpha, 124 eltwise.beta, eltwise.scale, save_state, x_table, p_mask, 125 p_tmp0, p_all, is_fwd, use_dst) {} 126 127 void compute_vector_range(size_t start_idx, size_t end_idx); 128 void compute_vector_range(const injector_utils::vmm_index_set_t &vmm_idxs); compute_vectordnnl::impl::cpu::aarch64::jit_uni_eltwise_injector_f32129 void compute_vector(size_t idx) { compute_vector_range(idx, idx + 1); } 130 void prepare_table(bool gen_table = true); load_table_addrdnnl::impl::cpu::aarch64::jit_uni_eltwise_injector_f32131 void load_table_addr() { h->adr(x_table, l_table); } 132 133 private: 134 const alg_kind_t alg_; 135 const float alpha_; 136 const float beta_; 137 const float scale_; 138 139 jit_generator *const h; 140 141 const bool save_state_; 142 const Xbyak_aarch64::XReg x_table; 143 const Xbyak_aarch64::PReg p_mask; 144 const Xbyak_aarch64::PReg p_tmp0; 145 const Xbyak_aarch64::PReg p_all; 146 const bool is_fwd_; 147 const bool use_dst_; 148 149 Xbyak_aarch64::Label l_table; 150 151 // if only the injector was inherited from jit_generator... 152 enum { 153 _cmp_eq_oq = jit_generator::_cmp_eq_oq, 154 _cmp_lt_os = jit_generator::_cmp_lt_os, 155 _cmp_le_os = jit_generator::_cmp_le_os, 156 _cmp_ge_os = jit_generator::_cmp_nlt_us, 157 _cmp_gt_os = jit_generator::_cmp_nle_us, 158 _op_floor = jit_generator::_op_floor, 159 _op_mxcsr = jit_generator::_op_mxcsr 160 }; 161 162 static constexpr size_t vlen = cpu_isa_traits<isa>::vlen; 163 static constexpr size_t preserved_vecs_max = 9; 164 static constexpr size_t preserved_gprs_max = 4; 165 static constexpr size_t vecs_count = 32; 166 static constexpr int n_mantissa_bits = 23; 167 static constexpr int k_mask_size = 8; 168 169 size_t vecs_to_preserve = 0; 170 size_t preserved_vecs_count = 0; 171 size_t preserved_vec_idxs[preserved_vecs_max] = {0}; 172 size_t preserved_gpr_idxs[preserved_gprs_max] = {0}; 173 injector_utils::vmm_index_set_iterator_t start_idx_tail; 174 175 /* These vector register must be assigned proper index. */ 176 TRegS vmm_mask {0}, vmm_aux0 {0}, vmm_aux1 {0}, vmm_aux2 {0}, vmm_aux3 {0}, 177 vmm_aux4 {0}, vmm_aux5 {0}, vmm_aux6 {0}, vmm_aux7 {0}, vmm_tmp {0}; 178 /* Default tempooral index. Chose a SVE register 179 not to be same as jit_uni_eltwise.(cpp|hpp). 180 This index is changed by assign_regs() in case of eltwise injection. 181 */ 182 TRegS z_tmp {31}; 183 184 size_t aux_vecs_count(); 185 size_t aux_gprs_count(); 186 187 void compute_body( 188 const injector_utils::vmm_index_set_iterator_t &start_idx_it, 189 const injector_utils::vmm_index_set_iterator_t &end_idx_it); 190 void injector_preamble(const injector_utils::vmm_index_set_t &vmm_idxs); 191 void injector_preamble_tail( 192 const injector_utils::vmm_index_set_iterator_t start_idx_it); 193 void injector_postamble(); 194 void assign_regs(); 195 void set_coef_to_regs(); 196 void compute_cmp_mask( 197 const TRegS &vmm_src, const TRegS &vmm_cmpare, int cmp_predicate); 198 void blend_with_mask(const TRegS &vmm_dst, const TRegS &src); 199 void test_mask(); 200 201 void exp_compute_vector_fwd(const TRegS &vmm_src); 202 void relu_compute_vector_fwd(const TRegS &vmm_src); 203 void relu_zero_ns_compute_vector_fwd(const TRegS &vmm_src); 204 void elu_compute_vector_fwd(const TRegS &vmm_src); 205 void tanh_compute_vector_fwd(const TRegS &vmm_src); 206 void square_compute_vector_fwd(const TRegS &vmm_src); 207 void abs_compute_vector_fwd(const TRegS &vmm_src); 208 void sqrt_compute_vector_fwd(const TRegS &vmm_src); 209 void linear_compute_vector_fwd(const TRegS &vmm_src); 210 void bounded_relu_compute_vector_fwd(const TRegS &vmm_src); 211 void soft_relu_compute_vector_fwd(const TRegS &vmm_src); 212 void logistic_compute_vector_fwd(const TRegS &vmm_src); 213 void gelu_tanh_compute_vector_fwd(const TRegS &vmm_src); 214 void swish_compute_vector_fwd(const TRegS &vmm_src); 215 void log_compute_vector_fwd(const TRegS &vmm_src); 216 void clip_compute_vector_fwd(const TRegS &vmm_src); 217 void gelu_erf_compute_vector_fwd(const TRegS &vmm_src); 218 void round_compute_vector_fwd(const TRegS &vmm_src); 219 220 void exp_compute_vector_bwd(const TRegS &vmm_src); 221 void relu_compute_vector_bwd(const TRegS &vmm_src); 222 void elu_compute_vector_bwd(const TRegS &vmm_src); 223 void tanh_compute_vector_bwd(const TRegS &vmm_src); 224 void square_compute_vector_bwd(const TRegS &vmm_src); 225 void abs_compute_vector_bwd(const TRegS &vmm_src); 226 void sqrt_compute_vector_bwd(const TRegS &vmm_src); 227 void linear_compute_vector_bwd(const TRegS &vmm_src); 228 void bounded_relu_compute_vector_bwd(const TRegS &vmm_src); 229 void soft_relu_compute_vector_bwd(const TRegS &vmm_src); 230 void logistic_compute_vector_bwd(const TRegS &vmm_src); 231 void gelu_tanh_compute_vector_bwd(const TRegS &vmm_src); 232 void swish_compute_vector_bwd(const TRegS &vmm_src); 233 void log_compute_vector_bwd(const TRegS &vmm_src); 234 void clip_compute_vector_bwd(const TRegS &vmm_src); 235 void gelu_erf_compute_vector_bwd(const TRegS &vmm_src); 236 237 enum key_t { 238 scale = 0, // scale argument 239 alpha, // alpha argument 240 beta, // beta argument 241 zero, // 0.f 242 half, // 0.5f 243 one, // 1.f or mask for exponent bits 244 two, // 2.f 245 minus_one, // -1.f or changes sign to opposite 246 minus_two, // -2.f 247 ln2f, // 0.69314718f 248 positive_mask, // changes sign to positive 249 sign_mask, // gets sign value 250 exponent_bias, // (127 = 2^7 - 1), gets exponent bits 251 exp_log2ef, // 1.44269502f - formula-based for approx 252 exp_ln_flt_max_f, // logf(FLT_MAX) - max normal value 253 exp_ln_flt_min_f, // logf(FLT_MIN) - min normal value 254 exp_pol, // see correspondent table for float values 255 exp_coeff1, // 0.6931473921 (0x3f31721c) 256 exp_coeff2, // 0.2413862043 (0x3e772df2) 257 exp_not_mask17, // ~((1u << 17) - 1) 258 tanh_range, // tanh(x) = x - x^3/3 for |x| < tanh_range 259 tanh_m1d3, // -1/3 260 soft_relu_one_twenty_six, // 126.f 261 soft_relu_mantissa_sign_mask, // mask for mantissa bits and sign 262 soft_relu_pol, // see correspondent table for float values 263 gelu_tanh_fitting_const, // 0.044715f 264 gelu_tanh_fitting_const_times_three, // 0.134145f 265 gelu_tanh_sqrt_two_over_pi, // sqrtf(2.f/pi) = 0.797884f 266 gelu_erf_approx_const, // 0.3275911f - implementation based for approx 267 gelu_erf_one_over_sqrt_two, // 1.f / sqrtf(2.f) 268 gelu_erf_one_over_sqrt_pi, // 1.f / sqrtf(pi) = 0.564190f 269 gelu_erf_pol, // see correspondent table for float values 270 log_minus_inf, // -inf 271 log_qnan, // qnan 272 log_mantissa_mask, // gets mantissa bits 273 log_full_k_reg_mask, // sets k_register with all bits of 1 274 log_full_vector_reg_mask, // sets vector register will all bits of 1 275 log_five_bit_offset, // 5 bits off (31 = 2^5 - 1) 276 log_pol, // see correspondent table for float values 277 log_predefined_vals, // see correspondent table for float values 278 log_i127shl23, 279 log_x7fffff, 280 log_log2, 281 log_log1p5, 282 log_f2div3, 283 log_coeffTbl, 284 undef_key, 285 }; 286 table_offdnnl::impl::cpu::aarch64::jit_uni_eltwise_injector_f32287 size_t table_off(key_t key, size_t key_off_val_shift = 0) { 288 // assumption: all table entries sharing the same key also 289 // share their broadcast property 290 // TODO: enforce through data structure 291 const auto it = entry_map_.find(key); // search an entry for a key 292 assert(it != entry_map_.end()); 293 const auto &te = (*it).second; 294 const auto scale = te.bcast ? vlen : sizeof(table_entry_val_t); 295 return te.off + key_off_val_shift * scale; 296 } 297 table_valdnnl::impl::cpu::aarch64::jit_uni_eltwise_injector_f32298 TRegS table_val(key_t key, TRegS zreg, size_t key_off_val_shift = 0) { 299 Xbyak_aarch64::XReg x_addr(h->X_DEFAULT_ADDR); 300 auto off = table_off(key, key_off_val_shift); 301 302 if (off) { 303 h->add_imm(x_addr, x_table, off, h->X_TMP_0); 304 } else { 305 x_addr = x_table; 306 } 307 308 h->ldr(TReg(zreg.getIdx()), ptr(x_addr)); 309 return zreg; 310 } 311 312 // we accept only 32bit hexadecimal table values to avoid any rounding 313 using table_entry_val_t = uint32_t; 314 using table_entry_offset_t = size_t; // offsets are in bytes wrt p_table 315 using table_entry_bcast_t = bool; // true => bcast value 316 317 struct table_entry_t { 318 table_entry_val_t val; 319 table_entry_bcast_t bcast; 320 }; 321 struct mapped_table_entry_t { 322 table_entry_offset_t off; 323 table_entry_val_t val; 324 table_entry_bcast_t bcast; 325 }; 326 327 using table_t = std::multimap<key_t, table_entry_t>; 328 using mapped_table_t = std::multimap<key_t, mapped_table_entry_t>; 329 330 void register_table_entries(); 331 mapped_table_t entry_map_; 332 }; 333 334 } // namespace aarch64 335 } // namespace cpu 336 } // namespace impl 337 } // namespace dnnl 338 #endif 339