1 /******************************************************************************* 2 * Copyright 2016-2020 Intel Corporation 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 *******************************************************************************/ 16 17 #ifndef CPU_X64_JIT_UNI_LRN_KERNEL_HPP 18 #define CPU_X64_JIT_UNI_LRN_KERNEL_HPP 19 20 #include "common/c_types_map.hpp" 21 #include "common/type_helpers.hpp" 22 23 #include "cpu/x64/jit_generator.hpp" 24 25 namespace dnnl { 26 namespace impl { 27 namespace cpu { 28 namespace x64 { 29 30 struct bf16_emulation_t; 31 struct jit_args_fwd_t { 32 const void *src; 33 void *dst, *scratch, *bwd_intermediate_res; 34 }; 35 36 struct jit_args_bwd_t { 37 const void *src, *diff_dst, *scratch, *bwd_intermediate_res; 38 void *diff_src; 39 }; 40 41 struct nchw8c_across_t { 42 /* version: 43 * -1: channels 0..7, 44 * 1: channels C-8 .. C-1, 45 * 0: other channels 46 * 3: channels only for this kernel(without prev and next) 47 */ 48 int H, W, version; nchw8c_across_tdnnl::impl::cpu::x64::nchw8c_across_t49 nchw8c_across_t(int h, int w, int v) : H(h), W(w), version(v) {} nchw8c_across_tdnnl::impl::cpu::x64::nchw8c_across_t50 nchw8c_across_t() : nchw8c_across_t(0, 0, 0) {} 51 }; 52 53 struct within_config_t { 54 const int H, W, C, size; 55 const format_tag_t dat_tag; within_config_tdnnl::impl::cpu::x64::within_config_t56 within_config_t(int h, int w, int c, int s, format_tag_t dat_tag) 57 : H(h), W(w), C(c), size(s), dat_tag(dat_tag) {} within_config_tdnnl::impl::cpu::x64::within_config_t58 within_config_t() : within_config_t(0, 0, 0, 0, dnnl_format_tag_undef) {} 59 }; 60 61 struct nchw_across_t { 62 int C, HW, tail; nchw_across_tdnnl::impl::cpu::x64::nchw_across_t63 nchw_across_t(int c, int hw, int t) : C(c), HW(hw), tail(t) {} nchw_across_tdnnl::impl::cpu::x64::nchw_across_t64 nchw_across_t() : nchw_across_t(0, 0, 0) {} 65 }; 66 67 struct nhwc_across_t { 68 int C; nhwc_across_tdnnl::impl::cpu::x64::nhwc_across_t69 nhwc_across_t(int c) : C(c) {} nhwc_across_tdnnl::impl::cpu::x64::nhwc_across_t70 nhwc_across_t() : nhwc_across_t(0) {} 71 }; 72 73 enum class lrn_config_t { 74 none = 0, 75 nchw8c_across, 76 within_config, 77 nchw_across, 78 nhwc_across, 79 }; 80 81 template <class Derived> 82 class jit_uni_lrn_kernel_t; // primary template 83 84 template <template <cpu_isa_t isa, data_type_t d_type> class Derived, 85 cpu_isa_t isa, data_type_t d_type> 86 class jit_uni_lrn_kernel_t<Derived<isa, d_type>> : public jit_generator { 87 public: 88 jit_uni_lrn_kernel_t( 89 void *code_ptr = nullptr, size_t code_size = MAX_CODE_SIZE); 90 jit_uni_lrn_kernel_t(const within_config_t &J, void *code_ptr = nullptr, 91 size_t code_size = MAX_CODE_SIZE); 92 93 ~jit_uni_lrn_kernel_t(); 94 95 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_lrn_kernel_t); 96 static constexpr int VECTOR_LENGTH = (isa == avx512_common ? 16 : 8); 97 98 protected: 99 using Vmm = typename utils::conditional<isa == avx2, Xbyak::Ymm, 100 Xbyak::Zmm>::type; 101 102 void load_constant(float constant, const Vmm &v_constant, 103 const Xbyak::Xmm &x_constant); 104 void load_data(const Vmm ®, const Xbyak::Address &p); 105 void store_data(const Xbyak::Address &p, const Vmm ®); 106 void within_loop( 107 const within_config_t &config, int max_reg_blocks, prop_kind_t pk); 108 void within_body_reg_blocked(int loop_count, int max_reg_block, int hoff, 109 int Hoff, int woff, int Woff, int stride, prop_kind_t pk); 110 111 const bool emulate_bfloat_ = false; 112 const Xbyak::Zmm bf16_emu_reserv_1_ = Xbyak::Zmm(28); 113 const Xbyak::Zmm bf16_emu_reserv_2_ = Xbyak::Zmm(29); 114 const Xbyak::Reg64 bf16_emu_scratch_ = this->rax; 115 const Xbyak::Zmm bf16_emu_reserv_3_ = Xbyak::Zmm(30); 116 const Xbyak::Zmm bf16_emu_reserv_4_ = Xbyak::Zmm(31); 117 std::unique_ptr<bf16_emulation_t> bf16_emu_; 118 const Xbyak::Reg64 h_ = this->r9; 119 const Xbyak::Reg64 w_ = this->r10; 120 const Xbyak::Reg64 imm_addr64_ = this->rbx; 121 int single_pixel_offset_ 122 = VECTOR_LENGTH * sizeof(typename prec_traits<d_type>::type); 123 int tempIdx_ = 0; 124 int reg_block_idx_ = 0; 125 }; 126 127 template <cpu_isa_t isa, data_type_t d_type> 128 class jit_uni_lrn_fwd_kernel_t 129 : public jit_uni_lrn_kernel_t<jit_uni_lrn_fwd_kernel_t<isa, d_type>> { 130 public: 131 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_lrn_fwd_kernel_t) 132 133 jit_uni_lrn_fwd_kernel_t(const within_config_t &J, float A, float K, 134 prop_kind_t pk, void *code_ptr = nullptr, 135 size_t code_size = 4 * Xbyak::DEFAULT_MAX_CODE_SIZE); 136 jit_uni_lrn_fwd_kernel_t(const nchw8c_across_t &J, float A, float K, 137 prop_kind_t pk, void *code_ptr = nullptr, 138 size_t code_size = 1 * Xbyak::DEFAULT_MAX_CODE_SIZE); 139 jit_uni_lrn_fwd_kernel_t(const nhwc_across_t &J, float A, float K, 140 prop_kind_t pk, void *code_ptr = nullptr, 141 size_t code_size = 1 * Xbyak::DEFAULT_MAX_CODE_SIZE); 142 jit_uni_lrn_fwd_kernel_t(const nchw_across_t &J, float A, float K, 143 prop_kind_t pk, void *code_ptr = nullptr, 144 size_t code_size = 2 * Xbyak::DEFAULT_MAX_CODE_SIZE); 145 ~jit_uni_lrn_fwd_kernel_t(); 146 147 private: 148 using Base = jit_uni_lrn_kernel_t<jit_uni_lrn_fwd_kernel_t<isa, d_type>>; 149 generate()150 void generate() override { 151 switch (config_) { 152 case lrn_config_t::within_config: 153 generate(this->within_config_); 154 return; 155 case lrn_config_t::nchw8c_across: 156 generate(this->nchw8c_across_); 157 return; 158 case lrn_config_t::nhwc_across: 159 generate(this->nhwc_across_); 160 return; 161 case lrn_config_t::nchw_across: 162 generate(this->nchw_across_); 163 return; 164 default: assert(!"Configuration not supported"); return; 165 } 166 } 167 void generate(const within_config_t &config); 168 void generate(const nchw8c_across_t &config); 169 void generate(const nhwc_across_t &config); 170 void generate(const nchw_across_t &config); 171 172 public: 173 using Base::VECTOR_LENGTH; 174 175 private: 176 friend Base; 177 using typename Base::Vmm; 178 179 void within_body(int hoff, int Hoff, int woff, int Woff, int stride, 180 prop_kind_t pk, int reg_block = 1, int single_pixel_offset = 0); 181 void nchw_body(int tail, int HW, prop_kind_t pk, Xbyak::Ymm ymask, 182 Xbyak::Ymm ya, Xbyak::Ymm yb, Xbyak::Ymm yc, Xbyak::Ymm yd, 183 Xbyak::Ymm ye, Xbyak::Ymm ysum); 184 void nchw_body_sse41(int tail, int HW, prop_kind_t pk, Xbyak::Xmm xe_lo, 185 Xbyak::Xmm xe_hi, Xbyak::Xmm xsum_lo, Xbyak::Xmm xsum_hi); 186 void nchw_tail_sse41(int tail, Xbyak::Reg64 reg_dst, Xbyak::Xmm xtail_lo, 187 Xbyak::Xmm xtail_hi); 188 void move_data_pointers(int pixel_count, prop_kind_t pk); 189 190 const Xbyak::Reg64 src_ = this->rax; 191 const Xbyak::Reg64 dst_ = this->r8; 192 const Xbyak::Reg64 scratch_ = this->r14; 193 const Xbyak::Reg64 bwd_intermediate_res_ = this->rdx; 194 const Xbyak::Reg64 store_addr_ = this->rbp; 195 196 const Xbyak::Xmm xalpha_ = this->xmm0; 197 const Xbyak::Xmm xk_ = this->xmm1; 198 const Xbyak::Ymm yk_ = this->ymm1; 199 const Vmm valpha_ = Vmm(0); 200 const Vmm vk_ = Vmm(1); 201 202 lrn_config_t config_; 203 const nchw8c_across_t nchw8c_across_; 204 const within_config_t within_config_; 205 const nchw_across_t nchw_across_; 206 const nhwc_across_t nhwc_across_; 207 float alpha_; 208 float k_; 209 prop_kind_t pk_; 210 static constexpr int stack_space_needed_ = 11 * 4 * sizeof(float) + 16; 211 }; 212 213 template <cpu_isa_t isa, data_type_t d_type> 214 class jit_uni_lrn_bwd_kernel_t 215 : public jit_uni_lrn_kernel_t<jit_uni_lrn_bwd_kernel_t<isa, d_type>> { 216 public: 217 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_lrn_bwd_kernel_t) 218 219 jit_uni_lrn_bwd_kernel_t(const nchw8c_across_t &J, float A, float B, 220 int use_h_parallel, void *code_ptr = nullptr, 221 size_t code_size = 1 * Xbyak::DEFAULT_MAX_CODE_SIZE); 222 jit_uni_lrn_bwd_kernel_t(const within_config_t &J, float A, float B, 223 void *code_ptr = nullptr, 224 size_t code_size = 4 * Xbyak::DEFAULT_MAX_CODE_SIZE); 225 226 private: 227 using Base = jit_uni_lrn_kernel_t<jit_uni_lrn_bwd_kernel_t<isa, d_type>>; 228 generate()229 void generate() override { 230 switch (config_) { 231 case lrn_config_t::nchw8c_across: 232 generate(this->nchw8c_across_); 233 return; 234 case lrn_config_t::within_config: 235 generate(this->within_config_); 236 return; 237 default: assert(!"Configuration not supported"); return; 238 } 239 } 240 void generate(const nchw8c_across_t &config); 241 void generate(const within_config_t &config); 242 243 public: 244 using Base::VECTOR_LENGTH; 245 246 private: 247 friend Base; 248 using typename Base::Vmm; 249 250 void within_body(int hoff, int Hoff, int woff, int Woff, int stride, 251 prop_kind_t pk, int reg_block = 1, int single_pixel_offset = 0); 252 void move_data_pointers(int pixel_count, prop_kind_t pk); 253 254 lrn_config_t config_; 255 const nchw8c_across_t nchw8c_across_; 256 const within_config_t within_config_; 257 prop_kind_t pk_ = prop_kind::backward; 258 259 float nalphabeta_; 260 int use_h_parallelizm_; 261 const Xbyak::Reg64 src_ = this->rax; 262 const Xbyak::Reg64 diffsrc_ = this->r13; 263 const Xbyak::Reg64 diffdst_ = this->r14; 264 const Xbyak::Reg64 scratch_ = this->r15; 265 const Xbyak::Reg64 bwd_intermediate_res_ = this->rdx; 266 const Xbyak::Xmm xnalphabeta_ = this->xmm0; 267 const Vmm vnalphabeta_ = Vmm(0); 268 }; 269 270 } // namespace x64 271 } // namespace cpu 272 } // namespace impl 273 } // namespace dnnl 274 275 #endif 276 277 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s 278