1 /******************************************************************************* 2 * Copyright 2021 Intel Corporation 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 *******************************************************************************/ 16 17 #ifndef CPU_X64_UNI_BINARY_KERNEL_HPP 18 #define CPU_X64_UNI_BINARY_KERNEL_HPP 19 20 #include <cassert> 21 22 #include "common/c_types_map.hpp" 23 #include "common/type_helpers.hpp" 24 #include "common/utils.hpp" 25 26 #include "cpu/x64/cpu_isa_traits.hpp" 27 #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" 28 #include "cpu/x64/jit_generator.hpp" 29 #include "cpu/x64/jit_primitive_conf.hpp" 30 #include "cpu/x64/utils/jit_io_helper.hpp" 31 32 #include "cpu/cpu_binary_pd.hpp" 33 34 namespace dnnl { 35 namespace impl { 36 namespace cpu { 37 namespace x64 { 38 39 using namespace Xbyak; 40 41 struct binary_kernel_t : public jit_generator { 42 using op_t = binary_op_t; 43 using bcast_t = binary_bcast_t; 44 45 binary_kernel_t(const size_t vlen, const binary_pd_t *pd, 46 const jit_binary_conf_t conf, bool tail_kernel = false); 47 ~binary_kernel_t() override = default; 48 operator ()dnnl::impl::cpu::x64::binary_kernel_t49 void operator()(jit_binary_call_s *p) { jit_generator::operator()(p); } 50 simd_wdnnl::impl::cpu::x64::binary_kernel_t51 size_t simd_w() const noexcept { return simd_w_; } vlendnnl::impl::cpu::x64::binary_kernel_t52 size_t vlen() const noexcept { return vlen_; } 53 54 protected: 55 size_t get_tail_size() const; 56 57 const size_t vlen_; 58 const size_t simd_w_; 59 constexpr static int vmm_start_idx_ = 1; 60 const binary_pd_t *pd_; 61 const jit_binary_conf_t conf_; 62 const bool is_tail_kernel_; 63 const bool is_src1_outer_dims_tail_; 64 const size_t tail_size_; 65 const size_t padding_tail_size_; 66 }; 67 68 template <cpu_isa_t isa> 69 struct jit_uni_binary_kernel_t : public binary_kernel_t { 70 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_binary_kernel_t) 71 72 using Vmm = typename cpu_isa_traits<isa>::Vmm; 73 const AddressFrame &vmmword 74 = (isa == sse41) ? xword : ((isa == avx2) ? yword : zword); 75 76 static constexpr bool is_avx512 77 = utils::one_of(isa, avx512_common, avx512_core, avx512_core_bf16); 78 static constexpr bool is_avx512_core 79 = utils::one_of(isa, avx512_core, avx512_core_bf16); 80 static constexpr bool is_avx512_common = isa == avx512_common; 81 const bool is_avx512_not_mic 82 = is_avx512_core || (is_avx512_common && !conf_.is_i8); 83 84 const Reg64 ®_param_ = abi_param1; 85 const Reg64 ®_src0_ = r8; 86 const Reg64 ®_src1_ = r9; 87 const Reg64 ®_dst_ = r10; 88 const Reg64 ®_offt_src0_ = r11; 89 const Reg64 ®_outer_dims_range_ = r12; 90 const Reg64 ®_offt_src1_ = rax; 91 const Reg64 ®_src1_stride_range_ = r15; 92 const Reg64 ®_reverse_src1_stride_range_ = rax; 93 const Reg64 ®_reverse_spat_offt_ = r13; 94 const Reg64 ®_tmp_ = r14; 95 const Reg64 ®_tmp1_ = abi_not_param1; 96 const Reg64 ®_elt_inj_table_ = r15; 97 const Reg64 ®_off_rhs_postops_ = rdx; 98 const Reg64 ®_scales_src0_ = rbx; 99 const Reg64 ®_scales_src1_ = rbp; 100 const Reg64 ®_offt_dst_ = rdx; 101 const Opmask &tail_opmask_ = k2; 102 const Opmask &cmp_mask = k3; 103 const Opmask &full_mask_ = k4; 104 const Vmm vmm_tail_vmask_ = Vmm(0); 105 const Vmm vreg_sum_scale_ = Vmm(is_avx512 ? 17 : 9); 106 const Xmm xreg_sum_scale_ = Xmm(9); 107 const Vmm vreg_zero_ = Vmm(is_avx512 ? 18 : 10); 108 const Vmm vreg_one_ = Vmm(is_avx512 ? 19 : 11); 109 const Vmm vreg_saturation_ubound_ = Vmm(is_avx512 ? 20 : 12); 110 const Vmm vreg_bcast_src1_ = Vmm(is_avx512_not_mic ? 21 : 13); 111 const Xmm xreg_bcast_src1_ = Xmm(13); 112 const Vmm vreg_scales_src0_ = Vmm(is_avx512 ? 22 : 14); 113 const Vmm vreg_scales_src1_ = Vmm(is_avx512 ? 23 : 15); 114 115 const Zmm vreg_bf16_emu_1_ = Zmm(26); 116 const Zmm vreg_bf16_emu_2_ = Zmm(27); 117 const Zmm vreg_bf16_emu_3_ = Zmm(28); 118 const Zmm vreg_bf16_emu_4_ = Zmm(29); 119 120 const Vmm vmm_full_mask_ = Vmm(is_avx512_not_mic ? 24 : 5); 121 const Vmm vmm_tmp_gather_ = Vmm(is_avx512_not_mic ? 25 : 6); 122 const Vmm vmm_indices_ = Vmm(is_avx512_not_mic ? 30 : 7); 123 const Vmm vmm_gathered_src_ = Vmm(is_avx512_not_mic ? 31 : 8); 124 125 const size_t unroll_regs_ = is_avx512 126 && IMPLICATION( 127 conf_.is_src_different_layouts, is_avx512_not_mic) 128 ? 8 129 : 4; 130 const size_t offt_src0_; 131 const size_t offt_src1_; 132 133 static constexpr cpu_isa_t inject_isa 134 = isa == avx512_core_bf16 ? avx512_core : isa; 135 io::jit_io_multi_dt_helper_t<Vmm> io_; 136 std::unique_ptr<injector::jit_uni_postops_injector_t<inject_isa>> 137 postops_injector_; 138 const Opmask &elt_inj_opmask_ = k1; 139 140 void init(); 141 void init_post_ops_injector(); 142 void apply_postops(int unroll, bool tail); 143 void load_kernel_params(); 144 Address src0_ptr(size_t offt = 0); 145 Address src1_ptr(size_t offt = 0); 146 Address dst_ptr(size_t offt = 0); 147 unsigned int cmp_predicate(alg_kind_t alg); 148 void perform_op( 149 const Vmm &v0, const Vmm &v1, const Vmm &s_src0, const Vmm &s_src1); 150 void prepare_isa_kernel(); 151 void compute_bcast(bool tail); 152 void load_src1(const Vmm &vreg_src1, const int offt, bool tail); 153 void compute_dst(int unroll, bool tail); 154 void forward(); 155 void forward_over_outer_dims(); 156 void generate() override; 157 158 jit_uni_binary_kernel_t(const binary_pd_t *pd, const jit_binary_conf_t conf, 159 bool tail_kernel = false); 160 ~jit_uni_binary_kernel_t() override = default; 161 162 std::map<data_type_t, io::io_saturation_conf_t> 163 create_saturation_vmm_map() const; 164 }; 165 166 } // namespace x64 167 } // namespace cpu 168 } // namespace impl 169 } // namespace dnnl 170 171 #endif