1 /*******************************************************************************
2 * Copyright 2016-2020 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include <array>
18 #include <cmath>
19 #include "common/c_types_map.hpp"
20 #include "common/nstl.hpp"
21 #include "common/utils.hpp"
22 #include "cpu/x64/jit_avx512_core_bf16cvt.hpp"
23 #include "cpu/x64/lrn/jit_uni_lrn_kernel.hpp"
24 
25 namespace dnnl {
26 namespace impl {
27 namespace cpu {
28 namespace x64 {
29 
30 using namespace dnnl::impl::format_tag;
31 
32 #define IRB_LOOP(statement) \
33     if (1 == reg_block) { \
34         const int irb_off = 0; \
35         const int irb = this->reg_block_idx_ % vsum.size(); \
36         statement; \
37         MAYBE_UNUSED(irb_off); \
38     } else { \
39         for (int irb = 0; irb < reg_block; irb++) { \
40             const int irb_off = irb * this->single_pixel_offset_; \
41             statement; \
42             MAYBE_UNUSED(irb_off); \
43         } \
44     }
45 
46 using namespace Xbyak;
47 
48 template <template <cpu_isa_t isa, data_type_t d_type> class Derived,
49         cpu_isa_t isa, data_type_t d_type>
jit_uni_lrn_kernel_t(void * code_ptr,size_t code_size)50 jit_uni_lrn_kernel_t<Derived<isa, d_type>>::jit_uni_lrn_kernel_t(
51         void *code_ptr, size_t code_size)
52     : jit_generator(code_ptr, code_size, true, isa)
53     , emulate_bfloat_(isa == avx512_common
54               && d_type == dnnl::impl::data_type::bf16
55               && !mayiuse(avx512_core_bf16))
56     , bf16_emu_(
57               emulate_bfloat_ ? utils::make_unique<bf16_emulation_t>(this,
58                       bf16_emu_reserv_1_, bf16_emu_reserv_2_,
59                       bf16_emu_reserv_3_, bf16_emu_scratch_, bf16_emu_reserv_4_)
60                               : nullptr) {
61 
62     if (bf16_emu_) bf16_emu_->init_vcvtneps2bf16();
63 }
64 
65 template <template <cpu_isa_t isa, data_type_t d_type> class Derived,
66         cpu_isa_t isa, data_type_t d_type>
jit_uni_lrn_kernel_t(const within_config_t & config,void * code_ptr,size_t code_size)67 jit_uni_lrn_kernel_t<Derived<isa, d_type>>::jit_uni_lrn_kernel_t(
68         const within_config_t &config, void *code_ptr, size_t code_size)
69     : jit_uni_lrn_kernel_t(code_ptr, code_size) {
70     if (config.dat_tag == nhwc)
71         single_pixel_offset_
72                 = config.C * sizeof(typename prec_traits<d_type>::type);
73 }
74 
75 template <template <cpu_isa_t isa, data_type_t d_type> class Derived,
76         cpu_isa_t isa, data_type_t d_type>
77 jit_uni_lrn_kernel_t<Derived<isa, d_type>>::~jit_uni_lrn_kernel_t() = default;
78 
79 template <template <cpu_isa_t isa, data_type_t d_type> class Derived,
80         cpu_isa_t isa, data_type_t d_type>
within_loop(const within_config_t & config,int max_reg_blocks,prop_kind_t pk)81 void jit_uni_lrn_kernel_t<Derived<isa, d_type>>::within_loop(
82         const within_config_t &config, int max_reg_blocks, prop_kind_t pk) {
83     const auto derived_ptr = static_cast<Derived<isa, d_type> *>(this);
84 
85     const int lower_bound = (config.size - 1) / 2,
86               upper_bound = config.size - lower_bound - 1;
87 
88     int pixel_count = 0;
89 
90     for (int i = 0; i < lower_bound; ++i) {
91         pixel_count = 0;
92         for (int j = 0; j < lower_bound; ++j)
93             derived_ptr->within_body(-i, upper_bound, -j, upper_bound, config.W,
94                     pk, 1, pixel_count++ * this->single_pixel_offset_);
95         derived_ptr->move_data_pointers(pixel_count, pk);
96 
97         within_body_reg_blocked(config.W - config.size + 1, max_reg_blocks, -i,
98                 upper_bound, -lower_bound, upper_bound, config.W, pk);
99 
100         pixel_count = 0;
101         for (int j = config.W - upper_bound; j < config.W; ++j)
102             derived_ptr->within_body(-i, upper_bound, -lower_bound,
103                     config.W - 1 - j, config.W, pk, 1,
104                     pixel_count++ * this->single_pixel_offset_);
105         derived_ptr->move_data_pointers(pixel_count, pk);
106     }
107 
108     this->mov(h_, config.H - config.size + 1);
109     Label lrn_loop_h;
110     this->L(lrn_loop_h);
111     pixel_count = 0;
112     for (int j = 0; j < lower_bound; ++j)
113         derived_ptr->within_body(-lower_bound, upper_bound, -j, upper_bound,
114                 config.W, pk, 1, pixel_count++ * this->single_pixel_offset_);
115     derived_ptr->move_data_pointers(pixel_count, pk);
116 
117     within_body_reg_blocked(config.W - config.size + 1, max_reg_blocks,
118             -lower_bound, upper_bound, -lower_bound, upper_bound, config.W, pk);
119 
120     pixel_count = 0;
121     for (int j = config.W - upper_bound; j < config.W; ++j)
122         derived_ptr->within_body(-lower_bound, upper_bound, -lower_bound,
123                 config.W - 1 - j, config.W, pk, 1,
124                 pixel_count++ * this->single_pixel_offset_);
125     derived_ptr->move_data_pointers(pixel_count, pk);
126 
127     this->dec(h_);
128     this->cmp(h_, 0);
129     this->jne(lrn_loop_h, this->T_NEAR);
130 
131     for (int i = config.H - upper_bound; i < config.H; ++i) {
132         pixel_count = 0;
133         for (int j = 0; j < lower_bound; ++j)
134             derived_ptr->within_body(-lower_bound, config.H - 1 - i, -j,
135                     upper_bound, config.W, pk, 1,
136                     pixel_count++ * this->single_pixel_offset_);
137         derived_ptr->move_data_pointers(pixel_count, pk);
138 
139         within_body_reg_blocked(config.W - config.size + 1, max_reg_blocks,
140                 -lower_bound, config.H - 1 - i, -lower_bound, upper_bound,
141                 config.W, pk);
142 
143         pixel_count = 0;
144         for (int j = config.W - upper_bound; j < config.W; ++j)
145             derived_ptr->within_body(-lower_bound, config.H - 1 - i,
146                     -lower_bound, config.W - 1 - j, config.W, pk, 1,
147                     pixel_count++ * this->single_pixel_offset_);
148         derived_ptr->move_data_pointers(pixel_count, pk);
149     }
150 }
151 
152 template <template <cpu_isa_t isa, data_type_t d_type> class Derived,
153         cpu_isa_t isa, data_type_t d_type>
within_body_reg_blocked(int loop_count,int max_reg_blocks,int hoff,int Hoff,int woff,int Woff,int stride,prop_kind_t pk)154 void jit_uni_lrn_kernel_t<Derived<isa, d_type>>::within_body_reg_blocked(
155         int loop_count, int max_reg_blocks, int hoff, int Hoff, int woff,
156         int Woff, int stride, prop_kind_t pk) {
157 
158     const auto derived_ptr = static_cast<Derived<isa, d_type> *>(this);
159     Label reg_block_compute_loop;
160 
161     const auto res = std::div(loop_count, max_reg_blocks);
162     if (res.quot) {
163         this->mov(this->w_, res.quot);
164         this->L(reg_block_compute_loop);
165         derived_ptr->within_body(
166                 hoff, Hoff, woff, Woff, stride, pk, max_reg_blocks, 0);
167         derived_ptr->move_data_pointers(max_reg_blocks, pk);
168         this->dec(this->w_);
169         this->cmp(this->w_, 0);
170         this->jne(reg_block_compute_loop, this->T_NEAR);
171     }
172     if (res.rem) {
173         derived_ptr->within_body(
174                 hoff, Hoff, woff, Woff, stride, pk, res.rem, 0);
175         derived_ptr->move_data_pointers(res.rem, pk);
176     }
177 }
178 
179 template <template <cpu_isa_t isa, data_type_t d_type> class Derived,
180         cpu_isa_t isa, data_type_t d_type>
load_data(const Vmm & reg,const Xbyak::Address & p)181 void jit_uni_lrn_kernel_t<Derived<isa, d_type>>::load_data(
182         const Vmm &reg, const Xbyak::Address &p) {
183     this->uni_vmovups(reg, p);
184 }
185 
186 template <typename Gen, typename Reg, typename Addr>
load_bf16_data(Gen generator,const Reg & reg,const Addr & p)187 void load_bf16_data(Gen generator, const Reg &reg, const Addr &p) {
188     generator->vpmovzxwd(reg, p);
189     generator->vpslld(reg, reg, 0x10);
190 }
191 
192 template <>
193 void jit_uni_lrn_kernel_t<jit_uni_lrn_fwd_kernel_t<avx512_common,
load_data(const Vmm & reg,const Xbyak::Address & p)194         dnnl::impl::data_type::bf16>>::load_data(const Vmm &reg,
195         const Xbyak::Address &p) {
196     load_bf16_data(this, reg, p);
197 }
198 
199 template <>
200 void jit_uni_lrn_kernel_t<jit_uni_lrn_bwd_kernel_t<avx512_common,
load_data(const Vmm & reg,const Xbyak::Address & p)201         dnnl::impl::data_type::bf16>>::load_data(const Vmm &reg,
202         const Xbyak::Address &p) {
203     load_bf16_data(this, reg, p);
204 }
205 
206 template <template <cpu_isa_t isa, data_type_t d_type> class Derived,
207         cpu_isa_t isa, data_type_t d_type>
store_data(const Xbyak::Address & addr,const Vmm & reg)208 void jit_uni_lrn_kernel_t<Derived<isa, d_type>>::store_data(
209         const Xbyak::Address &addr, const Vmm &reg) {
210     this->uni_vmovups(addr, reg);
211 }
212 
213 template <typename Gen, typename Bf16Emu>
store_bf16_data(Gen generator,Bf16Emu emu,const Xbyak::Address & addr,const Zmm & zr)214 void store_bf16_data(
215         Gen generator, Bf16Emu emu, const Xbyak::Address &addr, const Zmm &zr) {
216     const Ymm yr = Ymm(zr.getIdx());
217     if (mayiuse(avx512_core_bf16))
218         generator->vcvtneps2bf16(yr, zr);
219     else
220         emu->vcvtneps2bf16(yr, zr);
221     generator->vmovdqu16(addr, yr);
222 }
223 
224 template <>
225 void jit_uni_lrn_kernel_t<jit_uni_lrn_fwd_kernel_t<avx512_common,
store_data(const Xbyak::Address & addr,const Zmm & zr)226         dnnl::impl::data_type::bf16>>::store_data(const Xbyak::Address &addr,
227         const Zmm &zr) {
228     store_bf16_data(this, bf16_emu_.get(), addr, zr);
229 }
230 
231 template <>
232 void jit_uni_lrn_kernel_t<jit_uni_lrn_bwd_kernel_t<avx512_common,
store_data(const Xbyak::Address & addr,const Zmm & zr)233         dnnl::impl::data_type::bf16>>::store_data(const Xbyak::Address &addr,
234         const Zmm &zr) {
235     store_bf16_data(this, bf16_emu_.get(), addr, zr);
236 }
237 
238 template <template <cpu_isa_t isa, data_type_t d_type> class Derived,
239         cpu_isa_t isa, data_type_t d_type>
load_constant(float constant,const Vmm & v_constant,const Xbyak::Xmm & x_constant)240 void jit_uni_lrn_kernel_t<Derived<isa, d_type>>::load_constant(
241         float constant, const Vmm &v_constant, const Xbyak::Xmm &x_constant) {
242     this->mov(this->imm_addr64_, float2int(constant));
243     this->uni_vmovq(x_constant, this->imm_addr64_);
244     this->vbroadcastss(v_constant, x_constant);
245 }
246 
247 template <>
248 void jit_uni_lrn_kernel_t<jit_uni_lrn_fwd_kernel_t<sse41,
load_constant(float constant,const Vmm & v_constant,const Xbyak::Xmm & x_constant)249         dnnl::impl::data_type::f32>>::load_constant(float constant,
250         const Vmm &v_constant, const Xbyak::Xmm &x_constant) {
251     this->mov(this->imm_addr64_, float2int(constant));
252     this->uni_vmovq(x_constant, this->imm_addr64_);
253     this->shufps(x_constant, x_constant, 0);
254 }
255 
256 //////////////////////////////////////////////////////////////////////////////
257 // forward kernel
258 template <cpu_isa_t isa, data_type_t d_type>
within_body(int hoff,int Hoff,int woff,int Woff,int stride,prop_kind_t pk,const int reg_block,int pixel_offset)259 void jit_uni_lrn_fwd_kernel_t<isa, d_type>::within_body(int hoff, int Hoff,
260         int woff, int Woff, int stride, prop_kind_t pk, const int reg_block,
261         int pixel_offset) {
262 
263     static const std::array<Vmm, 3> vsum {{Vmm(2), Vmm(11), Vmm(20)}};
264     static const std::array<Vmm, 3> vsum2 {{Vmm(3), Vmm(12), Vmm(21)}};
265     static const std::array<Vmm, 3> vdst {{Vmm(4), Vmm(13), Vmm(22)}};
266     static const std::array<std::array<Vmm, 6u>, 3u> vtmp {
267             {{{Vmm(5), Vmm(6), Vmm(7), Vmm(8), Vmm(9), Vmm(14)}},
268                     {{Vmm(18), Vmm(15), Vmm(16), Vmm(17), Vmm(29), Vmm(30)}},
269                     {{Vmm(23), Vmm(24), Vmm(25), Vmm(26), Vmm(28), Vmm(31)}}}};
270     static const std::array<Vmm, 3> vscratch = {{Vmm(10), Vmm(19), Vmm(27)}};
271     static const std::size_t used_tmp_regs
272             = this->emulate_bfloat_ ? vtmp[0].size() - 2 : vtmp[0].size();
273 
274     IRB_LOOP(this->uni_vxorps(vsum[irb], vsum[irb], vsum[irb]));
275     for (int i = hoff; i <= Hoff; ++i) {
276         for (int j = woff; j <= Woff; ++j) {
277             if (i == 0 && j == 0) {
278                 IRB_LOOP(this->load_data(
279                         vdst[irb], this->ptr[src_ + pixel_offset + irb_off]));
280                 IRB_LOOP(this->vfmadd231ps(vsum[irb], vdst[irb], vdst[irb]));
281             } else {
282                 const auto idx = this->tempIdx_ % used_tmp_regs;
283                 IRB_LOOP(this->load_data(vtmp[irb][idx],
284                         this->ptr[(src_ + pixel_offset + irb_off)
285                                 + (i * stride + j)
286                                         * this->single_pixel_offset_]));
287                 IRB_LOOP(this->vfmadd231ps(
288                         vsum[irb], vtmp[irb][idx], vtmp[irb][idx]));
289                 ++(this->tempIdx_);
290             }
291         }
292     }
293 
294     this->tempIdx_ = this->tempIdx_ % used_tmp_regs;
295 
296     IRB_LOOP(this->vfmadd132ps(
297             vsum[irb], vk_, valpha_)); // ysum <- ysum*valpha_+yk_
298     IRB_LOOP(this->vmovaps(vscratch[irb], vsum[irb]));
299 
300     IRB_LOOP(this->vmulps(vsum2[irb], vsum[irb], vsum[irb]));
301     IRB_LOOP(this->vmulps(
302             vsum[irb], vsum[irb], vsum2[irb])); // ysum = (ysum*valpha_+yk_)^3;
303     IRB_LOOP(this->vsqrtps(vsum[irb], vsum[irb]));
304     IRB_LOOP(this->vsqrtps(
305             vsum[irb], vsum[irb])); // ysum = (ysum*valpha_+yk_)^0.75
306     IRB_LOOP(this->vdivps(
307             vdst[irb], vdst[irb], vsum[irb])); // ydst <- ydst / ysum
308 
309     if (pk_ != prop_kind::forward_inference) {
310         IRB_LOOP(this->store_data(
311                 this->ptr[scratch_ + pixel_offset + irb_off], vsum[irb]));
312         IRB_LOOP(this->vdivps(vscratch[irb], vdst[irb], vscratch[irb]));
313         IRB_LOOP(this->store_data(
314                 this->ptr[bwd_intermediate_res_ + pixel_offset + irb_off],
315                 vscratch[irb]));
316     }
317 
318     IRB_LOOP(this->store_data(
319             this->ptr[dst_ + pixel_offset + irb_off], vdst[irb]));
320 
321     if (isa == avx512_common)
322         this->reg_block_idx_ = (this->reg_block_idx_ % vsum.size()) + 1;
323 }
324 
325 template <>
within_body(int hoff,int Hoff,int woff,int Woff,int stride,prop_kind_t pk,int reg_block,int pixel_offset)326 void jit_uni_lrn_fwd_kernel_t<sse41, dnnl::impl::data_type::f32>::within_body(
327         int hoff, int Hoff, int woff, int Woff, int stride, prop_kind_t pk,
328         int reg_block, int pixel_offset) {
329 
330     const Xbyak::Xmm &xtmp_lo = this->xmm2;
331     const Xbyak::Xmm &xtmp_hi = this->xmm3;
332     const Xbyak::Xmm &xsum_lo = this->xmm4;
333     const Xbyak::Xmm &xsum_hi = this->xmm5;
334     const Xbyak::Xmm &xdst_lo = this->xmm6;
335     const Xbyak::Xmm &xdst_hi = this->xmm7;
336     const Xbyak::Xmm &xsum2_lo = this->xmm8;
337     const Xbyak::Xmm &xsum2_hi = this->xmm9;
338 
339     xorps(xsum_lo, xsum_lo);
340     xorps(xsum_hi, xsum_hi);
341     for (int i = hoff; i <= Hoff; ++i) {
342         for (int j = woff; j <= Woff; ++j) {
343             if (i == 0 && j == 0) {
344                 movups(xdst_lo, ptr[src_ + pixel_offset]);
345                 movups(xdst_hi, ptr[src_ + pixel_offset + 4 * sizeof(float)]);
346                 mulps(xdst_lo, xdst_lo);
347                 mulps(xdst_hi, xdst_hi);
348                 addps(xsum_lo, xdst_lo);
349                 addps(xsum_hi, xdst_hi);
350             } else {
351                 movups(xtmp_lo,
352                         ptr[src_ + pixel_offset
353                                 + (i * stride + j) * single_pixel_offset_]);
354                 movups(xtmp_hi,
355                         ptr[src_ + pixel_offset
356                                 + (i * stride + j) * single_pixel_offset_
357                                 + 4 * sizeof(float)]);
358                 this->mulps(xtmp_lo, xtmp_lo);
359                 this->mulps(xtmp_hi, xtmp_hi);
360                 this->addps(xsum_lo, xtmp_lo);
361                 this->addps(xsum_hi, xtmp_hi);
362             }
363         }
364     }
365     this->mulps(xsum_lo, xalpha_);
366     this->mulps(xsum_hi, xalpha_);
367     this->addps(xsum_lo, xk_);
368     this->addps(xsum_hi, xk_); // xsum <- xsum*xalpha_+xk_
369     this->movaps(xtmp_lo, xsum_lo);
370     this->movaps(xtmp_hi, xsum_hi);
371     if (pk_ != prop_kind::forward_inference) {
372         this->movups(this->ptr[scratch_ + pixel_offset], xtmp_lo);
373         this->movups(this->ptr[scratch_ + pixel_offset + 4 * sizeof(float)],
374                 xtmp_hi);
375     }
376     this->movaps(xsum2_lo, xsum_lo);
377     this->movaps(xsum2_hi, xsum_hi);
378     this->mulps(xsum2_lo, xsum_lo);
379     this->mulps(xsum2_hi, xsum_hi);
380     this->mulps(xsum_lo, xsum2_lo);
381     this->mulps(xsum_hi, xsum2_hi); // xsum = (xsum*xalpha_+xk_)^3;
382 
383     this->sqrtps(xsum_lo, xsum_lo);
384     this->sqrtps(xsum_hi, xsum_hi);
385     this->sqrtps(xsum_lo, xsum_lo);
386     this->sqrtps(xsum_hi, xsum_hi); // xsum = (xsum*xalpha_+xk_)^0.75
387 
388     this->movups(xdst_lo, this->ptr[src_ + pixel_offset]);
389     this->movups(xdst_hi, this->ptr[src_ + pixel_offset + 4 * sizeof(float)]);
390     this->divps(xdst_lo, xsum_lo);
391     this->divps(xdst_hi, xsum_hi); // xdst <- xdst / xsum
392 
393     this->movups(this->ptr[dst_ + pixel_offset], xdst_lo);
394     this->movups(this->ptr[dst_ + pixel_offset + 4 * sizeof(float)], xdst_hi);
395 }
396 
397 template <cpu_isa_t isa, data_type_t d_type>
move_data_pointers(int pixel_count,prop_kind_t pk)398 void jit_uni_lrn_fwd_kernel_t<isa, d_type>::move_data_pointers(
399         int pixel_count, prop_kind_t pk) {
400 
401     const int pixel_offset = this->single_pixel_offset_ * pixel_count;
402     this->add(src_, pixel_offset);
403     this->add(dst_, pixel_offset);
404     if (pk_ != prop_kind::forward_inference) {
405         this->add(scratch_, pixel_offset);
406         this->add(bwd_intermediate_res_, pixel_offset);
407     }
408 }
409 
410 template <cpu_isa_t isa, data_type_t d_type>
jit_uni_lrn_fwd_kernel_t(const within_config_t & config,float A,float K,prop_kind_t pk,void * code_ptr,size_t code_size)411 jit_uni_lrn_fwd_kernel_t<isa, d_type>::jit_uni_lrn_fwd_kernel_t(
412         const within_config_t &config, float A, float K, prop_kind_t pk,
413         void *code_ptr, size_t code_size)
414     : Base(config, code_ptr, code_size)
415     , config_(lrn_config_t::within_config)
416     , within_config_(config)
417     , alpha_(A)
418     , k_(K)
419     , pk_(pk) {}
420 
421 template <cpu_isa_t isa, data_type_t d_type>
generate(const within_config_t & config)422 void jit_uni_lrn_fwd_kernel_t<isa, d_type>::generate(
423         const within_config_t &config) {
424     this->preamble();
425 
426 #define GET_OFF(field) offsetof(jit_args_fwd_t, field)
427     this->mov(src_, this->ptr[this->param1 + GET_OFF(src)]);
428     this->mov(dst_, this->ptr[this->param1 + GET_OFF(dst)]);
429     if (pk_ != prop_kind::forward_inference) {
430         this->mov(scratch_, this->ptr[this->param1 + GET_OFF(scratch)]);
431         this->mov(bwd_intermediate_res_,
432                 this->ptr[this->param1 + GET_OFF(bwd_intermediate_res)]);
433     }
434 #undef GET_OFF
435 
436     this->load_constant(alpha_, valpha_, xalpha_);
437     this->load_constant(k_, vk_, xk_);
438 
439     static const int max_reg_blocks = isa == avx512_common ? 3 : 1;
440     this->within_loop(config, max_reg_blocks, pk_);
441 
442     this->postamble();
443 }
444 
445 template <cpu_isa_t isa, data_type_t d_type>
jit_uni_lrn_fwd_kernel_t(const struct nchw8c_across_t & J,float A,float K,prop_kind_t pk,void * code_ptr,size_t code_size)446 jit_uni_lrn_fwd_kernel_t<isa, d_type>::jit_uni_lrn_fwd_kernel_t(
447         const struct nchw8c_across_t &J, float A, float K, prop_kind_t pk,
448         void *code_ptr, size_t code_size)
449     : Base(code_ptr, code_size)
450     , config_(lrn_config_t::nchw8c_across)
451     , nchw8c_across_(J)
452     , alpha_(A)
453     , k_(K)
454     , pk_(pk) {}
455 
456 template <cpu_isa_t isa, data_type_t d_type>
generate(const nchw8c_across_t & J)457 void jit_uni_lrn_fwd_kernel_t<isa, d_type>::generate(const nchw8c_across_t &J) {
458     const Xbyak::Reg64 &t = this->rsp;
459     const Xbyak::Reg64 &hw = this->r9;
460     const Xbyak::Xmm &xsrc_prev = this->xmm2;
461     const Xbyak::Ymm &ysrc = this->ymm3;
462     const Xbyak::Ymm &yc = this->ymm3;
463     const Xbyak::Xmm &xsrc_next = this->xmm4;
464     const Xbyak::Ymm &ya = this->ymm5;
465     const Xbyak::Ymm &yb = this->ymm6;
466     const Xbyak::Ymm &yd = this->ymm7;
467     const Xbyak::Ymm &ye = this->ymm8;
468     const Xbyak::Ymm &ysum = this->ymm9;
469     const Xbyak::Ymm &ysum2 = this->ymm10;
470     const Xbyak::Ymm &ydst = this->ymm11;
471     const Xbyak::Ymm &ybase = this->ymm12;
472 
473     this->preamble();
474 
475     this->mov(src_, this->ptr[this->param1 + 0]);
476     this->mov(dst_, this->ptr[this->param1 + 8]);
477     if (pk_ != prop_kind::forward_inference)
478         this->mov(scratch_, this->ptr[this->param1 + 16]);
479     this->sub(t, 64);
480     this->mov(this->imm_addr64_, float2int(this->alpha_));
481     this->vmovq(xalpha_, this->imm_addr64_);
482     this->vbroadcastss(valpha_, xalpha_);
483 
484     this->mov(this->imm_addr64_, float2int(this->k_));
485     this->vmovq(xk_, this->imm_addr64_);
486     this->vbroadcastss(yk_, xk_);
487 
488     if (J.version == -1) {
489         this->vxorps(xsrc_prev, xsrc_prev, xsrc_prev);
490         this->vmovups(this->ptr[t + 0], xsrc_prev);
491     }
492     if (J.version == +1) {
493         this->vxorps(xsrc_next, xsrc_next, xsrc_next);
494         this->vmovups(this->ptr[t + 48], xsrc_next);
495     }
496 
497     this->mov(hw, J.H * J.W);
498 
499     Label lrn_loop;
500     this->L(lrn_loop);
501 
502     if (J.version != -1)
503         this->vmovups(xsrc_prev, this->ptr[src_ - J.H * J.W * 32 + 16]);
504     this->vmovups(ysrc, this->ptr[src_]);
505     if (J.version != +1)
506         this->vmovups(xsrc_next, this->ptr[src_ + J.H * J.W * 32]);
507 
508     if (J.version != -1) this->vmovups(this->ptr[t + 0], xsrc_prev);
509     this->vmovups(this->ptr[t + 16], ysrc);
510     if (J.version != +1) this->vmovups(this->ptr[t + 48], xsrc_next);
511 
512     this->vmovups(ya, this->ptr[t + 16 - 8]);
513     this->vmovups(yb, this->ptr[t + 16 - 4]);
514     this->vmovups(yd, this->ptr[t + 16 + 4]);
515     this->vmovups(ye, this->ptr[t + 16 + 8]);
516     this->vmulps(ysum, yc, yc);
517     this->vfmadd231ps(ysum, ya, ya); // ysum <- ysum + ya*ya
518     this->vfmadd231ps(ysum, yb, yb);
519     this->vfmadd231ps(ysum, yd, yd);
520     this->vfmadd231ps(ysum, ye, ye);
521     this->vfmadd132ps(ysum, yk_, valpha_); // ysum <- ysum*valpha_+yk_
522 
523     this->vmovaps(ybase, ysum);
524     if (pk_ != prop_kind::forward_inference)
525         this->vmovups(this->ptr[scratch_], ybase);
526     this->vmulps(ysum2, ysum, ysum);
527     this->vmulps(ysum, ysum, ysum2); // ysum = ybase^3;
528     this->vsqrtps(ysum, ysum);
529     this->vsqrtps(ysum, ysum); // ysum = ybase^0.75
530     this->vdivps(ydst, ysrc, ysum); // ydst = ysrc / ysum
531     this->vmovups(this->ptr[dst_], ydst);
532 
533     this->add(src_, 32);
534     this->add(dst_, 32);
535     if (pk_ != prop_kind::forward_inference) this->add(scratch_, 32);
536     this->dec(hw);
537     this->cmp(hw, 0);
538     this->jne(lrn_loop, this->T_NEAR);
539 
540     this->add(t, 64);
541     this->postamble();
542 }
543 
544 template <>
545 jit_uni_lrn_fwd_kernel_t<sse41, dnnl::impl::data_type::f32>::
jit_uni_lrn_fwd_kernel_t(const struct nchw8c_across_t & J,float A,float K,prop_kind_t pk,void * code_ptr,size_t code_size)546         jit_uni_lrn_fwd_kernel_t(const struct nchw8c_across_t &J, float A,
547                 float K, prop_kind_t pk, void *code_ptr, size_t code_size)
548     : Base(code_ptr, code_size)
549     , config_(lrn_config_t::nchw8c_across)
550     , nchw8c_across_(J)
551     , alpha_(A)
552     , k_(K)
553     , pk_(pk) {}
554 
555 template <>
generate(const nchw8c_across_t & J)556 void jit_uni_lrn_fwd_kernel_t<sse41, dnnl::impl::data_type::f32>::generate(
557         const nchw8c_across_t &J) {
558 
559     const Xbyak::Reg64 &t = this->rsp;
560     const Xbyak::Reg64 &hw = this->r9;
561     const Xbyak::Xmm &xsrc_lo = this->xmm2;
562     const Xbyak::Xmm &xsrc_hi = this->xmm3;
563     const Xbyak::Xmm &xc_lo = this->xmm4;
564     const Xbyak::Xmm &xc_hi = this->xmm5;
565     const Xbyak::Xmm &xsum_lo = xc_lo;
566     const Xbyak::Xmm &xsum_hi = xc_hi;
567     const Xbyak::Xmm &xsrc_prev = this->xmm6;
568     const Xbyak::Xmm &xsrc_next = this->xmm7;
569     const Xbyak::Xmm &xa_lo = this->xmm8;
570     const Xbyak::Xmm &xa_hi = this->xmm9;
571     const Xbyak::Xmm &xb_lo = this->xmm10;
572     const Xbyak::Xmm &xb_hi = this->xmm11;
573     const Xbyak::Xmm &xd_lo = this->xmm12;
574     const Xbyak::Xmm &xd_hi = this->xmm13;
575     const Xbyak::Xmm &xe_lo = this->xmm14;
576     const Xbyak::Xmm &xe_hi = this->xmm15;
577     const Xbyak::Xmm &xbase_lo = this->xmm14;
578     const Xbyak::Xmm &xbase_hi = this->xmm15;
579 
580     this->preamble();
581 
582     this->mov(src_, this->ptr[this->param1 + 0]);
583     this->mov(dst_, this->ptr[this->param1 + 8]);
584     if (pk_ != prop_kind::forward_inference)
585         this->mov(scratch_, this->ptr[this->param1 + 16]);
586     this->sub(t, 64);
587     this->mov(this->imm_addr64_, float2int(this->alpha_));
588     this->movq(xalpha_, this->imm_addr64_);
589     this->shufps(xalpha_, xalpha_, 0);
590 
591     this->mov(this->imm_addr64_, float2int(this->k_));
592     this->movq(xk_, this->imm_addr64_);
593     this->shufps(xk_, xk_, 0);
594 
595     if (J.version == -1) {
596         this->xorps(xsrc_prev, xsrc_prev);
597         this->movups(this->ptr[t + 0], xsrc_prev);
598     }
599     if (J.version == +1) {
600         this->xorps(xsrc_next, xsrc_next);
601         this->movups(this->ptr[t + 48], xsrc_next);
602     }
603 
604     this->mov(hw, J.H * J.W);
605     Label lrn_loop;
606     L(lrn_loop);
607 
608     if (J.version != -1)
609         this->movups(xsrc_prev, this->ptr[src_ - J.H * J.W * 32 + 16]);
610     this->movups(xsrc_lo, this->ptr[src_]);
611     this->movups(xsrc_hi, this->ptr[src_ + 4 * sizeof(float)]);
612     if (J.version != +1)
613         this->movups(xsrc_next, this->ptr[src_ + J.H * J.W * 32]);
614 
615     if (J.version != -1) this->movups(this->ptr[t + 0], xsrc_prev);
616     this->movups(this->ptr[t + 16], xsrc_lo);
617     this->movups(this->ptr[t + 16 + 4 * sizeof(float)], xsrc_hi);
618     if (J.version != +1) this->movups(this->ptr[t + 48], xsrc_next);
619 
620     this->movups(xa_lo, this->ptr[t + 16 - 8]);
621     this->movups(xa_hi, this->ptr[t + 16 - 8 + 4 * sizeof(float)]);
622     this->movups(xb_lo, this->ptr[t + 16 - 4]);
623     this->movups(xb_hi, this->ptr[t + 16 - 4 + 4 * sizeof(float)]);
624     this->movups(xd_lo, this->ptr[t + 16 + 4]);
625     this->movups(xd_hi, this->ptr[t + 16 + 4 + 4 * sizeof(float)]);
626     this->movups(xe_lo, this->ptr[t + 16 + 8]);
627     this->movups(xe_hi, this->ptr[t + 16 + 8 + 4 * sizeof(float)]);
628     this->movaps(xc_lo, xsrc_lo);
629     this->movaps(xc_hi, xsrc_hi);
630     this->mulps(xsum_lo, xc_lo);
631     this->mulps(xsum_hi, xc_hi);
632     this->mulps(xa_lo, xa_lo);
633     this->mulps(xa_hi, xa_hi);
634     this->addps(xsum_lo, xa_lo);
635     this->addps(xsum_hi, xa_hi); // xsum <- xsum + xa*xa
636     this->mulps(xb_lo, xb_lo);
637     this->mulps(xb_hi, xb_hi);
638     this->addps(xsum_lo, xb_lo);
639     this->addps(xsum_hi, xb_hi);
640     this->mulps(xd_lo, xd_lo);
641     this->mulps(xd_hi, xd_hi);
642     this->addps(xsum_lo, xd_lo);
643     this->addps(xsum_hi, xd_hi);
644     this->mulps(xe_lo, xe_lo);
645     this->mulps(xe_hi, xe_hi);
646     this->addps(xsum_lo, xe_lo);
647     this->addps(xsum_hi, xe_hi);
648 
649     this->mulps(xsum_lo, xalpha_);
650     this->mulps(xsum_hi, xalpha_);
651     this->addps(xsum_lo, xk_);
652     this->addps(xsum_hi, xk_); // xsum <- xsum*xalpha_+xk_
653 
654     this->movaps(xbase_lo, xsum_lo);
655     this->movaps(xbase_hi, xsum_hi);
656     if (pk_ != prop_kind::forward_inference) {
657         this->movups(this->ptr[scratch_], xbase_lo);
658         this->movups(this->ptr[scratch_ + 4 * sizeof(float)], xbase_hi);
659     }
660     this->mulps(xsum_lo, xsum_lo);
661     this->mulps(xsum_hi, xsum_hi);
662     this->mulps(xsum_lo, xbase_lo);
663     this->mulps(xsum_hi, xbase_hi); // xsum = xbase^3;
664     this->sqrtps(xsum_lo, xsum_lo);
665     this->sqrtps(xsum_hi, xsum_hi);
666     this->sqrtps(xsum_lo, xsum_lo);
667     this->sqrtps(xsum_hi, xsum_hi); // xsum = xbase^0.75
668     this->divps(xsrc_lo, xsum_lo);
669     this->divps(xsrc_hi, xsum_hi); // xdst = xsrc / xsum
670     this->movups(this->ptr[dst_], xsrc_lo);
671     this->movups(this->ptr[dst_ + 4 * sizeof(float)], xsrc_hi);
672 
673     this->add(src_, 32);
674     this->add(dst_, 32);
675     if (pk_ != prop_kind::forward_inference) add(scratch_, 32);
676     this->dec(hw);
677     this->cmp(hw, 0);
678     this->jne(lrn_loop, this->T_NEAR);
679 
680     this->add(t, 64);
681     this->postamble();
682 }
683 
684 template <cpu_isa_t isa, data_type_t d_type>
jit_uni_lrn_fwd_kernel_t(const struct nhwc_across_t & J,float A,float K,prop_kind_t pk,void * code_ptr,size_t code_size)685 jit_uni_lrn_fwd_kernel_t<isa, d_type>::jit_uni_lrn_fwd_kernel_t(
686         const struct nhwc_across_t &J, float A, float K, prop_kind_t pk,
687         void *code_ptr, size_t code_size)
688     : Base(code_ptr, code_size)
689     , config_(lrn_config_t::nhwc_across)
690     , nhwc_across_(J)
691     , alpha_(A)
692     , k_(K)
693     , pk_(pk) {}
694 
695 template <cpu_isa_t isa, data_type_t d_type>
generate(const nhwc_across_t & J)696 void jit_uni_lrn_fwd_kernel_t<isa, d_type>::generate(const nhwc_across_t &J) {
697     static const uint32_t mask[] = {0, 0, 0x80000000, 0x80000000, 0x80000000,
698             0x80000000, 0x80000000, 0x80000000, 0x80000000, 0, 0};
699 
700     const Xbyak::Reg64 &c = this->r9;
701     const Xbyak::Ymm &ya = this->ymm2;
702     const Xbyak::Ymm &yb = this->ymm3;
703     const Xbyak::Ymm &yc = this->ymm4;
704     const Xbyak::Ymm &yd = this->ymm5;
705     const Xbyak::Ymm &ye = this->ymm6;
706     const Xbyak::Ymm &ysum = this->ymm7;
707     const Xbyak::Ymm &ydst = this->ymm8;
708     const Xbyak::Ymm &ybase = this->ymm9;
709     const Xbyak::Ymm &ymask = this->ymm10;
710 
711     this->preamble();
712 
713     this->mov(src_, this->ptr[this->param1 + 0]);
714     this->mov(dst_, this->ptr[this->param1 + 8]);
715     if (pk_ != prop_kind::forward_inference)
716         this->mov(scratch_, this->ptr[this->param1 + 16]);
717     this->mov(this->imm_addr64_, float2int(this->alpha_));
718     this->movq(xalpha_, this->imm_addr64_);
719     this->vbroadcastss(valpha_, xalpha_);
720 
721     this->mov(this->imm_addr64_, float2int(this->k_));
722     this->movq(xk_, this->imm_addr64_);
723     this->vbroadcastss(yk_, xk_);
724 
725     this->vxorps(ysum, ysum, ysum);
726 
727     this->mov(this->imm_addr64_, reinterpret_cast<size_t>(&mask[0]));
728     this->vmovups(ymask, this->ptr[this->imm_addr64_]);
729     this->vmaskmovps(ya, ymask, this->ptr[src_ - 8]);
730     this->vfmadd231ps(ysum, ya, ya); // ysum <- ysum + ya^2+yb^2+yc^2+yd^2+ye^2
731 
732     this->mov(this->imm_addr64_, reinterpret_cast<size_t>(&mask[1]));
733     this->vmovups(ymask, this->ptr[this->imm_addr64_]);
734     this->vmaskmovps(yb, ymask, this->ptr[src_ - 4]);
735     this->vfmadd231ps(ysum, yb, yb);
736 
737     this->mov(c, J.C / 8 - 1);
738     Label lrn_loop;
739     this->L(lrn_loop);
740 
741     this->vmovups(yc, this->ptr[src_]);
742     this->vmovups(yd, this->ptr[src_ + 4]);
743     this->vmovups(ye, this->ptr[src_ + 8]);
744     this->vfmadd231ps(ysum, yc, yc);
745     this->vfmadd231ps(ysum, yd, yd);
746     this->vfmadd231ps(ysum, ye, ye);
747 
748     this->vmovups(ydst, ysum);
749     this->vfmadd132ps(ydst, yk_, valpha_); // ydst <- ysum*valpha_+yk_
750 
751     this->vmovaps(ybase, ydst);
752     if (pk_ != prop_kind::forward_inference)
753         this->vmovups(this->ptr[scratch_], ybase);
754     this->vmulps(ydst, ydst, ydst);
755     this->vmulps(ydst, ydst, ybase); // ydst = (ysum*valpha_+yk_)^3;
756     this->vsqrtps(ydst, ydst);
757     this->vsqrtps(ydst, ydst); // ydst = (ysum*valpha_+yk_)^0.75
758 
759     this->vdivps(ydst, yc, ydst); // ydst = ysrc / (ysum*valpha_+yk_)^0.75
760     this->vmovups(this->ptr[dst_], ydst);
761 
762     this->vxorps(ysum, ysum, ysum);
763 
764     this->add(src_, 32);
765     this->add(dst_, 32);
766     if (pk_ != prop_kind::forward_inference) this->add(scratch_, 32);
767 
768     this->vmovups(ya, this->ptr[src_ - 8]);
769     this->vfmadd231ps(ysum, ya, ya);
770     this->vmovups(yb, this->ptr[src_ - 4]);
771     this->vfmadd231ps(ysum, yb, yb);
772 
773     this->dec(c);
774     this->cmp(c, 0);
775     this->jne(lrn_loop, this->T_NEAR);
776 
777     this->vmovups(yc, this->ptr[src_]);
778     this->vfmadd231ps(ysum, yc, yc);
779 
780     this->mov(this->imm_addr64_, reinterpret_cast<size_t>(&mask[2]));
781     this->vmovups(ymask, this->ptr[this->imm_addr64_]);
782     this->vmaskmovps(yd, ymask, this->ptr[src_ + 4]);
783     this->vfmadd231ps(ysum, yd, yd); // ysum <- ysum + ya^2+yb^2+yc^2+yd^2+ye^2
784 
785     this->mov(this->imm_addr64_, reinterpret_cast<size_t>(&mask[3]));
786     this->vmovups(ymask, this->ptr[this->imm_addr64_]);
787     this->vmaskmovps(ye, ymask, this->ptr[src_ + 8]);
788     this->vfmadd231ps(ysum, ye, ye);
789 
790     this->vmovups(ydst, ysum);
791     this->vfmadd132ps(ydst, yk_, valpha_); // ydst <- ysum*valpha_+yk_
792 
793     this->vmovaps(ybase, ydst);
794     if (pk_ != prop_kind::forward_inference)
795         this->vmovups(this->ptr[scratch_], ybase);
796     this->vmulps(ydst, ydst, ydst);
797     this->vmulps(ydst, ydst, ybase); // ydst = (ysum*valpha_+yk_)^3;
798     this->vsqrtps(ydst, ydst);
799     this->vsqrtps(ydst, ydst); // ydst = (ysum*valpha_+yk_)^0.75
800     this->vdivps(ydst, yc, ydst); // ydst = ysrc / (ysum*valpha_+yk_)^0.75
801 
802     this->vmovups(this->ptr[dst_], ydst);
803 
804     this->postamble();
805 }
806 
807 template <>
808 jit_uni_lrn_fwd_kernel_t<sse41, dnnl::impl::data_type::f32>::
jit_uni_lrn_fwd_kernel_t(const struct nhwc_across_t & J,float A,float K,prop_kind_t pk,void * code_ptr,size_t code_size)809         jit_uni_lrn_fwd_kernel_t(const struct nhwc_across_t &J, float A,
810                 float K, prop_kind_t pk, void *code_ptr, size_t code_size)
811     : Base(code_ptr, code_size)
812     , config_(lrn_config_t::nhwc_across)
813     , nhwc_across_(J)
814     , alpha_(A)
815     , k_(K)
816     , pk_(pk) {}
817 
818 template <>
generate(const nhwc_across_t & J)819 void jit_uni_lrn_fwd_kernel_t<sse41, dnnl::impl::data_type::f32>::generate(
820         const nhwc_across_t &J) {
821     static uint32_t store[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
822     const Xbyak::Reg64 c = this->r9;
823 
824     const Xbyak::Xmm &xdst_lo = this->xmm0;
825     const Xbyak::Xmm &xdst_hi = this->xmm1;
826     const Xbyak::Xmm &xa_lo = this->xmm2;
827     const Xbyak::Xmm &xa_hi = this->xmm3;
828     const Xbyak::Xmm &xb_lo = this->xmm2;
829     const Xbyak::Xmm &xb_hi = this->xmm3;
830     const Xbyak::Xmm &xc_lo = this->xmm4;
831     const Xbyak::Xmm &xc_hi = this->xmm5;
832     const Xbyak::Xmm &xd_lo = this->xmm6;
833     const Xbyak::Xmm &xd_hi = this->xmm7;
834     const Xbyak::Xmm &xe_lo = this->xmm8;
835     const Xbyak::Xmm &xe_hi = this->xmm9;
836     const Xbyak::Xmm &xsum_lo = this->xmm10;
837     const Xbyak::Xmm &xsum_hi = this->xmm11;
838     // unused: xmm12, xmm13;
839     const Xbyak::Xmm &xbase_lo = this->xmm14;
840     const Xbyak::Xmm &xbase_hi = this->xmm15;
841 
842     this->preamble();
843 
844     this->mov(src_, this->ptr[this->param1 + 0]);
845     this->mov(dst_, this->ptr[this->param1 + 8]);
846     if (pk_ != prop_kind::forward_inference)
847         mov(scratch_, this->ptr[this->param1 + 16]);
848     this->mov(this->imm_addr64_, float2int(this->alpha_));
849     this->movq(xalpha_, this->imm_addr64_);
850     this->shufps(xalpha_, xalpha_, 0);
851 
852     this->mov(this->imm_addr64_, float2int(this->k_));
853     this->movq(xk_, this->imm_addr64_);
854     this->shufps(xk_, xk_, 0);
855 
856     this->mov(store_addr_, reinterpret_cast<size_t>(&store[0]));
857     this->and_(store_addr_, -15);
858     this->movups(this->ptr[store_addr_], xalpha_);
859     this->movups(this->ptr[store_addr_ + 4 * sizeof(float)], xk_);
860 
861     this->xorps(xsum_lo, xsum_lo);
862     this->xorps(xsum_hi, xsum_hi);
863 
864     /* load the 2 first blocks of channels
865      * block:         | -- low -- | -- hi --  |
866      * C:             [c1,c2,c3,c4,c5,c6,c7,c8]
867      * xa_lo << 2 [0,0,c1,c2]
868      * xa_hi                [c3,c4,c5,c6]
869      * xb_lo << 1   [0,c1,c2,c3]
870      * xb_hi                   [c4,c5,c6,c7]
871      *                | --  data  --     (...)
872      *                ^ memory boundary
873      */
874     this->movups(xa_lo, this->ptr[src_]);
875     this->movups(xa_hi, this->ptr[src_ + 2 * sizeof(float)]);
876     this->pslldq(xa_lo, 2 * sizeof(float));
877     this->mulps(xa_lo, xa_lo);
878     this->mulps(xa_hi, xa_hi);
879     this->addps(xsum_lo, xa_lo);
880     this->addps(xsum_hi, xa_hi); // xsum <- xsum + xa^2+xb^2+xc^2+xd^2+xe^2
881 
882     this->movups(xb_lo, this->ptr[src_]);
883     this->movups(xb_hi, this->ptr[src_ + 3 * sizeof(float)]);
884     this->pslldq(xb_lo, 1 * sizeof(float));
885     this->mulps(xb_lo, xb_lo);
886     this->mulps(xb_hi, xb_hi);
887     this->addps(xsum_lo, xb_lo);
888     this->addps(xsum_hi, xb_hi);
889 
890     this->mov(c, J.C / 8 - 1);
891     Label lrn_loop;
892     this->L(lrn_loop);
893 
894     this->movups(xc_lo, this->ptr[src_]);
895     this->movups(xc_hi, this->ptr[src_ + 4 * sizeof(float)]);
896     this->movups(xd_lo, this->ptr[src_ + 4]);
897     this->movups(xd_hi, this->ptr[src_ + 4 + 4 * sizeof(float)]);
898     this->movups(xe_lo, this->ptr[src_ + 8]);
899     this->movups(xe_hi, this->ptr[src_ + 8 + 4 * sizeof(float)]);
900     this->mulps(xc_lo, xc_lo);
901     this->mulps(xc_hi, xc_hi);
902     this->addps(xsum_lo, xc_lo);
903     this->addps(xsum_hi, xc_hi);
904     this->mulps(xd_lo, xd_lo);
905     this->mulps(xd_hi, xd_hi);
906     this->addps(xsum_lo, xd_lo);
907     this->addps(xsum_hi, xd_hi);
908     this->mulps(xe_lo, xe_lo);
909     this->mulps(xe_hi, xe_hi);
910     this->addps(xsum_lo, xe_lo);
911     this->addps(xsum_hi, xe_hi);
912 
913     this->movaps(xdst_lo, xsum_lo);
914     this->movaps(xdst_hi, xsum_hi);
915     // xdst <- xsum*xalpha_+xk_
916     this->mulps(xdst_lo, this->ptr[store_addr_]);
917     this->mulps(xdst_hi, this->ptr[store_addr_]);
918     this->addps(xdst_lo, this->ptr[store_addr_ + 4 * sizeof(float)]);
919     this->addps(xdst_hi, this->ptr[store_addr_ + 4 * sizeof(float)]);
920 
921     this->movaps(xbase_lo, xdst_lo);
922     this->movaps(xbase_hi, xdst_hi);
923     if (pk_ != prop_kind::forward_inference) {
924         this->movups(this->ptr[scratch_], xbase_lo);
925         this->movups(this->ptr[scratch_ + 4 * sizeof(float)], xbase_hi);
926     }
927     this->mulps(xdst_lo, xdst_lo);
928     this->mulps(xdst_hi, xdst_hi);
929     this->mulps(xdst_lo, xbase_lo);
930     this->mulps(xdst_hi, xbase_hi); // xdst = (xsum*xalpha_+xk_)^3;
931     this->sqrtps(xdst_lo, xdst_lo);
932     this->sqrtps(xdst_hi, xdst_hi);
933     this->sqrtps(xdst_lo, xdst_lo);
934     this->sqrtps(xdst_hi, xdst_hi); // xdst = (xsum*xalpha_+xk_)^0.75
935 
936     this->movups(xc_lo, this->ptr[src_]);
937     this->movups(xc_hi, this->ptr[src_ + 4 * sizeof(float)]);
938     this->divps(xc_lo, xdst_lo);
939     this->divps(xc_hi, xdst_hi); // xdst = xsrc / (xsum*xalpha_+xk_)^0.75
940     this->movups(this->ptr[dst_], xc_lo);
941     this->movups(this->ptr[dst_ + 4 * sizeof(float)], xc_hi);
942 
943     this->xorps(xsum_lo, xsum_lo);
944     this->xorps(xsum_hi, xsum_hi);
945 
946     this->add(src_, 32);
947     this->add(dst_, 32);
948     if (pk_ != prop_kind::forward_inference) this->add(scratch_, 32);
949 
950     this->movups(xa_lo, this->ptr[src_ - 8]);
951     this->movups(xa_hi, this->ptr[src_ - 8 + 4 * sizeof(float)]);
952     this->mulps(xa_lo, xa_lo);
953     this->mulps(xa_hi, xa_hi);
954     this->addps(xsum_lo, xa_lo);
955     this->addps(xsum_hi, xa_hi);
956     this->movups(xb_lo, this->ptr[src_ - 4]);
957     this->movups(xb_hi, this->ptr[src_ - 4 + 4 * sizeof(float)]);
958     this->mulps(xb_lo, xb_lo);
959     this->mulps(xb_hi, xb_hi);
960     this->addps(xsum_lo, xb_lo);
961     this->addps(xsum_hi, xb_hi);
962 
963     this->dec(c);
964     this->cmp(c, 0);
965     this->jne(lrn_loop, this->T_NEAR);
966 
967     /* compute last 3 blocks of channels:
968      * block:       | -- low -- | -- hi --  |
969      * C:           [c1,c2,c3,c4,c5,c6,c7,c8]
970      * xc_lo|xc_hi  [c1,c2,c3,c4|c5,c6,c7,c8]
971      * xd_lo           [c2,c3,c4,c5]
972      * xd_hi >> 1                  [c6,c7,c8, 0]
973      * xe_lo              [c3,c4,c5,c6]
974      * xe_hi >> 2                     [c7,c8, 0, 0]
975      *                  (...) --  data  --  | -- illegal reading -- (...)
976      *                                      ^ memory boundary
977      */
978     this->movups(xc_lo, this->ptr[src_]);
979     this->movups(xc_hi, this->ptr[src_ + 4 * sizeof(float)]);
980     this->mulps(xc_lo, xc_lo);
981     this->mulps(xc_hi, xc_hi);
982     this->addps(xsum_lo, xc_lo);
983     this->addps(xsum_hi, xc_hi);
984 
985     this->movups(xd_lo, this->ptr[src_ + 1 * sizeof(float)]);
986     this->movups(xd_hi, this->ptr[src_ + 4 * sizeof(float)]);
987     this->psrldq(xd_hi, 1 * sizeof(float));
988     this->mulps(xd_lo, xd_lo);
989     this->mulps(xd_hi, xd_hi);
990     this->addps(xsum_lo, xd_lo);
991     this->addps(xsum_hi, xd_hi); // xsum <- xsum + xa^2+xb^2+xc^2+xd^2+xe^2
992 
993     this->movups(xe_lo, this->ptr[src_ + 2 * sizeof(float)]);
994     this->movups(xe_hi, this->ptr[src_ + 4 * sizeof(float)]);
995     this->psrldq(xe_hi, 2 * sizeof(float));
996     this->mulps(xe_lo, xe_lo);
997     this->mulps(xe_hi, xe_hi);
998     this->addps(xsum_lo, xe_lo);
999     this->addps(xsum_hi, xe_hi);
1000 
1001     this->movups(xdst_lo, xsum_lo);
1002     this->movups(xdst_hi, xsum_hi);
1003     // xdst <- xsum*xalpha_+xk_
1004     this->mulps(xdst_lo, this->ptr[store_addr_]);
1005     this->mulps(xdst_hi, this->ptr[store_addr_]);
1006     this->addps(xdst_lo, this->ptr[store_addr_ + 4 * sizeof(float)]);
1007     this->addps(xdst_hi, this->ptr[store_addr_ + 4 * sizeof(float)]);
1008 
1009     this->movaps(xbase_lo, xdst_lo);
1010     this->movaps(xbase_hi, xdst_hi);
1011     if (pk_ != prop_kind::forward_inference) {
1012         this->movups(this->ptr[scratch_], xbase_lo);
1013         this->movups(this->ptr[scratch_ + 4 * sizeof(float)], xbase_hi);
1014     }
1015     this->mulps(xdst_lo, xdst_lo);
1016     this->mulps(xdst_hi, xdst_hi);
1017     this->mulps(xdst_lo, xbase_lo);
1018     this->mulps(xdst_hi, xbase_hi); // xdst = (xsum*xalpha_+xk_)^3;
1019     this->sqrtps(xdst_lo, xdst_lo);
1020     this->sqrtps(xdst_hi, xdst_hi);
1021     this->sqrtps(xdst_lo, xdst_lo);
1022     this->sqrtps(xdst_hi, xdst_hi); // xdst = (xsum*xalpha_+xk_)^0.75
1023     this->movups(xc_lo, this->ptr[src_]);
1024     this->movups(xc_hi, this->ptr[src_ + 4 * sizeof(float)]);
1025     this->divps(xc_lo, xdst_lo);
1026     this->divps(xc_hi, xdst_hi); // xdst = xsrc / (xsum*xalpha_+xk_)^0.75
1027 
1028     this->movups(this->ptr[dst_], xc_lo);
1029     this->movups(this->ptr[dst_ + 4 * sizeof(float)], xc_hi);
1030 
1031     this->postamble();
1032 }
1033 
1034 template <>
nchw_body(int tail,int HW,prop_kind_t pk,Xbyak::Ymm ymask,Xbyak::Ymm ya,Xbyak::Ymm yb,Xbyak::Ymm yc,Xbyak::Ymm yd,Xbyak::Ymm ye,Xbyak::Ymm ysum)1035 void jit_uni_lrn_fwd_kernel_t<sse41, dnnl::impl::data_type::f32>::nchw_body(
1036         int tail, int HW, prop_kind_t pk, Xbyak::Ymm ymask, Xbyak::Ymm ya,
1037         Xbyak::Ymm yb, Xbyak::Ymm yc, Xbyak::Ymm yd, Xbyak::Ymm ye,
1038         Xbyak::Ymm ysum) {}
1039 
1040 template <cpu_isa_t isa, data_type_t d_type>
nchw_body(int tail,int HW,prop_kind_t pk,Xbyak::Ymm ymask,Xbyak::Ymm ya,Xbyak::Ymm yb,Xbyak::Ymm yc,Xbyak::Ymm yd,Xbyak::Ymm ye,Xbyak::Ymm ysum)1041 void jit_uni_lrn_fwd_kernel_t<isa, d_type>::nchw_body(int tail, int HW,
1042         prop_kind_t pk, Xbyak::Ymm ymask, Xbyak::Ymm ya, Xbyak::Ymm yb,
1043         Xbyak::Ymm yc, Xbyak::Ymm yd, Xbyak::Ymm ye, Xbyak::Ymm ysum) {
1044     const Xbyak::Ymm &ydst = this->ymm14;
1045     const Xbyak::Ymm &ybase = this->ymm15;
1046 
1047     this->vfmadd231ps(ysum, ye, ye);
1048 
1049     this->vmovups(ydst, ysum);
1050     this->vfmadd132ps(ydst, yk_, valpha_); // ydst <- ysum*valpha_+yk_
1051 
1052     this->vmovaps(ybase, ydst);
1053     if (pk_ != prop_kind::forward_inference) {
1054         if (tail != 0)
1055             this->vmaskmovps(this->ptr[scratch_], ymask, ybase);
1056         else
1057             this->vmovups(this->ptr[scratch_], ybase);
1058     }
1059     this->vmulps(ydst, ydst, ydst);
1060     this->vmulps(ydst, ydst, ybase); // ydst = (ysum*valpha_+yk_)^3;
1061     this->vsqrtps(ydst, ydst);
1062     this->vsqrtps(ydst, ydst); // ydst = (ysum*valpha_+yk_)^0.75
1063     this->vdivps(ydst, yc, ydst); // ydst = ysrc / (ysum*valpha_+yk_)^0.75
1064 
1065     if (tail != 0)
1066         this->vmaskmovps(this->ptr[dst_], ymask, ydst);
1067     else
1068         this->vmovups(this->ptr[dst_], ydst);
1069 
1070     this->vfnmadd231ps(ysum, ya, ya);
1071     this->vmovups(ya, yb);
1072     this->vmovups(yb, yc);
1073     this->vmovups(yc, yd);
1074     this->vmovups(yd, ye);
1075 }
1076 
1077 template <cpu_isa_t isa, data_type_t d_type>
nchw_tail_sse41(int tail,Xbyak::Reg64 reg_dst,Xbyak::Xmm xtail_lo,Xbyak::Xmm xtail_hi)1078 void jit_uni_lrn_fwd_kernel_t<isa, d_type>::nchw_tail_sse41(int tail,
1079         Xbyak::Reg64 reg_dst, Xbyak::Xmm xtail_lo, Xbyak::Xmm xtail_hi) {}
1080 
1081 template <>
1082 void jit_uni_lrn_fwd_kernel_t<sse41,
nchw_tail_sse41(int tail,Xbyak::Reg64 reg_dst,Xbyak::Xmm xtail_lo,Xbyak::Xmm xtail_hi)1083         dnnl::impl::data_type::f32>::nchw_tail_sse41(int tail,
1084         Xbyak::Reg64 reg_dst, Xbyak::Xmm xtail_lo, Xbyak::Xmm xtail_hi) {
1085     Xbyak::Xmm xmm_tmp = xmm10;
1086     this->movaps(xmm_tmp, xtail_hi);
1087 
1088     if (tail > 3) {
1089         /* Store upper-half directly */
1090         this->movups(this->ptr[reg_dst + (tail - 4) * sizeof(float)], xtail_hi);
1091         this->movaps(xmm_tmp, xtail_lo);
1092         tail -= 4;
1093     }
1094     if (tail > 0) {
1095         /* Store on a single-element basis when 'tail' overlaps
1096          * with 'src_' */
1097         this->psrldq(xmm_tmp, (4 - tail) * sizeof(float));
1098         this->movss(this->ptr[reg_dst], xmm_tmp);
1099 
1100         for (int i = 1; i < tail; i++) {
1101             this->psrldq(xmm_tmp, sizeof(float));
1102             this->movss(this->ptr[reg_dst + i * sizeof(float)], xmm_tmp);
1103         }
1104     }
1105 }
1106 
1107 template <>
1108 void jit_uni_lrn_fwd_kernel_t<sse41,
nchw_body_sse41(int tail,int HW,prop_kind_t pk,Xbyak::Xmm xe_lo,Xbyak::Xmm xe_hi,Xbyak::Xmm xsum_lo,Xbyak::Xmm xsum_hi)1109         dnnl::impl::data_type::f32>::nchw_body_sse41(int tail, int HW,
1110         prop_kind_t pk, Xbyak::Xmm xe_lo, Xbyak::Xmm xe_hi, Xbyak::Xmm xsum_lo,
1111         Xbyak::Xmm xsum_hi) {
1112     const Xbyak::Xmm &xdst_lo = this->xmm0;
1113     const Xbyak::Xmm &xdst_hi = this->xmm1;
1114     const Xbyak::Xmm &xbase_lo = this->xmm6;
1115     const Xbyak::Xmm &xbase_hi = this->xmm7;
1116     const Xbyak::Xmm &xtmp_lo = this->xmm8;
1117     const Xbyak::Xmm &xtmp_hi = this->xmm9;
1118     const Xbyak::Xmm &xa_lo = this->xmm6;
1119     const Xbyak::Xmm &xa_hi = this->xmm7;
1120     const Xbyak::Xmm &xb_lo = this->xmm8;
1121     const Xbyak::Xmm &xb_hi = this->xmm9;
1122     const Xbyak::Xmm &xc_lo = this->xmm10;
1123     const Xbyak::Xmm &xc_hi = this->xmm11;
1124     const Xbyak::Xmm &xd_lo = this->xmm12;
1125     const Xbyak::Xmm &xd_hi = this->xmm13;
1126 
1127     // store xe
1128     this->movaps(this->ptr[store_addr_ + 10 * 4 * sizeof(float)], xe_lo);
1129     this->movaps(this->ptr[store_addr_ + 11 * 4 * sizeof(float)], xe_hi);
1130 
1131     this->mulps(xe_lo, xe_lo);
1132     this->mulps(xe_hi, xe_hi);
1133     this->addps(xsum_lo, xe_lo);
1134     this->addps(xsum_hi, xe_hi);
1135 
1136     // xdst <- xsum*xalpha_+xk_
1137     this->movaps(xdst_lo, xsum_lo);
1138     this->movaps(xdst_hi, xsum_hi);
1139     this->mulps(xdst_lo, this->ptr[store_addr_ + 0 * 4 * sizeof(float)]);
1140     this->mulps(xdst_hi, this->ptr[store_addr_ + 0 * 4 * sizeof(float)]);
1141     this->addps(xdst_lo, this->ptr[store_addr_ + 1 * 4 * sizeof(float)]);
1142     this->addps(xdst_hi, this->ptr[store_addr_ + 1 * 4 * sizeof(float)]);
1143 
1144     this->movaps(xbase_lo, xdst_lo);
1145     this->movaps(xbase_hi, xdst_hi);
1146     if (pk_ != prop_kind::forward_inference) {
1147         if (tail != 0) {
1148             nchw_tail_sse41(tail, scratch_, xbase_lo, xbase_hi);
1149         } else {
1150             this->movups(this->ptr[scratch_], xbase_lo);
1151             this->movups(this->ptr[scratch_ + 4 * sizeof(float)], xbase_hi);
1152         }
1153     }
1154     this->mulps(xdst_lo, xdst_lo);
1155     this->mulps(xdst_hi, xdst_hi);
1156     this->mulps(xdst_lo, xbase_lo);
1157     this->mulps(xdst_hi, xbase_hi); // xdst = (xsum*xalpha_+xk_)^3;
1158     this->sqrtps(xdst_lo, xdst_lo);
1159     this->sqrtps(xdst_hi, xdst_hi);
1160     this->sqrtps(xdst_lo, xdst_lo);
1161     this->sqrtps(xdst_hi, xdst_hi); // xdst = (xsum*xalpha_+xk_)^0.75
1162     this->movaps(xtmp_lo, this->ptr[store_addr_ + 6 * 4 * sizeof(float)]);
1163     this->movaps(xtmp_hi, this->ptr[store_addr_ + 7 * 4 * sizeof(float)]);
1164     this->divps(xtmp_lo, xdst_lo);
1165     this->divps(xtmp_hi, xdst_hi); // xdst = xsrc / (xsum*xalpha_+xk_)^0.75
1166     this->movaps(xdst_lo, xtmp_lo);
1167     this->movaps(xdst_hi, xtmp_hi);
1168 
1169     if (tail != 0) {
1170         nchw_tail_sse41(tail, dst_, xdst_lo, xdst_hi);
1171     } else {
1172         this->movups(this->ptr[dst_], xdst_lo);
1173         this->movups(this->ptr[dst_ + 4 * sizeof(float)], xdst_hi);
1174     }
1175 
1176     this->movaps(xa_lo, this->ptr[store_addr_ + 2 * 4 * sizeof(float)]);
1177     this->movaps(xa_hi, this->ptr[store_addr_ + 3 * 4 * sizeof(float)]);
1178     this->mulps(xa_lo, xa_lo);
1179     this->mulps(xa_hi, xa_hi);
1180     this->subps(xsum_lo, xa_lo);
1181     this->subps(xsum_hi, xa_hi);
1182 
1183     // xa <- xb
1184     this->movaps(xb_lo, this->ptr[store_addr_ + 4 * 4 * sizeof(float)]);
1185     this->movaps(xb_hi, this->ptr[store_addr_ + 5 * 4 * sizeof(float)]);
1186     this->movaps(this->ptr[store_addr_ + 2 * 4 * sizeof(float)], xb_lo);
1187     this->movaps(this->ptr[store_addr_ + 3 * 4 * sizeof(float)], xb_hi);
1188 
1189     // xb <- xc
1190     this->movaps(xc_lo, this->ptr[store_addr_ + 6 * 4 * sizeof(float)]);
1191     this->movaps(xc_hi, this->ptr[store_addr_ + 7 * 4 * sizeof(float)]);
1192     this->movaps(this->ptr[store_addr_ + 4 * 4 * sizeof(float)], xc_lo);
1193     this->movaps(this->ptr[store_addr_ + 5 * 4 * sizeof(float)], xc_hi);
1194 
1195     // xc <- xd
1196     this->movaps(xd_lo, this->ptr[store_addr_ + 8 * 4 * sizeof(float)]);
1197     this->movaps(xd_hi, this->ptr[store_addr_ + 9 * 4 * sizeof(float)]);
1198     this->movaps(this->ptr[store_addr_ + 6 * 4 * sizeof(float)], xd_lo);
1199     this->movaps(this->ptr[store_addr_ + 7 * 4 * sizeof(float)], xd_hi);
1200 
1201     // xd <- xe
1202     this->movaps(xe_lo, this->ptr[store_addr_ + 10 * 4 * sizeof(float)]);
1203     this->movaps(xe_hi, this->ptr[store_addr_ + 11 * 4 * sizeof(float)]);
1204     this->movaps(this->ptr[store_addr_ + 8 * 4 * sizeof(float)], xe_lo);
1205     this->movaps(this->ptr[store_addr_ + 9 * 4 * sizeof(float)], xe_hi);
1206 }
1207 
1208 template <cpu_isa_t isa, data_type_t d_type>
nchw_body_sse41(int tail,int HW,prop_kind_t pk,Xbyak::Xmm xe_lo,Xbyak::Xmm xe_hi,Xbyak::Xmm xsum_lo,Xbyak::Xmm xsum_hi)1209 void jit_uni_lrn_fwd_kernel_t<isa, d_type>::nchw_body_sse41(int tail, int HW,
1210         prop_kind_t pk, Xbyak::Xmm xe_lo, Xbyak::Xmm xe_hi, Xbyak::Xmm xsum_lo,
1211         Xbyak::Xmm xsum_hi) {}
1212 
1213 template <cpu_isa_t isa, data_type_t d_type>
jit_uni_lrn_fwd_kernel_t(const nchw_across_t & J,float A,float K,prop_kind_t pk,void * code_ptr,size_t code_size)1214 jit_uni_lrn_fwd_kernel_t<isa, d_type>::jit_uni_lrn_fwd_kernel_t(
1215         const nchw_across_t &J, float A, float K, prop_kind_t pk,
1216         void *code_ptr, size_t code_size)
1217     : Base(code_ptr, code_size)
1218     , config_(lrn_config_t::nchw_across)
1219     , nchw_across_(J)
1220     , alpha_(A)
1221     , k_(K)
1222     , pk_(pk) {}
1223 
1224 template <cpu_isa_t isa, data_type_t d_type>
generate(const nchw_across_t & J)1225 void jit_uni_lrn_fwd_kernel_t<isa, d_type>::generate(const nchw_across_t &J) {
1226     static const uint32_t mask[]
1227             = {0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000,
1228                     0x80000000, 0x80000000, 0, 0, 0, 0, 0, 0, 0};
1229     const Xbyak::Reg64 &c = this->r10;
1230     const Xbyak::Ymm &ymask = this->ymm2;
1231     const Xbyak::Ymm &ye = this->ymm3;
1232     const Xbyak::Ymm &ya = this->ymm4;
1233     const Xbyak::Ymm &yb = this->ymm5;
1234     const Xbyak::Ymm &yc = this->ymm6;
1235     const Xbyak::Ymm &yd = this->ymm7;
1236     const Xbyak::Ymm &ysum = this->ymm8;
1237 
1238     this->preamble();
1239 
1240     if (J.tail != 0) {
1241         this->mov(
1242                 this->imm_addr64_, reinterpret_cast<size_t>(&mask[7 - J.tail]));
1243         this->vmovups(ymask, this->ptr[this->imm_addr64_]);
1244     }
1245     this->mov(this->imm_addr64_, float2int(this->alpha_));
1246     this->vmovq(xalpha_, this->imm_addr64_);
1247     this->vbroadcastss(valpha_, xalpha_);
1248 
1249     this->mov(this->imm_addr64_, float2int(this->k_));
1250     this->vmovq(xk_, this->imm_addr64_);
1251     this->vbroadcastss(yk_, xk_);
1252 
1253     this->mov(src_, this->ptr[this->param1 + 0]);
1254     this->mov(dst_, this->ptr[this->param1 + 8]);
1255     if (pk_ != prop_kind::forward_inference)
1256         this->mov(scratch_, this->ptr[this->param1 + 16]);
1257 
1258     this->vxorps(ya, ya, ya);
1259     this->vxorps(yb, yb, yb);
1260     if (J.tail != 0)
1261         this->vmaskmovps(yc, ymask, this->ptr[src_ + J.HW * 0]);
1262     else
1263         this->vmovups(yc, this->ptr[src_ + J.HW * 0]);
1264     if (J.tail != 0)
1265         this->vmaskmovps(yd, ymask, this->ptr[src_ + J.HW * 4]);
1266     else
1267         this->vmovups(yd, this->ptr[src_ + J.HW * 4]);
1268 
1269     this->vxorps(ysum, ysum, ysum);
1270     this->vfmadd231ps(ysum, yc, yc); // ysum <- ysum + ya^2+yb^2+yc^2+yd^2+ye^2
1271     this->vfmadd231ps(ysum, yd, yd);
1272 
1273     this->mov(c, J.C - 2);
1274     Label lrn_loop;
1275     this->L(lrn_loop);
1276 
1277     if (J.tail != 0)
1278         this->vmaskmovps(ye, ymask, this->ptr[src_ + J.HW * 8]);
1279     else
1280         this->vmovups(ye, this->ptr[src_ + J.HW * 8]);
1281 
1282     nchw_body(J.tail, J.HW, pk_, ymask, ya, yb, yc, yd, ye, ysum);
1283 
1284     this->add(src_, J.HW * 4);
1285     this->add(dst_, J.HW * 4);
1286     if (pk_ != prop_kind::forward_inference) this->add(scratch_, J.HW * 4);
1287     this->dec(c);
1288     this->cmp(c, 0);
1289     this->jne(lrn_loop, this->T_NEAR);
1290 
1291     this->vxorps(ye, ye, ye);
1292 
1293     nchw_body(J.tail, J.HW, pk_, ymask, ya, yb, yc, yd, ye, ysum);
1294     this->add(src_, J.HW * 4);
1295     this->add(dst_, J.HW * 4);
1296     if (pk_ != prop_kind::forward_inference) this->add(scratch_, J.HW * 4);
1297 
1298     nchw_body(J.tail, J.HW, pk_, ymask, ya, yb, yc, yd, ye, ysum);
1299 
1300     this->postamble();
1301 }
1302 
1303 template <cpu_isa_t isa, data_type_t d_type>
1304 jit_uni_lrn_fwd_kernel_t<isa, d_type>::~jit_uni_lrn_fwd_kernel_t() = default;
1305 
1306 template <>
1307 jit_uni_lrn_fwd_kernel_t<sse41, dnnl::impl::data_type::f32>::
jit_uni_lrn_fwd_kernel_t(const nchw_across_t & J,float A,float K,prop_kind_t pk,void * code_ptr,size_t code_size)1308         jit_uni_lrn_fwd_kernel_t(const nchw_across_t &J, float A, float K,
1309                 prop_kind_t pk, void *code_ptr, size_t code_size)
1310     : Base(code_ptr, code_size)
1311     , config_(lrn_config_t::nchw_across)
1312     , nchw_across_(J)
1313     , alpha_(A)
1314     , k_(K)
1315     , pk_(pk) {}
1316 
1317 template <>
generate(const nchw_across_t & J)1318 void jit_uni_lrn_fwd_kernel_t<sse41, dnnl::impl::data_type::f32>::generate(
1319         const nchw_across_t &J) {
1320 
1321     /* Load from within the memory boundary of 'src_' and apply a zero-mask to
1322      * the 'x_hi' register:
1323      *  block:       src_  |tail = 3
1324      *  src_:      [x,x,x,x|a,b,c]
1325      *  x_hi:           [x,a,b,c]
1326      *  mask:           [0,1,1,1]
1327      *      (...) --  data  --  | -- illegal reading -- (...)
1328      *                          ^ memory boundary
1329      *
1330      * 'x_lo' is loaded with the elements between 'src_' and 'x_hi' when
1331      * tail.size is between [5:7]. The register is then left-shifted to
1332      * clear the overlapping elements with 'x_hi'.
1333      *  block: - src_ - |  tail = 7
1334      *  src_:  (...) [x,|a,b,c,d,e,f,g]
1335      *  x_hi                 [d,e,f,g]
1336      *  x_lo           [a,b,c,d]
1337      *    x_lo >> 1: [0,a,b,c]
1338      *           (...) --  data  --  | -- illegal reading -- (...)
1339      *                               ^ memory boundary
1340      *
1341      *  - seg-fault happens if read occurs anywhere outside the
1342      *  memory boundary.
1343      * */
1344     static const uint32_t mask[]
1345             = {0, 0, 0, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff};
1346     assert(J.HW > 3);
1347 
1348     const Xbyak::Reg64 &c = r10;
1349 
1350     // unused: xmm2
1351     const Xbyak::Xmm &xmask_hi = this->xmm3;
1352     const Xbyak::Xmm &xsum_lo = this->xmm4;
1353     const Xbyak::Xmm &xsum_hi = this->xmm5;
1354     const Xbyak::Xmm &xa_lo = this->xmm6;
1355     const Xbyak::Xmm &xa_hi = this->xmm7;
1356     const Xbyak::Xmm &xb_lo = this->xmm8;
1357     const Xbyak::Xmm &xb_hi = this->xmm9;
1358     const Xbyak::Xmm &xc_lo = this->xmm10;
1359     const Xbyak::Xmm &xc_hi = this->xmm11;
1360     const Xbyak::Xmm &xd_lo = this->xmm12;
1361     const Xbyak::Xmm &xd_hi = this->xmm13;
1362     const Xbyak::Xmm &xe_lo = this->xmm14;
1363     const Xbyak::Xmm &xe_hi = this->xmm15;
1364 
1365     const int vlen = cpu_isa_traits<sse41>::vlen / sizeof(float);
1366 
1367     bool compute_tail = J.tail != 0;
1368     bool load_lo = J.tail == 0 || J.tail > 4;
1369 
1370     size_t h_offset = vlen;
1371     size_t l_shift = 0;
1372 
1373     this->preamble();
1374 
1375     this->mov(src_, this->ptr[this->param1 + 0]);
1376     this->mov(dst_, this->ptr[this->param1 + 8]);
1377     if (pk_ != prop_kind::forward_inference)
1378         this->mov(scratch_, this->ptr[this->param1 + 16]);
1379 
1380     this->sub(rsp, stack_space_needed_);
1381     this->mov(store_addr_, rsp);
1382     this->and_(store_addr_, -15);
1383 
1384     this->mov(this->imm_addr64_, float2int(this->alpha_));
1385     this->movq(xalpha_, this->imm_addr64_);
1386     this->shufps(xalpha_, xalpha_, 0);
1387 
1388     this->mov(this->imm_addr64_, float2int(this->k_));
1389     this->movq(xk_, this->imm_addr64_);
1390     this->shufps(xk_, xk_, 0);
1391 
1392     // put alpha_ and k_ into store (free up regs)
1393     this->movaps(this->ptr[store_addr_ + 0 * 4 * sizeof(float)], xalpha_);
1394     this->movaps(this->ptr[store_addr_ + 1 * 4 * sizeof(float)], xk_);
1395 
1396     if (compute_tail) {
1397         assert(J.tail > 0 && J.tail < 2 * vlen);
1398         h_offset = J.tail - vlen;
1399         l_shift = nstl::min(2 * vlen - J.tail, vlen);
1400 
1401         /* if 'tail' is between [1:3], need to zero-mask for underflow */
1402         size_t m_off = nstl::min(J.tail - 1, 3);
1403         this->mov(this->imm_addr64_, reinterpret_cast<size_t>(&mask[m_off]));
1404         this->movups(xmask_hi, this->ptr[this->imm_addr64_]);
1405     }
1406     // init xa, xb
1407     this->xorps(xa_lo, xa_lo);
1408     this->xorps(xa_hi, xa_hi);
1409     this->xorps(xb_lo, xb_lo);
1410     this->xorps(xb_hi, xb_hi);
1411 
1412     // read xc, xd
1413     if (load_lo) this->movups(xc_lo, this->ptr[src_ + J.HW * 0]);
1414     this->movups(xc_hi, this->ptr[src_ + J.HW * 0 + h_offset * sizeof(float)]);
1415     if (compute_tail) {
1416         this->pslldq(xc_lo, l_shift * sizeof(float));
1417         this->andps(xc_hi, xmask_hi);
1418     }
1419 
1420     if (load_lo) this->movups(xd_lo, this->ptr[src_ + J.HW * 4]);
1421     this->movups(xd_hi, this->ptr[src_ + J.HW * 4 + h_offset * sizeof(float)]);
1422     if (compute_tail) {
1423         this->pslldq(xd_lo, l_shift * sizeof(float));
1424         this->andps(xd_hi, xmask_hi);
1425     }
1426 
1427     // put xa, xb, xc, xd into store to free-up regs
1428     this->movaps(this->ptr[store_addr_ + 2 * 4 * sizeof(float)], xa_lo);
1429     this->movaps(this->ptr[store_addr_ + 3 * 4 * sizeof(float)], xa_hi);
1430     this->movaps(this->ptr[store_addr_ + 4 * 4 * sizeof(float)], xb_lo);
1431     this->movaps(this->ptr[store_addr_ + 5 * 4 * sizeof(float)], xb_hi);
1432     this->movaps(this->ptr[store_addr_ + 6 * 4 * sizeof(float)], xc_lo);
1433     this->movaps(this->ptr[store_addr_ + 7 * 4 * sizeof(float)], xc_hi);
1434     this->movaps(this->ptr[store_addr_ + 8 * 4 * sizeof(float)], xd_lo);
1435     this->movaps(this->ptr[store_addr_ + 9 * 4 * sizeof(float)], xd_hi);
1436 
1437     this->xorps(xsum_lo, xsum_lo);
1438     this->xorps(xsum_hi, xsum_hi);
1439     this->mulps(xc_lo, xc_lo);
1440     this->mulps(xc_hi, xc_hi);
1441     this->addps(xsum_lo, xc_lo);
1442     this->addps(xsum_hi, xc_hi);
1443     this->mulps(xd_lo, xd_lo);
1444     this->mulps(xd_hi, xd_hi);
1445     this->addps(xsum_lo, xd_lo);
1446     this->addps(xsum_hi, xd_hi); // xsum <- xsum + xa^2+xb^2+xc^2+xd^2+xe^2
1447 
1448     this->mov(c, J.C - 2);
1449     Label lrn_loop;
1450     this->L(lrn_loop);
1451 
1452     if (load_lo) this->movups(xe_lo, this->ptr[src_ + J.HW * 8]);
1453     this->movups(xe_hi, this->ptr[src_ + J.HW * 8 + h_offset * sizeof(float)]);
1454     if (compute_tail) {
1455         this->pslldq(xe_lo, l_shift * sizeof(float));
1456         this->andps(xe_hi, xmask_hi);
1457     }
1458 
1459     nchw_body_sse41(J.tail, J.HW, pk_, xe_lo, xe_hi, xsum_lo, xsum_hi);
1460 
1461     this->add(src_, J.HW * 4);
1462     this->add(dst_, J.HW * 4);
1463     if (pk_ != prop_kind::forward_inference) add(scratch_, J.HW * 4);
1464     this->dec(c);
1465     this->cmp(c, 0);
1466     this->jne(lrn_loop, this->T_NEAR);
1467 
1468     this->xorps(xe_lo, xe_lo);
1469     this->xorps(xe_hi, xe_hi);
1470 
1471     nchw_body_sse41(J.tail, J.HW, pk_, xe_lo, xe_hi, xsum_lo, xsum_hi);
1472     this->add(src_, J.HW * 4);
1473     this->add(dst_, J.HW * 4);
1474     if (pk_ != prop_kind::forward_inference) add(scratch_, J.HW * 4);
1475 
1476     nchw_body_sse41(J.tail, J.HW, pk_, xe_lo, xe_hi, xsum_lo, xsum_hi);
1477 
1478     this->add(rsp, stack_space_needed_);
1479 
1480     this->postamble();
1481 }
1482 
1483 //////////////////////////////////////////////////////////////////////////////
1484 // backward kernel
1485 template <cpu_isa_t isa, data_type_t d_type>
jit_uni_lrn_bwd_kernel_t(const nchw8c_across_t & J,float A,float B,int use_h_parallel,void * code_ptr,size_t code_size)1486 jit_uni_lrn_bwd_kernel_t<isa, d_type>::jit_uni_lrn_bwd_kernel_t(
1487         const nchw8c_across_t &J, float A, float B, int use_h_parallel,
1488         void *code_ptr, size_t code_size)
1489     : Base(code_ptr, code_size)
1490     , config_(lrn_config_t::nchw8c_across)
1491     , nchw8c_across_(J)
1492     , nalphabeta_(-2 * A * B)
1493     , use_h_parallelizm_(use_h_parallel) {}
1494 
1495 template <cpu_isa_t isa, data_type_t d_type>
generate(const nchw8c_across_t & J)1496 void jit_uni_lrn_bwd_kernel_t<isa, d_type>::generate(const nchw8c_across_t &J) {
1497 
1498     const Xbyak::Reg64 &t = this->rsp;
1499     const Xbyak::Reg64 &hw = this->r10;
1500     const Xbyak::Xmm &xsrc_prev = this->xmm1;
1501     const Xbyak::Xmm &xws_prev = this->xmm2;
1502     const Xbyak::Xmm &xdiffdst_prev = this->xmm3;
1503     const Xbyak::Ymm &ysrc = this->ymm4;
1504     const Xbyak::Ymm &yws = this->ymm5;
1505     const Xbyak::Ymm &ydiffdst = this->ymm6;
1506     const Xbyak::Xmm &xsrc_next = this->xmm7;
1507     const Xbyak::Xmm &xws_next = this->xmm8;
1508     const Xbyak::Xmm &xdiffdst_next = this->xmm9;
1509     const Xbyak::Ymm &ya = this->ymm10;
1510     const Xbyak::Xmm &xa = this->xmm10;
1511     const Xbyak::Ymm &yb = this->ymm11;
1512     const Xbyak::Ymm &yd = this->ymm12;
1513     const Xbyak::Ymm &ye = this->ymm13;
1514     const Xbyak::Ymm &ysum = this->ymm14;
1515     const Xbyak::Ymm &ydiffsrc = this->ymm15;
1516 
1517     this->preamble();
1518 
1519 #define GET_OFF(field) offsetof(jit_args_bwd_t, field)
1520     this->mov(src_, this->ptr[this->param1 + GET_OFF(src)]);
1521     this->mov(diffdst_, this->ptr[this->param1 + GET_OFF(diff_dst)]);
1522     this->mov(scratch_, this->ptr[this->param1 + GET_OFF(scratch)]);
1523     this->mov(bwd_intermediate_res_,
1524             this->ptr[this->param1 + GET_OFF(bwd_intermediate_res)]);
1525     this->mov(diffsrc_, this->ptr[this->param1 + GET_OFF(diff_src)]);
1526 #undef GET_OFF
1527 
1528     this->sub(t, 64);
1529     this->mov(this->imm_addr64_, float2int(this->nalphabeta_));
1530     this->vmovq(xnalphabeta_, this->imm_addr64_);
1531     this->vbroadcastss(vnalphabeta_, xnalphabeta_);
1532 
1533     bool is_single = J.version == 3;
1534     bool is_first = J.version == -1 || J.version == -2;
1535     bool is_last = J.version == +1 || J.version == -2;
1536 
1537     if (is_first || is_single) {
1538         this->vxorps(xsrc_prev, xsrc_prev, xsrc_prev);
1539         this->vmovups(this->ptr[t + 0], xsrc_prev);
1540     }
1541     if (is_last || is_single) {
1542         this->vxorps(xsrc_next, xsrc_next, xsrc_next);
1543         this->vmovups(this->ptr[t + 48], xsrc_next);
1544     }
1545     this->mov(hw, this->use_h_parallelizm_ ? J.W : J.H * J.W);
1546     Label lrn_loop;
1547     this->L(lrn_loop);
1548     {
1549         if (!is_first && !is_single) {
1550             this->vmovups(xws_prev, this->ptr[scratch_ - J.H * J.W * 32 + 16]);
1551             this->vmovups(xsrc_prev, this->ptr[src_ - J.H * J.W * 32 + 16]);
1552             this->vmovups(
1553                     xdiffdst_prev, this->ptr[diffdst_ - J.H * J.W * 32 + 16]);
1554             this->vmulps(xa, xws_prev, xws_prev);
1555             this->vmulps(xa, xa, xws_prev);
1556             this->vsqrtps(xa, xa);
1557             this->vsqrtps(xa, xa);
1558             this->vmulps(xa, xa, xws_prev);
1559             this->vdivps(xsrc_prev, xsrc_prev, xa);
1560             this->vmulps(xdiffdst_prev, xdiffdst_prev, xsrc_prev);
1561         }
1562 
1563         this->vmovups(ysrc, this->ptr[src_]);
1564         this->vmovups(yws, this->ptr[scratch_]);
1565         this->vmovups(ydiffdst, this->ptr[diffdst_]);
1566         this->vmulps(ya, yws, yws);
1567         this->vmulps(ya, ya, yws);
1568         this->vsqrtps(ya, ya);
1569         this->vsqrtps(ya, ya);
1570         this->vdivps(ydiffsrc, ydiffdst, ya);
1571         this->vdivps(ysum, ydiffsrc, yws);
1572         this->vmulps(ysum, ysum, ysrc);
1573 
1574         if (!is_last && !is_single) {
1575             this->vmovups(xws_next, this->ptr[scratch_ + J.H * J.W * 32]);
1576             this->vmovups(xsrc_next, this->ptr[src_ + J.H * J.W * 32]);
1577             this->vmovups(xdiffdst_next, this->ptr[diffdst_ + J.H * J.W * 32]);
1578             this->vmulps(xa, xws_next, xws_next);
1579             this->vmulps(xa, xa, xws_next);
1580             this->vsqrtps(xa, xa);
1581             this->vsqrtps(xa, xa);
1582             this->vmulps(xa, xa, xws_next);
1583             this->vdivps(xsrc_next, xsrc_next, xa);
1584             this->vmulps(xdiffdst_next, xdiffdst_next, xsrc_next);
1585         }
1586 
1587         if (!is_first && !is_single)
1588             this->vmovups(this->ptr[t + 0], xdiffdst_prev);
1589         this->vmovups(this->ptr[t + 16], ysum);
1590         if (!is_last && !is_single)
1591             this->vmovups(this->ptr[t + 48], xdiffdst_next);
1592 
1593         this->vmovups(ya, this->ptr[t + 16 - 8]);
1594         this->vmovups(yb, this->ptr[t + 16 - 4]);
1595         this->vaddps(ysum, ysum, ya);
1596         this->vmulps(ysrc, ysrc, vnalphabeta_);
1597         this->vaddps(ysum, ysum, yb);
1598 
1599         this->vmovups(yd, this->ptr[t + 16 + 4]);
1600         this->vmovups(ye, this->ptr[t + 16 + 8]);
1601         this->vaddps(ysum, ysum, yd);
1602         this->vaddps(ysum, ysum, ye);
1603 
1604         this->vfmadd231ps(ydiffsrc, ysum, ysrc);
1605 
1606         this->vmovups(this->ptr[diffsrc_], ydiffsrc);
1607 
1608         this->add(src_, 32);
1609         this->add(diffsrc_, 32);
1610         this->add(diffdst_, 32);
1611         this->add(scratch_, 32);
1612 
1613         this->dec(hw);
1614         this->cmp(hw, 0);
1615         this->jne(lrn_loop, this->T_NEAR);
1616     }
1617 
1618     this->add(t, 64);
1619     this->postamble();
1620 }
1621 
1622 template <cpu_isa_t isa, data_type_t d_type>
jit_uni_lrn_bwd_kernel_t(const within_config_t & config,float A,float B,void * code_ptr,size_t code_size)1623 jit_uni_lrn_bwd_kernel_t<isa, d_type>::jit_uni_lrn_bwd_kernel_t(
1624         const within_config_t &config, float A, float B, void *code_ptr,
1625         size_t code_size)
1626     : Base(config, code_ptr, code_size)
1627     , config_(lrn_config_t::within_config)
1628     , within_config_(config)
1629     , nalphabeta_(-2.0f * A * B) {}
1630 
1631 template <cpu_isa_t isa, data_type_t d_type>
generate(const within_config_t & config)1632 void jit_uni_lrn_bwd_kernel_t<isa, d_type>::generate(
1633         const within_config_t &config) {
1634 
1635     this->preamble();
1636 
1637 #define GET_OFF(field) offsetof(jit_args_bwd_t, field)
1638     this->mov(src_, this->ptr[this->param1 + GET_OFF(src)]);
1639     this->mov(diffdst_, this->ptr[this->param1 + GET_OFF(diff_dst)]);
1640     this->mov(scratch_, this->ptr[this->param1 + GET_OFF(scratch)]);
1641     this->mov(bwd_intermediate_res_,
1642             this->ptr[this->param1 + GET_OFF(bwd_intermediate_res)]);
1643     this->mov(diffsrc_, this->ptr[this->param1 + GET_OFF(diff_src)]);
1644 #undef GET_OFF
1645     this->load_constant(nalphabeta_, vnalphabeta_, xnalphabeta_);
1646 
1647     static const int max_reg_blocks = isa == avx512_common ? 3 : 1;
1648     this->within_loop(config, max_reg_blocks, prop_kind::backward);
1649 
1650     this->postamble();
1651 }
1652 
1653 template <cpu_isa_t isa, data_type_t d_type>
within_body(int hoff,int Hoff,int woff,int Woff,int stride,prop_kind_t pk,const int reg_block,int pixel_offset)1654 void jit_uni_lrn_bwd_kernel_t<isa, d_type>::within_body(int hoff, int Hoff,
1655         int woff, int Woff, int stride, prop_kind_t pk, const int reg_block,
1656         int pixel_offset) {
1657 
1658     static const std::array<Vmm, 3> vsum {{Vmm(1), Vmm(9), Vmm(18)}};
1659     static const std::array<std::array<Vmm, 3>, 3> diff_dst {{
1660             {{Vmm(2), Vmm(3), Vmm(6)}},
1661             {{Vmm(10), Vmm(11), Vmm(23)}},
1662             {{Vmm(19), Vmm(20), Vmm(26)}},
1663     }};
1664     static const std::array<std::array<Vmm, 3>, 3> ws1 {{
1665             {{Vmm(4), Vmm(5), Vmm(15)}},
1666             {{Vmm(12), Vmm(13), Vmm(27)}},
1667             {{Vmm(21), Vmm(22), Vmm(28)}},
1668     }};
1669     static const std::array<Vmm, 3> ws0 = !this->emulate_bfloat_
1670             ? std::array<Vmm, 3> {{Vmm(29), Vmm(30), Vmm(31)}}
1671             : std::array<Vmm, 3> {{Vmm(6), Vmm(15), Vmm(23)}};
1672     static const std::array<Vmm, 3> src {{Vmm(7), Vmm(16), Vmm(24)}};
1673     static const std::array<Vmm, 3> a {{Vmm(8), Vmm(17), Vmm(25)}};
1674 
1675     static const std::size_t used_tmp_regs
1676             = this->emulate_bfloat_ ? ws1[0].size() - 1 : ws1[0].size();
1677 
1678     IRB_LOOP(this->uni_vxorps(vsum[irb], vsum[irb], vsum[irb]));
1679     for (int i = hoff; i <= Hoff; ++i) {
1680         for (int j = woff; j <= Woff; ++j) {
1681             const auto idx = this->tempIdx_ % used_tmp_regs;
1682             IRB_LOOP(this->load_data(diff_dst[irb][idx],
1683                     this->ptr[(diffdst_ + pixel_offset + irb_off)
1684                             + (i * stride + j) * this->single_pixel_offset_]));
1685             IRB_LOOP(this->load_data(ws1[irb][idx],
1686                     this->ptr[(bwd_intermediate_res_ + pixel_offset + irb_off)
1687                             + (i * stride + j) * this->single_pixel_offset_]));
1688 
1689             if (i == 0 && j == 0) {
1690                 if (d_type == dnnl::impl::data_type::bf16) {
1691                     IRB_LOOP(this->load_data(ws0[irb],
1692                             this->ptr[(scratch_ + pixel_offset + irb_off)]));
1693                     IRB_LOOP(
1694                             this->vdivps(a[irb], diff_dst[irb][idx], ws0[irb]));
1695                 } else {
1696                     IRB_LOOP(this->vdivps(a[irb], diff_dst[irb][idx],
1697                             this->ptr[(scratch_ + pixel_offset + irb_off)]));
1698                 }
1699             }
1700 
1701             IRB_LOOP(this->vfmadd231ps(
1702                     vsum[irb], ws1[irb][idx], diff_dst[irb][idx]));
1703             ++(this->tempIdx_);
1704         }
1705     }
1706 
1707     this->tempIdx_ = this->tempIdx_ % used_tmp_regs;
1708 
1709     if (d_type == dnnl::impl::data_type::bf16) {
1710         IRB_LOOP(this->load_data(
1711                 src[irb], this->ptr[(src_ + pixel_offset + irb_off)]));
1712         IRB_LOOP(this->vmulps(src[irb], this->vnalphabeta_, src[irb]));
1713     } else {
1714         IRB_LOOP(this->vmulps(src[irb], this->vnalphabeta_,
1715                 this->ptr[(src_ + pixel_offset + irb_off)]));
1716     }
1717 
1718     IRB_LOOP(this->vfmadd231ps(a[irb], src[irb], vsum[irb]));
1719 
1720     IRB_LOOP(this->store_data(
1721             this->ptr[diffsrc_ + pixel_offset + irb_off], a[irb]));
1722 
1723     if (isa == avx512_common)
1724         this->reg_block_idx_ = (this->reg_block_idx_ % vsum.size()) + 1;
1725 }
1726 
1727 template <cpu_isa_t isa, data_type_t d_type>
move_data_pointers(int pixel_count,prop_kind_t pk)1728 void jit_uni_lrn_bwd_kernel_t<isa, d_type>::move_data_pointers(
1729         int pixel_count, prop_kind_t pk) {
1730     const int pixel_offset = this->single_pixel_offset_ * pixel_count;
1731     this->add(src_, pixel_offset);
1732     this->add(diffsrc_, pixel_offset);
1733     this->add(diffdst_, pixel_offset);
1734     this->add(scratch_, pixel_offset);
1735     this->add(bwd_intermediate_res_, pixel_offset);
1736 }
1737 
1738 template class jit_uni_lrn_fwd_kernel_t<sse41, dnnl::impl::data_type::f32>;
1739 template class jit_uni_lrn_fwd_kernel_t<avx2, dnnl::impl::data_type::f32>;
1740 template class jit_uni_lrn_fwd_kernel_t<avx512_common,
1741         dnnl::impl::data_type::f32>;
1742 template class jit_uni_lrn_fwd_kernel_t<avx512_common,
1743         dnnl::impl::data_type::bf16>;
1744 
1745 template class jit_uni_lrn_kernel_t<
1746         jit_uni_lrn_fwd_kernel_t<sse41, dnnl::impl::data_type::f32>>;
1747 template class jit_uni_lrn_kernel_t<
1748         jit_uni_lrn_fwd_kernel_t<avx2, dnnl::impl::data_type::f32>>;
1749 template class jit_uni_lrn_kernel_t<
1750         jit_uni_lrn_fwd_kernel_t<avx512_common, dnnl::impl::data_type::f32>>;
1751 template class jit_uni_lrn_kernel_t<
1752         jit_uni_lrn_fwd_kernel_t<avx512_common, dnnl::impl::data_type::bf16>>;
1753 
1754 template class jit_uni_lrn_bwd_kernel_t<avx512_common,
1755         dnnl::impl::data_type::f32>;
1756 template class jit_uni_lrn_bwd_kernel_t<avx512_common,
1757         dnnl::impl::data_type::bf16>;
1758 template class jit_uni_lrn_bwd_kernel_t<avx2, dnnl::impl::data_type::f32>;
1759 
1760 template class jit_uni_lrn_kernel_t<
1761         jit_uni_lrn_bwd_kernel_t<avx2, dnnl::impl::data_type::f32>>;
1762 template class jit_uni_lrn_kernel_t<
1763         jit_uni_lrn_bwd_kernel_t<avx512_common, dnnl::impl::data_type::f32>>;
1764 template class jit_uni_lrn_kernel_t<
1765         jit_uni_lrn_bwd_kernel_t<avx512_common, dnnl::impl::data_type::bf16>>;
1766 
1767 } // namespace x64
1768 } // namespace cpu
1769 } // namespace impl
1770 } // namespace dnnl
1771 
1772 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
1773