1 /******************************************************************************* 2 * Copyright 2020-2021 Intel Corporation 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 *******************************************************************************/ 16 17 #ifndef CPU_X64_JIT_AVX512_CORE_AMX_1X1_CONV_KERNEL_HPP 18 #define CPU_X64_JIT_AVX512_CORE_AMX_1X1_CONV_KERNEL_HPP 19 20 #include "common/c_types_map.hpp" 21 #include "common/memory_tracking.hpp" 22 23 #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" 24 #include "cpu/x64/jit_generator.hpp" 25 #include "cpu/x64/jit_primitive_conf.hpp" 26 27 namespace dnnl { 28 namespace impl { 29 namespace cpu { 30 namespace x64 { 31 32 struct jit_avx512_core_amx_1x1_fwd_kernel_t : public jit_generator { 33 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_amx_1x1_fwd_kernel_t) 34 35 jit_avx512_core_amx_1x1_fwd_kernel_t(const jit_conv_conf_t &ajcp, 36 const primitive_attr_t &attr, const memory_desc_t &dst_md); 37 38 static status_t init_conf(jit_conv_conf_t &jcp, 39 const convolution_desc_t &cd, memory_desc_t &src_pd, 40 memory_desc_t &weights_pd, memory_desc_t &dst_pd, 41 memory_desc_t &bias_pd, primitive_attr_t &attr, int nthreads); 42 static void init_scratchpad(memory_tracking::registrar_t &scratchpad, 43 const jit_conv_conf_t &jcp, const primitive_attr_t &attr); 44 45 // Tile-registers decomposition 46 enum { C_BASE = 0, W_BASE = 6, I_BASE = 4 }; 47 48 void tile_configure(char *tcgf_buff); 49 50 jit_conv_conf_t jcp; 51 const primitive_attr_t &attr_; 52 53 private: 54 constexpr static int isa_simd_width_ 55 = cpu_isa_traits<avx512_common>::vlen / sizeof(float); 56 std::unique_ptr<injector::jit_uni_postops_injector_t<avx512_core>> 57 postops_injector_; 58 59 enum { 60 zmm_idx_limit_bf16 = 29, 61 zmm_idx_limit_int8 = 27, 62 }; 63 64 int row_count_ = 0; 65 int buf_count_ = 0; 66 bool is_store_done_ = false; 67 bool is_buffer_empty_ = true; 68 bool check_last_sb_ = false; 69 bool last_oc_block_flag_ = false; 70 71 /* data regs */ 72 const Xbyak::Reg64 &inp_ptr = r15; 73 const Xbyak::Reg64 &wei_ptr = r14; 74 const Xbyak::Reg64 &out_ptr = r13; 75 const Xbyak::Reg64 &wsp_ptr = r12; 76 77 const Xbyak::Reg64 ®_bias = r11; 78 const Xbyak::Reg64 ®_ptr_scales = r10; 79 const Xbyak::Reg64 ®_ptr_sum_scale = r9; 80 const Xbyak::Reg64 ®_ptr_sum_zp = rax; 81 const Xbyak::Reg64 &aux_reg_saturation = reg_ptr_sum_scale; 82 const Xbyak::Reg64 ®_last_h = r8; 83 84 const Xbyak::Reg64 &stride_seq = rbx; 85 const Xbyak::Reg64 &stride_nhwc = rsi; 86 const Xbyak::Reg64 ®_tmp = abi_not_param1; 87 88 const Xbyak::Reg64 ®_oc_blocks = rdx; 89 const Xbyak::Reg64 ®_is_osb = rsi; 90 const Xbyak::Reg64 ®_postop = abi_not_param1; 91 const Xbyak::Reg64 ®_scratch = reg_bias; 92 const Xbyak::Reg64 ®_tilebuff = reg_ptr_scales; 93 /* zero-point */ 94 const Xbyak::Reg64 ®_zp_compensation = reg_last_h; 95 const Xbyak::Reg64 ®_src_zero_point = reg_oc_blocks; 96 const Xbyak::Reg64 ®_dst_zero_point = rax; 97 98 const Xbyak::Zmm &zmm_bias = zmm31; 99 const Xbyak::Zmm &zmm_saturation = zmm_bias; 100 const Xbyak::Zmm &zmm_zero = zmm30; 101 const Xbyak::Zmm &zmm_prev_dst = zmm29; 102 const Xbyak::Zmm &zmm_sum_zp = zmm26; 103 /* zero-point */ 104 const Xbyak::Zmm &zmm_zp = zmm29; 105 const Xbyak::Zmm &zmm_src_zp = zmm28; 106 const Xbyak::Zmm &zmm_dst_zp = zmm27; 107 108 const Xbyak::Reg64 &bin_injector_helper_reg_1 = r14; 109 const Xbyak::Reg64 &bin_injector_helper_reg_2 = r15; 110 111 const Xbyak::Opmask &ktail_mask = k2; 112 113 bool is_bf16() const; 114 115 void init_runtime_counters(); 116 117 int get_out_tensor(int h, int i) const; 118 int get_inp_tensor(int h) const; 119 int get_wei_tensor(int i) const; 120 int get_ic_tail() const; 121 122 size_t out_h_shift() const; 123 size_t out_w_shift() const; 124 size_t inp_offset(int ih, int iw, int icb) const; 125 size_t out_row_offset(int h, int w, int ocb) const; 126 127 void prepare_output(); 128 129 void cvt2ps(data_type_t type_in, const Xbyak::Zmm &ymm_in, 130 const Xbyak::Operand &op, bool mask_flag); zmm_outdnnl::impl::cpu::x64::jit_avx512_core_amx_1x1_fwd_kernel_t131 Xbyak::Zmm zmm_out(const int idx) { 132 const int upper_limit 133 = is_bf16() ? zmm_idx_limit_bf16 : zmm_idx_limit_int8; 134 assert(upper_limit > idx); 135 MAYBE_UNUSED(upper_limit); 136 return Xbyak::Zmm(idx); 137 } 138 Xbyak::Zmm zmm_mask( 139 const Xbyak::Zmm &zmm_in, bool mask_flag, bool store = false); 140 Xbyak::Ymm ymm_mask( 141 const Xbyak::Ymm &ymm_in, bool mask_flag, bool store = false); 142 143 void update_buffer_pointers(); 144 void interleave_store(); 145 void apply_sum(const Xbyak::Zmm &zmm_out, const float *p_sum_scale, 146 const int32_t *p_sum_zp, const Xbyak::Address &addr, 147 const bool mask_flag); 148 void apply_postops(const Xbyak::Zmm &zmm_out, const float *p_sum_scale, 149 const int32_t *p_sum_zp, const Xbyak::Address &addr, 150 const size_t off, const bool mask_flag); 151 static bool is_fast_postops(const jit_conv_conf_t &jcp); 152 void store_output_vectors_int8(int ocb, int osb); 153 void store_output_vector_int8( 154 const Xbyak::Zmm &zmm_out, int ocb, int h, int w); 155 void store_output_vectors_bf16(int ocb, int osb); 156 void store_output_vector_bf16( 157 const Xbyak::Zmm &zmm_out, int ocb, int h, int w); 158 void store_output_vectors(int ocb, int osb); 159 void store_output_vector(const Xbyak::Zmm &zmm_out, int ocb, int h, int w); 160 void store_output(bool do_store, bool is_tail); 161 void icb_loop(bool do_store); 162 void osb_loop(int nb_os = 1); 163 164 void generate() override; 165 }; 166 167 } // namespace x64 168 } // namespace cpu 169 } // namespace impl 170 } // namespace dnnl 171 172 #endif 173