1 /*******************************************************************************
2 * Copyright 2016-2020 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef CPU_X64_JIT_UNI_LRN_KERNEL_HPP
18 #define CPU_X64_JIT_UNI_LRN_KERNEL_HPP
19 
20 #include "common/c_types_map.hpp"
21 #include "common/type_helpers.hpp"
22 
23 #include "cpu/x64/jit_generator.hpp"
24 
25 namespace dnnl {
26 namespace impl {
27 namespace cpu {
28 namespace x64 {
29 
30 struct bf16_emulation_t;
31 struct jit_args_fwd_t {
32     const void *src;
33     void *dst, *scratch, *bwd_intermediate_res;
34 };
35 
36 struct jit_args_bwd_t {
37     const void *src, *diff_dst, *scratch, *bwd_intermediate_res;
38     void *diff_src;
39 };
40 
41 struct nchw8c_across_t {
42     /*  version:
43     *  -1: channels 0..7,
44     *   1: channels C-8 .. C-1,
45     *   0: other channels
46     *   3: channels only for this kernel(without prev and next)
47     */
48     int H, W, version;
nchw8c_across_tdnnl::impl::cpu::x64::nchw8c_across_t49     nchw8c_across_t(int h, int w, int v) : H(h), W(w), version(v) {}
nchw8c_across_tdnnl::impl::cpu::x64::nchw8c_across_t50     nchw8c_across_t() : nchw8c_across_t(0, 0, 0) {}
51 };
52 
53 struct within_config_t {
54     const int H, W, C, size;
55     const format_tag_t dat_tag;
within_config_tdnnl::impl::cpu::x64::within_config_t56     within_config_t(int h, int w, int c, int s, format_tag_t dat_tag)
57         : H(h), W(w), C(c), size(s), dat_tag(dat_tag) {}
within_config_tdnnl::impl::cpu::x64::within_config_t58     within_config_t() : within_config_t(0, 0, 0, 0, dnnl_format_tag_undef) {}
59 };
60 
61 struct nchw_across_t {
62     int C, HW, tail;
nchw_across_tdnnl::impl::cpu::x64::nchw_across_t63     nchw_across_t(int c, int hw, int t) : C(c), HW(hw), tail(t) {}
nchw_across_tdnnl::impl::cpu::x64::nchw_across_t64     nchw_across_t() : nchw_across_t(0, 0, 0) {}
65 };
66 
67 struct nhwc_across_t {
68     int C;
nhwc_across_tdnnl::impl::cpu::x64::nhwc_across_t69     nhwc_across_t(int c) : C(c) {}
nhwc_across_tdnnl::impl::cpu::x64::nhwc_across_t70     nhwc_across_t() : nhwc_across_t(0) {}
71 };
72 
73 enum class lrn_config_t {
74     none = 0,
75     nchw8c_across,
76     within_config,
77     nchw_across,
78     nhwc_across,
79 };
80 
81 template <class Derived>
82 class jit_uni_lrn_kernel_t; // primary template
83 
84 template <template <cpu_isa_t isa, data_type_t d_type> class Derived,
85         cpu_isa_t isa, data_type_t d_type>
86 class jit_uni_lrn_kernel_t<Derived<isa, d_type>> : public jit_generator {
87 public:
88     jit_uni_lrn_kernel_t(
89             void *code_ptr = nullptr, size_t code_size = MAX_CODE_SIZE);
90     jit_uni_lrn_kernel_t(const within_config_t &J, void *code_ptr = nullptr,
91             size_t code_size = MAX_CODE_SIZE);
92 
93     ~jit_uni_lrn_kernel_t();
94 
95     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_lrn_kernel_t);
96     static constexpr int VECTOR_LENGTH = (isa == avx512_common ? 16 : 8);
97 
98 protected:
99     using Vmm = typename utils::conditional<isa == avx2, Xbyak::Ymm,
100             Xbyak::Zmm>::type;
101 
102     void load_constant(float constant, const Vmm &v_constant,
103             const Xbyak::Xmm &x_constant);
104     void load_data(const Vmm &reg, const Xbyak::Address &p);
105     void store_data(const Xbyak::Address &p, const Vmm &reg);
106     void within_loop(
107             const within_config_t &config, int max_reg_blocks, prop_kind_t pk);
108     void within_body_reg_blocked(int loop_count, int max_reg_block, int hoff,
109             int Hoff, int woff, int Woff, int stride, prop_kind_t pk);
110 
111     const bool emulate_bfloat_ = false;
112     const Xbyak::Zmm bf16_emu_reserv_1_ = Xbyak::Zmm(28);
113     const Xbyak::Zmm bf16_emu_reserv_2_ = Xbyak::Zmm(29);
114     const Xbyak::Reg64 bf16_emu_scratch_ = this->rax;
115     const Xbyak::Zmm bf16_emu_reserv_3_ = Xbyak::Zmm(30);
116     const Xbyak::Zmm bf16_emu_reserv_4_ = Xbyak::Zmm(31);
117     std::unique_ptr<bf16_emulation_t> bf16_emu_;
118     const Xbyak::Reg64 h_ = this->r9;
119     const Xbyak::Reg64 w_ = this->r10;
120     const Xbyak::Reg64 imm_addr64_ = this->rbx;
121     int single_pixel_offset_
122             = VECTOR_LENGTH * sizeof(typename prec_traits<d_type>::type);
123     int tempIdx_ = 0;
124     int reg_block_idx_ = 0;
125 };
126 
127 template <cpu_isa_t isa, data_type_t d_type>
128 class jit_uni_lrn_fwd_kernel_t
129     : public jit_uni_lrn_kernel_t<jit_uni_lrn_fwd_kernel_t<isa, d_type>> {
130 public:
131     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_lrn_fwd_kernel_t)
132 
133     jit_uni_lrn_fwd_kernel_t(const within_config_t &J, float A, float K,
134             prop_kind_t pk, void *code_ptr = nullptr,
135             size_t code_size = 4 * Xbyak::DEFAULT_MAX_CODE_SIZE);
136     jit_uni_lrn_fwd_kernel_t(const nchw8c_across_t &J, float A, float K,
137             prop_kind_t pk, void *code_ptr = nullptr,
138             size_t code_size = 1 * Xbyak::DEFAULT_MAX_CODE_SIZE);
139     jit_uni_lrn_fwd_kernel_t(const nhwc_across_t &J, float A, float K,
140             prop_kind_t pk, void *code_ptr = nullptr,
141             size_t code_size = 1 * Xbyak::DEFAULT_MAX_CODE_SIZE);
142     jit_uni_lrn_fwd_kernel_t(const nchw_across_t &J, float A, float K,
143             prop_kind_t pk, void *code_ptr = nullptr,
144             size_t code_size = 2 * Xbyak::DEFAULT_MAX_CODE_SIZE);
145     ~jit_uni_lrn_fwd_kernel_t();
146 
147 private:
148     using Base = jit_uni_lrn_kernel_t<jit_uni_lrn_fwd_kernel_t<isa, d_type>>;
149 
generate()150     void generate() override {
151         switch (config_) {
152             case lrn_config_t::within_config:
153                 generate(this->within_config_);
154                 return;
155             case lrn_config_t::nchw8c_across:
156                 generate(this->nchw8c_across_);
157                 return;
158             case lrn_config_t::nhwc_across:
159                 generate(this->nhwc_across_);
160                 return;
161             case lrn_config_t::nchw_across:
162                 generate(this->nchw_across_);
163                 return;
164             default: assert(!"Configuration not supported"); return;
165         }
166     }
167     void generate(const within_config_t &config);
168     void generate(const nchw8c_across_t &config);
169     void generate(const nhwc_across_t &config);
170     void generate(const nchw_across_t &config);
171 
172 public:
173     using Base::VECTOR_LENGTH;
174 
175 private:
176     friend Base;
177     using typename Base::Vmm;
178 
179     void within_body(int hoff, int Hoff, int woff, int Woff, int stride,
180             prop_kind_t pk, int reg_block = 1, int single_pixel_offset = 0);
181     void nchw_body(int tail, int HW, prop_kind_t pk, Xbyak::Ymm ymask,
182             Xbyak::Ymm ya, Xbyak::Ymm yb, Xbyak::Ymm yc, Xbyak::Ymm yd,
183             Xbyak::Ymm ye, Xbyak::Ymm ysum);
184     void nchw_body_sse41(int tail, int HW, prop_kind_t pk, Xbyak::Xmm xe_lo,
185             Xbyak::Xmm xe_hi, Xbyak::Xmm xsum_lo, Xbyak::Xmm xsum_hi);
186     void nchw_tail_sse41(int tail, Xbyak::Reg64 reg_dst, Xbyak::Xmm xtail_lo,
187             Xbyak::Xmm xtail_hi);
188     void move_data_pointers(int pixel_count, prop_kind_t pk);
189 
190     const Xbyak::Reg64 src_ = this->rax;
191     const Xbyak::Reg64 dst_ = this->r8;
192     const Xbyak::Reg64 scratch_ = this->r14;
193     const Xbyak::Reg64 bwd_intermediate_res_ = this->rdx;
194     const Xbyak::Reg64 store_addr_ = this->rbp;
195 
196     const Xbyak::Xmm xalpha_ = this->xmm0;
197     const Xbyak::Xmm xk_ = this->xmm1;
198     const Xbyak::Ymm yk_ = this->ymm1;
199     const Vmm valpha_ = Vmm(0);
200     const Vmm vk_ = Vmm(1);
201 
202     lrn_config_t config_;
203     const nchw8c_across_t nchw8c_across_;
204     const within_config_t within_config_;
205     const nchw_across_t nchw_across_;
206     const nhwc_across_t nhwc_across_;
207     float alpha_;
208     float k_;
209     prop_kind_t pk_;
210     static constexpr int stack_space_needed_ = 11 * 4 * sizeof(float) + 16;
211 };
212 
213 template <cpu_isa_t isa, data_type_t d_type>
214 class jit_uni_lrn_bwd_kernel_t
215     : public jit_uni_lrn_kernel_t<jit_uni_lrn_bwd_kernel_t<isa, d_type>> {
216 public:
217     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_lrn_bwd_kernel_t)
218 
219     jit_uni_lrn_bwd_kernel_t(const nchw8c_across_t &J, float A, float B,
220             int use_h_parallel, void *code_ptr = nullptr,
221             size_t code_size = 1 * Xbyak::DEFAULT_MAX_CODE_SIZE);
222     jit_uni_lrn_bwd_kernel_t(const within_config_t &J, float A, float B,
223             void *code_ptr = nullptr,
224             size_t code_size = 4 * Xbyak::DEFAULT_MAX_CODE_SIZE);
225 
226 private:
227     using Base = jit_uni_lrn_kernel_t<jit_uni_lrn_bwd_kernel_t<isa, d_type>>;
228 
generate()229     void generate() override {
230         switch (config_) {
231             case lrn_config_t::nchw8c_across:
232                 generate(this->nchw8c_across_);
233                 return;
234             case lrn_config_t::within_config:
235                 generate(this->within_config_);
236                 return;
237             default: assert(!"Configuration not supported"); return;
238         }
239     }
240     void generate(const nchw8c_across_t &config);
241     void generate(const within_config_t &config);
242 
243 public:
244     using Base::VECTOR_LENGTH;
245 
246 private:
247     friend Base;
248     using typename Base::Vmm;
249 
250     void within_body(int hoff, int Hoff, int woff, int Woff, int stride,
251             prop_kind_t pk, int reg_block = 1, int single_pixel_offset = 0);
252     void move_data_pointers(int pixel_count, prop_kind_t pk);
253 
254     lrn_config_t config_;
255     const nchw8c_across_t nchw8c_across_;
256     const within_config_t within_config_;
257     prop_kind_t pk_ = prop_kind::backward;
258 
259     float nalphabeta_;
260     int use_h_parallelizm_;
261     const Xbyak::Reg64 src_ = this->rax;
262     const Xbyak::Reg64 diffsrc_ = this->r13;
263     const Xbyak::Reg64 diffdst_ = this->r14;
264     const Xbyak::Reg64 scratch_ = this->r15;
265     const Xbyak::Reg64 bwd_intermediate_res_ = this->rdx;
266     const Xbyak::Xmm xnalphabeta_ = this->xmm0;
267     const Vmm vnalphabeta_ = Vmm(0);
268 };
269 
270 } // namespace x64
271 } // namespace cpu
272 } // namespace impl
273 } // namespace dnnl
274 
275 #endif
276 
277 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
278