1 /*******************************************************************************
2 * Copyright 2016-2020 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <array>
18 #include <cmath>
19 #include "common/c_types_map.hpp"
20 #include "common/nstl.hpp"
21 #include "common/utils.hpp"
22 #include "cpu/x64/jit_avx512_core_bf16cvt.hpp"
23 #include "cpu/x64/lrn/jit_uni_lrn_kernel.hpp"
24
25 namespace dnnl {
26 namespace impl {
27 namespace cpu {
28 namespace x64 {
29
30 using namespace dnnl::impl::format_tag;
31
32 #define IRB_LOOP(statement) \
33 if (1 == reg_block) { \
34 const int irb_off = 0; \
35 const int irb = this->reg_block_idx_ % vsum.size(); \
36 statement; \
37 MAYBE_UNUSED(irb_off); \
38 } else { \
39 for (int irb = 0; irb < reg_block; irb++) { \
40 const int irb_off = irb * this->single_pixel_offset_; \
41 statement; \
42 MAYBE_UNUSED(irb_off); \
43 } \
44 }
45
46 using namespace Xbyak;
47
48 template <template <cpu_isa_t isa, data_type_t d_type> class Derived,
49 cpu_isa_t isa, data_type_t d_type>
jit_uni_lrn_kernel_t(void * code_ptr,size_t code_size)50 jit_uni_lrn_kernel_t<Derived<isa, d_type>>::jit_uni_lrn_kernel_t(
51 void *code_ptr, size_t code_size)
52 : jit_generator(code_ptr, code_size, true, isa)
53 , emulate_bfloat_(isa == avx512_common
54 && d_type == dnnl::impl::data_type::bf16
55 && !mayiuse(avx512_core_bf16))
56 , bf16_emu_(
57 emulate_bfloat_ ? utils::make_unique<bf16_emulation_t>(this,
58 bf16_emu_reserv_1_, bf16_emu_reserv_2_,
59 bf16_emu_reserv_3_, bf16_emu_scratch_, bf16_emu_reserv_4_)
60 : nullptr) {
61
62 if (bf16_emu_) bf16_emu_->init_vcvtneps2bf16();
63 }
64
65 template <template <cpu_isa_t isa, data_type_t d_type> class Derived,
66 cpu_isa_t isa, data_type_t d_type>
jit_uni_lrn_kernel_t(const within_config_t & config,void * code_ptr,size_t code_size)67 jit_uni_lrn_kernel_t<Derived<isa, d_type>>::jit_uni_lrn_kernel_t(
68 const within_config_t &config, void *code_ptr, size_t code_size)
69 : jit_uni_lrn_kernel_t(code_ptr, code_size) {
70 if (config.dat_tag == nhwc)
71 single_pixel_offset_
72 = config.C * sizeof(typename prec_traits<d_type>::type);
73 }
74
75 template <template <cpu_isa_t isa, data_type_t d_type> class Derived,
76 cpu_isa_t isa, data_type_t d_type>
77 jit_uni_lrn_kernel_t<Derived<isa, d_type>>::~jit_uni_lrn_kernel_t() = default;
78
79 template <template <cpu_isa_t isa, data_type_t d_type> class Derived,
80 cpu_isa_t isa, data_type_t d_type>
within_loop(const within_config_t & config,int max_reg_blocks,prop_kind_t pk)81 void jit_uni_lrn_kernel_t<Derived<isa, d_type>>::within_loop(
82 const within_config_t &config, int max_reg_blocks, prop_kind_t pk) {
83 const auto derived_ptr = static_cast<Derived<isa, d_type> *>(this);
84
85 const int lower_bound = (config.size - 1) / 2,
86 upper_bound = config.size - lower_bound - 1;
87
88 int pixel_count = 0;
89
90 for (int i = 0; i < lower_bound; ++i) {
91 pixel_count = 0;
92 for (int j = 0; j < lower_bound; ++j)
93 derived_ptr->within_body(-i, upper_bound, -j, upper_bound, config.W,
94 pk, 1, pixel_count++ * this->single_pixel_offset_);
95 derived_ptr->move_data_pointers(pixel_count, pk);
96
97 within_body_reg_blocked(config.W - config.size + 1, max_reg_blocks, -i,
98 upper_bound, -lower_bound, upper_bound, config.W, pk);
99
100 pixel_count = 0;
101 for (int j = config.W - upper_bound; j < config.W; ++j)
102 derived_ptr->within_body(-i, upper_bound, -lower_bound,
103 config.W - 1 - j, config.W, pk, 1,
104 pixel_count++ * this->single_pixel_offset_);
105 derived_ptr->move_data_pointers(pixel_count, pk);
106 }
107
108 this->mov(h_, config.H - config.size + 1);
109 Label lrn_loop_h;
110 this->L(lrn_loop_h);
111 pixel_count = 0;
112 for (int j = 0; j < lower_bound; ++j)
113 derived_ptr->within_body(-lower_bound, upper_bound, -j, upper_bound,
114 config.W, pk, 1, pixel_count++ * this->single_pixel_offset_);
115 derived_ptr->move_data_pointers(pixel_count, pk);
116
117 within_body_reg_blocked(config.W - config.size + 1, max_reg_blocks,
118 -lower_bound, upper_bound, -lower_bound, upper_bound, config.W, pk);
119
120 pixel_count = 0;
121 for (int j = config.W - upper_bound; j < config.W; ++j)
122 derived_ptr->within_body(-lower_bound, upper_bound, -lower_bound,
123 config.W - 1 - j, config.W, pk, 1,
124 pixel_count++ * this->single_pixel_offset_);
125 derived_ptr->move_data_pointers(pixel_count, pk);
126
127 this->dec(h_);
128 this->cmp(h_, 0);
129 this->jne(lrn_loop_h, this->T_NEAR);
130
131 for (int i = config.H - upper_bound; i < config.H; ++i) {
132 pixel_count = 0;
133 for (int j = 0; j < lower_bound; ++j)
134 derived_ptr->within_body(-lower_bound, config.H - 1 - i, -j,
135 upper_bound, config.W, pk, 1,
136 pixel_count++ * this->single_pixel_offset_);
137 derived_ptr->move_data_pointers(pixel_count, pk);
138
139 within_body_reg_blocked(config.W - config.size + 1, max_reg_blocks,
140 -lower_bound, config.H - 1 - i, -lower_bound, upper_bound,
141 config.W, pk);
142
143 pixel_count = 0;
144 for (int j = config.W - upper_bound; j < config.W; ++j)
145 derived_ptr->within_body(-lower_bound, config.H - 1 - i,
146 -lower_bound, config.W - 1 - j, config.W, pk, 1,
147 pixel_count++ * this->single_pixel_offset_);
148 derived_ptr->move_data_pointers(pixel_count, pk);
149 }
150 }
151
152 template <template <cpu_isa_t isa, data_type_t d_type> class Derived,
153 cpu_isa_t isa, data_type_t d_type>
within_body_reg_blocked(int loop_count,int max_reg_blocks,int hoff,int Hoff,int woff,int Woff,int stride,prop_kind_t pk)154 void jit_uni_lrn_kernel_t<Derived<isa, d_type>>::within_body_reg_blocked(
155 int loop_count, int max_reg_blocks, int hoff, int Hoff, int woff,
156 int Woff, int stride, prop_kind_t pk) {
157
158 const auto derived_ptr = static_cast<Derived<isa, d_type> *>(this);
159 Label reg_block_compute_loop;
160
161 const auto res = std::div(loop_count, max_reg_blocks);
162 if (res.quot) {
163 this->mov(this->w_, res.quot);
164 this->L(reg_block_compute_loop);
165 derived_ptr->within_body(
166 hoff, Hoff, woff, Woff, stride, pk, max_reg_blocks, 0);
167 derived_ptr->move_data_pointers(max_reg_blocks, pk);
168 this->dec(this->w_);
169 this->cmp(this->w_, 0);
170 this->jne(reg_block_compute_loop, this->T_NEAR);
171 }
172 if (res.rem) {
173 derived_ptr->within_body(
174 hoff, Hoff, woff, Woff, stride, pk, res.rem, 0);
175 derived_ptr->move_data_pointers(res.rem, pk);
176 }
177 }
178
179 template <template <cpu_isa_t isa, data_type_t d_type> class Derived,
180 cpu_isa_t isa, data_type_t d_type>
load_data(const Vmm & reg,const Xbyak::Address & p)181 void jit_uni_lrn_kernel_t<Derived<isa, d_type>>::load_data(
182 const Vmm ®, const Xbyak::Address &p) {
183 this->uni_vmovups(reg, p);
184 }
185
186 template <typename Gen, typename Reg, typename Addr>
load_bf16_data(Gen generator,const Reg & reg,const Addr & p)187 void load_bf16_data(Gen generator, const Reg ®, const Addr &p) {
188 generator->vpmovzxwd(reg, p);
189 generator->vpslld(reg, reg, 0x10);
190 }
191
192 template <>
193 void jit_uni_lrn_kernel_t<jit_uni_lrn_fwd_kernel_t<avx512_common,
load_data(const Vmm & reg,const Xbyak::Address & p)194 dnnl::impl::data_type::bf16>>::load_data(const Vmm ®,
195 const Xbyak::Address &p) {
196 load_bf16_data(this, reg, p);
197 }
198
199 template <>
200 void jit_uni_lrn_kernel_t<jit_uni_lrn_bwd_kernel_t<avx512_common,
load_data(const Vmm & reg,const Xbyak::Address & p)201 dnnl::impl::data_type::bf16>>::load_data(const Vmm ®,
202 const Xbyak::Address &p) {
203 load_bf16_data(this, reg, p);
204 }
205
206 template <template <cpu_isa_t isa, data_type_t d_type> class Derived,
207 cpu_isa_t isa, data_type_t d_type>
store_data(const Xbyak::Address & addr,const Vmm & reg)208 void jit_uni_lrn_kernel_t<Derived<isa, d_type>>::store_data(
209 const Xbyak::Address &addr, const Vmm ®) {
210 this->uni_vmovups(addr, reg);
211 }
212
213 template <typename Gen, typename Bf16Emu>
store_bf16_data(Gen generator,Bf16Emu emu,const Xbyak::Address & addr,const Zmm & zr)214 void store_bf16_data(
215 Gen generator, Bf16Emu emu, const Xbyak::Address &addr, const Zmm &zr) {
216 const Ymm yr = Ymm(zr.getIdx());
217 if (mayiuse(avx512_core_bf16))
218 generator->vcvtneps2bf16(yr, zr);
219 else
220 emu->vcvtneps2bf16(yr, zr);
221 generator->vmovdqu16(addr, yr);
222 }
223
224 template <>
225 void jit_uni_lrn_kernel_t<jit_uni_lrn_fwd_kernel_t<avx512_common,
store_data(const Xbyak::Address & addr,const Zmm & zr)226 dnnl::impl::data_type::bf16>>::store_data(const Xbyak::Address &addr,
227 const Zmm &zr) {
228 store_bf16_data(this, bf16_emu_.get(), addr, zr);
229 }
230
231 template <>
232 void jit_uni_lrn_kernel_t<jit_uni_lrn_bwd_kernel_t<avx512_common,
store_data(const Xbyak::Address & addr,const Zmm & zr)233 dnnl::impl::data_type::bf16>>::store_data(const Xbyak::Address &addr,
234 const Zmm &zr) {
235 store_bf16_data(this, bf16_emu_.get(), addr, zr);
236 }
237
238 template <template <cpu_isa_t isa, data_type_t d_type> class Derived,
239 cpu_isa_t isa, data_type_t d_type>
load_constant(float constant,const Vmm & v_constant,const Xbyak::Xmm & x_constant)240 void jit_uni_lrn_kernel_t<Derived<isa, d_type>>::load_constant(
241 float constant, const Vmm &v_constant, const Xbyak::Xmm &x_constant) {
242 this->mov(this->imm_addr64_, float2int(constant));
243 this->uni_vmovq(x_constant, this->imm_addr64_);
244 this->vbroadcastss(v_constant, x_constant);
245 }
246
247 template <>
248 void jit_uni_lrn_kernel_t<jit_uni_lrn_fwd_kernel_t<sse41,
load_constant(float constant,const Vmm & v_constant,const Xbyak::Xmm & x_constant)249 dnnl::impl::data_type::f32>>::load_constant(float constant,
250 const Vmm &v_constant, const Xbyak::Xmm &x_constant) {
251 this->mov(this->imm_addr64_, float2int(constant));
252 this->uni_vmovq(x_constant, this->imm_addr64_);
253 this->shufps(x_constant, x_constant, 0);
254 }
255
256 //////////////////////////////////////////////////////////////////////////////
257 // forward kernel
258 template <cpu_isa_t isa, data_type_t d_type>
within_body(int hoff,int Hoff,int woff,int Woff,int stride,prop_kind_t pk,const int reg_block,int pixel_offset)259 void jit_uni_lrn_fwd_kernel_t<isa, d_type>::within_body(int hoff, int Hoff,
260 int woff, int Woff, int stride, prop_kind_t pk, const int reg_block,
261 int pixel_offset) {
262
263 static const std::array<Vmm, 3> vsum {{Vmm(2), Vmm(11), Vmm(20)}};
264 static const std::array<Vmm, 3> vsum2 {{Vmm(3), Vmm(12), Vmm(21)}};
265 static const std::array<Vmm, 3> vdst {{Vmm(4), Vmm(13), Vmm(22)}};
266 static const std::array<std::array<Vmm, 6u>, 3u> vtmp {
267 {{{Vmm(5), Vmm(6), Vmm(7), Vmm(8), Vmm(9), Vmm(14)}},
268 {{Vmm(18), Vmm(15), Vmm(16), Vmm(17), Vmm(29), Vmm(30)}},
269 {{Vmm(23), Vmm(24), Vmm(25), Vmm(26), Vmm(28), Vmm(31)}}}};
270 static const std::array<Vmm, 3> vscratch = {{Vmm(10), Vmm(19), Vmm(27)}};
271 static const std::size_t used_tmp_regs
272 = this->emulate_bfloat_ ? vtmp[0].size() - 2 : vtmp[0].size();
273
274 IRB_LOOP(this->uni_vxorps(vsum[irb], vsum[irb], vsum[irb]));
275 for (int i = hoff; i <= Hoff; ++i) {
276 for (int j = woff; j <= Woff; ++j) {
277 if (i == 0 && j == 0) {
278 IRB_LOOP(this->load_data(
279 vdst[irb], this->ptr[src_ + pixel_offset + irb_off]));
280 IRB_LOOP(this->vfmadd231ps(vsum[irb], vdst[irb], vdst[irb]));
281 } else {
282 const auto idx = this->tempIdx_ % used_tmp_regs;
283 IRB_LOOP(this->load_data(vtmp[irb][idx],
284 this->ptr[(src_ + pixel_offset + irb_off)
285 + (i * stride + j)
286 * this->single_pixel_offset_]));
287 IRB_LOOP(this->vfmadd231ps(
288 vsum[irb], vtmp[irb][idx], vtmp[irb][idx]));
289 ++(this->tempIdx_);
290 }
291 }
292 }
293
294 this->tempIdx_ = this->tempIdx_ % used_tmp_regs;
295
296 IRB_LOOP(this->vfmadd132ps(
297 vsum[irb], vk_, valpha_)); // ysum <- ysum*valpha_+yk_
298 IRB_LOOP(this->vmovaps(vscratch[irb], vsum[irb]));
299
300 IRB_LOOP(this->vmulps(vsum2[irb], vsum[irb], vsum[irb]));
301 IRB_LOOP(this->vmulps(
302 vsum[irb], vsum[irb], vsum2[irb])); // ysum = (ysum*valpha_+yk_)^3;
303 IRB_LOOP(this->vsqrtps(vsum[irb], vsum[irb]));
304 IRB_LOOP(this->vsqrtps(
305 vsum[irb], vsum[irb])); // ysum = (ysum*valpha_+yk_)^0.75
306 IRB_LOOP(this->vdivps(
307 vdst[irb], vdst[irb], vsum[irb])); // ydst <- ydst / ysum
308
309 if (pk_ != prop_kind::forward_inference) {
310 IRB_LOOP(this->store_data(
311 this->ptr[scratch_ + pixel_offset + irb_off], vsum[irb]));
312 IRB_LOOP(this->vdivps(vscratch[irb], vdst[irb], vscratch[irb]));
313 IRB_LOOP(this->store_data(
314 this->ptr[bwd_intermediate_res_ + pixel_offset + irb_off],
315 vscratch[irb]));
316 }
317
318 IRB_LOOP(this->store_data(
319 this->ptr[dst_ + pixel_offset + irb_off], vdst[irb]));
320
321 if (isa == avx512_common)
322 this->reg_block_idx_ = (this->reg_block_idx_ % vsum.size()) + 1;
323 }
324
325 template <>
within_body(int hoff,int Hoff,int woff,int Woff,int stride,prop_kind_t pk,int reg_block,int pixel_offset)326 void jit_uni_lrn_fwd_kernel_t<sse41, dnnl::impl::data_type::f32>::within_body(
327 int hoff, int Hoff, int woff, int Woff, int stride, prop_kind_t pk,
328 int reg_block, int pixel_offset) {
329
330 const Xbyak::Xmm &xtmp_lo = this->xmm2;
331 const Xbyak::Xmm &xtmp_hi = this->xmm3;
332 const Xbyak::Xmm &xsum_lo = this->xmm4;
333 const Xbyak::Xmm &xsum_hi = this->xmm5;
334 const Xbyak::Xmm &xdst_lo = this->xmm6;
335 const Xbyak::Xmm &xdst_hi = this->xmm7;
336 const Xbyak::Xmm &xsum2_lo = this->xmm8;
337 const Xbyak::Xmm &xsum2_hi = this->xmm9;
338
339 xorps(xsum_lo, xsum_lo);
340 xorps(xsum_hi, xsum_hi);
341 for (int i = hoff; i <= Hoff; ++i) {
342 for (int j = woff; j <= Woff; ++j) {
343 if (i == 0 && j == 0) {
344 movups(xdst_lo, ptr[src_ + pixel_offset]);
345 movups(xdst_hi, ptr[src_ + pixel_offset + 4 * sizeof(float)]);
346 mulps(xdst_lo, xdst_lo);
347 mulps(xdst_hi, xdst_hi);
348 addps(xsum_lo, xdst_lo);
349 addps(xsum_hi, xdst_hi);
350 } else {
351 movups(xtmp_lo,
352 ptr[src_ + pixel_offset
353 + (i * stride + j) * single_pixel_offset_]);
354 movups(xtmp_hi,
355 ptr[src_ + pixel_offset
356 + (i * stride + j) * single_pixel_offset_
357 + 4 * sizeof(float)]);
358 this->mulps(xtmp_lo, xtmp_lo);
359 this->mulps(xtmp_hi, xtmp_hi);
360 this->addps(xsum_lo, xtmp_lo);
361 this->addps(xsum_hi, xtmp_hi);
362 }
363 }
364 }
365 this->mulps(xsum_lo, xalpha_);
366 this->mulps(xsum_hi, xalpha_);
367 this->addps(xsum_lo, xk_);
368 this->addps(xsum_hi, xk_); // xsum <- xsum*xalpha_+xk_
369 this->movaps(xtmp_lo, xsum_lo);
370 this->movaps(xtmp_hi, xsum_hi);
371 if (pk_ != prop_kind::forward_inference) {
372 this->movups(this->ptr[scratch_ + pixel_offset], xtmp_lo);
373 this->movups(this->ptr[scratch_ + pixel_offset + 4 * sizeof(float)],
374 xtmp_hi);
375 }
376 this->movaps(xsum2_lo, xsum_lo);
377 this->movaps(xsum2_hi, xsum_hi);
378 this->mulps(xsum2_lo, xsum_lo);
379 this->mulps(xsum2_hi, xsum_hi);
380 this->mulps(xsum_lo, xsum2_lo);
381 this->mulps(xsum_hi, xsum2_hi); // xsum = (xsum*xalpha_+xk_)^3;
382
383 this->sqrtps(xsum_lo, xsum_lo);
384 this->sqrtps(xsum_hi, xsum_hi);
385 this->sqrtps(xsum_lo, xsum_lo);
386 this->sqrtps(xsum_hi, xsum_hi); // xsum = (xsum*xalpha_+xk_)^0.75
387
388 this->movups(xdst_lo, this->ptr[src_ + pixel_offset]);
389 this->movups(xdst_hi, this->ptr[src_ + pixel_offset + 4 * sizeof(float)]);
390 this->divps(xdst_lo, xsum_lo);
391 this->divps(xdst_hi, xsum_hi); // xdst <- xdst / xsum
392
393 this->movups(this->ptr[dst_ + pixel_offset], xdst_lo);
394 this->movups(this->ptr[dst_ + pixel_offset + 4 * sizeof(float)], xdst_hi);
395 }
396
397 template <cpu_isa_t isa, data_type_t d_type>
move_data_pointers(int pixel_count,prop_kind_t pk)398 void jit_uni_lrn_fwd_kernel_t<isa, d_type>::move_data_pointers(
399 int pixel_count, prop_kind_t pk) {
400
401 const int pixel_offset = this->single_pixel_offset_ * pixel_count;
402 this->add(src_, pixel_offset);
403 this->add(dst_, pixel_offset);
404 if (pk_ != prop_kind::forward_inference) {
405 this->add(scratch_, pixel_offset);
406 this->add(bwd_intermediate_res_, pixel_offset);
407 }
408 }
409
410 template <cpu_isa_t isa, data_type_t d_type>
jit_uni_lrn_fwd_kernel_t(const within_config_t & config,float A,float K,prop_kind_t pk,void * code_ptr,size_t code_size)411 jit_uni_lrn_fwd_kernel_t<isa, d_type>::jit_uni_lrn_fwd_kernel_t(
412 const within_config_t &config, float A, float K, prop_kind_t pk,
413 void *code_ptr, size_t code_size)
414 : Base(config, code_ptr, code_size)
415 , config_(lrn_config_t::within_config)
416 , within_config_(config)
417 , alpha_(A)
418 , k_(K)
419 , pk_(pk) {}
420
421 template <cpu_isa_t isa, data_type_t d_type>
generate(const within_config_t & config)422 void jit_uni_lrn_fwd_kernel_t<isa, d_type>::generate(
423 const within_config_t &config) {
424 this->preamble();
425
426 #define GET_OFF(field) offsetof(jit_args_fwd_t, field)
427 this->mov(src_, this->ptr[this->param1 + GET_OFF(src)]);
428 this->mov(dst_, this->ptr[this->param1 + GET_OFF(dst)]);
429 if (pk_ != prop_kind::forward_inference) {
430 this->mov(scratch_, this->ptr[this->param1 + GET_OFF(scratch)]);
431 this->mov(bwd_intermediate_res_,
432 this->ptr[this->param1 + GET_OFF(bwd_intermediate_res)]);
433 }
434 #undef GET_OFF
435
436 this->load_constant(alpha_, valpha_, xalpha_);
437 this->load_constant(k_, vk_, xk_);
438
439 static const int max_reg_blocks = isa == avx512_common ? 3 : 1;
440 this->within_loop(config, max_reg_blocks, pk_);
441
442 this->postamble();
443 }
444
445 template <cpu_isa_t isa, data_type_t d_type>
jit_uni_lrn_fwd_kernel_t(const struct nchw8c_across_t & J,float A,float K,prop_kind_t pk,void * code_ptr,size_t code_size)446 jit_uni_lrn_fwd_kernel_t<isa, d_type>::jit_uni_lrn_fwd_kernel_t(
447 const struct nchw8c_across_t &J, float A, float K, prop_kind_t pk,
448 void *code_ptr, size_t code_size)
449 : Base(code_ptr, code_size)
450 , config_(lrn_config_t::nchw8c_across)
451 , nchw8c_across_(J)
452 , alpha_(A)
453 , k_(K)
454 , pk_(pk) {}
455
456 template <cpu_isa_t isa, data_type_t d_type>
generate(const nchw8c_across_t & J)457 void jit_uni_lrn_fwd_kernel_t<isa, d_type>::generate(const nchw8c_across_t &J) {
458 const Xbyak::Reg64 &t = this->rsp;
459 const Xbyak::Reg64 &hw = this->r9;
460 const Xbyak::Xmm &xsrc_prev = this->xmm2;
461 const Xbyak::Ymm &ysrc = this->ymm3;
462 const Xbyak::Ymm &yc = this->ymm3;
463 const Xbyak::Xmm &xsrc_next = this->xmm4;
464 const Xbyak::Ymm &ya = this->ymm5;
465 const Xbyak::Ymm &yb = this->ymm6;
466 const Xbyak::Ymm &yd = this->ymm7;
467 const Xbyak::Ymm &ye = this->ymm8;
468 const Xbyak::Ymm &ysum = this->ymm9;
469 const Xbyak::Ymm &ysum2 = this->ymm10;
470 const Xbyak::Ymm &ydst = this->ymm11;
471 const Xbyak::Ymm &ybase = this->ymm12;
472
473 this->preamble();
474
475 this->mov(src_, this->ptr[this->param1 + 0]);
476 this->mov(dst_, this->ptr[this->param1 + 8]);
477 if (pk_ != prop_kind::forward_inference)
478 this->mov(scratch_, this->ptr[this->param1 + 16]);
479 this->sub(t, 64);
480 this->mov(this->imm_addr64_, float2int(this->alpha_));
481 this->vmovq(xalpha_, this->imm_addr64_);
482 this->vbroadcastss(valpha_, xalpha_);
483
484 this->mov(this->imm_addr64_, float2int(this->k_));
485 this->vmovq(xk_, this->imm_addr64_);
486 this->vbroadcastss(yk_, xk_);
487
488 if (J.version == -1) {
489 this->vxorps(xsrc_prev, xsrc_prev, xsrc_prev);
490 this->vmovups(this->ptr[t + 0], xsrc_prev);
491 }
492 if (J.version == +1) {
493 this->vxorps(xsrc_next, xsrc_next, xsrc_next);
494 this->vmovups(this->ptr[t + 48], xsrc_next);
495 }
496
497 this->mov(hw, J.H * J.W);
498
499 Label lrn_loop;
500 this->L(lrn_loop);
501
502 if (J.version != -1)
503 this->vmovups(xsrc_prev, this->ptr[src_ - J.H * J.W * 32 + 16]);
504 this->vmovups(ysrc, this->ptr[src_]);
505 if (J.version != +1)
506 this->vmovups(xsrc_next, this->ptr[src_ + J.H * J.W * 32]);
507
508 if (J.version != -1) this->vmovups(this->ptr[t + 0], xsrc_prev);
509 this->vmovups(this->ptr[t + 16], ysrc);
510 if (J.version != +1) this->vmovups(this->ptr[t + 48], xsrc_next);
511
512 this->vmovups(ya, this->ptr[t + 16 - 8]);
513 this->vmovups(yb, this->ptr[t + 16 - 4]);
514 this->vmovups(yd, this->ptr[t + 16 + 4]);
515 this->vmovups(ye, this->ptr[t + 16 + 8]);
516 this->vmulps(ysum, yc, yc);
517 this->vfmadd231ps(ysum, ya, ya); // ysum <- ysum + ya*ya
518 this->vfmadd231ps(ysum, yb, yb);
519 this->vfmadd231ps(ysum, yd, yd);
520 this->vfmadd231ps(ysum, ye, ye);
521 this->vfmadd132ps(ysum, yk_, valpha_); // ysum <- ysum*valpha_+yk_
522
523 this->vmovaps(ybase, ysum);
524 if (pk_ != prop_kind::forward_inference)
525 this->vmovups(this->ptr[scratch_], ybase);
526 this->vmulps(ysum2, ysum, ysum);
527 this->vmulps(ysum, ysum, ysum2); // ysum = ybase^3;
528 this->vsqrtps(ysum, ysum);
529 this->vsqrtps(ysum, ysum); // ysum = ybase^0.75
530 this->vdivps(ydst, ysrc, ysum); // ydst = ysrc / ysum
531 this->vmovups(this->ptr[dst_], ydst);
532
533 this->add(src_, 32);
534 this->add(dst_, 32);
535 if (pk_ != prop_kind::forward_inference) this->add(scratch_, 32);
536 this->dec(hw);
537 this->cmp(hw, 0);
538 this->jne(lrn_loop, this->T_NEAR);
539
540 this->add(t, 64);
541 this->postamble();
542 }
543
544 template <>
545 jit_uni_lrn_fwd_kernel_t<sse41, dnnl::impl::data_type::f32>::
jit_uni_lrn_fwd_kernel_t(const struct nchw8c_across_t & J,float A,float K,prop_kind_t pk,void * code_ptr,size_t code_size)546 jit_uni_lrn_fwd_kernel_t(const struct nchw8c_across_t &J, float A,
547 float K, prop_kind_t pk, void *code_ptr, size_t code_size)
548 : Base(code_ptr, code_size)
549 , config_(lrn_config_t::nchw8c_across)
550 , nchw8c_across_(J)
551 , alpha_(A)
552 , k_(K)
553 , pk_(pk) {}
554
555 template <>
generate(const nchw8c_across_t & J)556 void jit_uni_lrn_fwd_kernel_t<sse41, dnnl::impl::data_type::f32>::generate(
557 const nchw8c_across_t &J) {
558
559 const Xbyak::Reg64 &t = this->rsp;
560 const Xbyak::Reg64 &hw = this->r9;
561 const Xbyak::Xmm &xsrc_lo = this->xmm2;
562 const Xbyak::Xmm &xsrc_hi = this->xmm3;
563 const Xbyak::Xmm &xc_lo = this->xmm4;
564 const Xbyak::Xmm &xc_hi = this->xmm5;
565 const Xbyak::Xmm &xsum_lo = xc_lo;
566 const Xbyak::Xmm &xsum_hi = xc_hi;
567 const Xbyak::Xmm &xsrc_prev = this->xmm6;
568 const Xbyak::Xmm &xsrc_next = this->xmm7;
569 const Xbyak::Xmm &xa_lo = this->xmm8;
570 const Xbyak::Xmm &xa_hi = this->xmm9;
571 const Xbyak::Xmm &xb_lo = this->xmm10;
572 const Xbyak::Xmm &xb_hi = this->xmm11;
573 const Xbyak::Xmm &xd_lo = this->xmm12;
574 const Xbyak::Xmm &xd_hi = this->xmm13;
575 const Xbyak::Xmm &xe_lo = this->xmm14;
576 const Xbyak::Xmm &xe_hi = this->xmm15;
577 const Xbyak::Xmm &xbase_lo = this->xmm14;
578 const Xbyak::Xmm &xbase_hi = this->xmm15;
579
580 this->preamble();
581
582 this->mov(src_, this->ptr[this->param1 + 0]);
583 this->mov(dst_, this->ptr[this->param1 + 8]);
584 if (pk_ != prop_kind::forward_inference)
585 this->mov(scratch_, this->ptr[this->param1 + 16]);
586 this->sub(t, 64);
587 this->mov(this->imm_addr64_, float2int(this->alpha_));
588 this->movq(xalpha_, this->imm_addr64_);
589 this->shufps(xalpha_, xalpha_, 0);
590
591 this->mov(this->imm_addr64_, float2int(this->k_));
592 this->movq(xk_, this->imm_addr64_);
593 this->shufps(xk_, xk_, 0);
594
595 if (J.version == -1) {
596 this->xorps(xsrc_prev, xsrc_prev);
597 this->movups(this->ptr[t + 0], xsrc_prev);
598 }
599 if (J.version == +1) {
600 this->xorps(xsrc_next, xsrc_next);
601 this->movups(this->ptr[t + 48], xsrc_next);
602 }
603
604 this->mov(hw, J.H * J.W);
605 Label lrn_loop;
606 L(lrn_loop);
607
608 if (J.version != -1)
609 this->movups(xsrc_prev, this->ptr[src_ - J.H * J.W * 32 + 16]);
610 this->movups(xsrc_lo, this->ptr[src_]);
611 this->movups(xsrc_hi, this->ptr[src_ + 4 * sizeof(float)]);
612 if (J.version != +1)
613 this->movups(xsrc_next, this->ptr[src_ + J.H * J.W * 32]);
614
615 if (J.version != -1) this->movups(this->ptr[t + 0], xsrc_prev);
616 this->movups(this->ptr[t + 16], xsrc_lo);
617 this->movups(this->ptr[t + 16 + 4 * sizeof(float)], xsrc_hi);
618 if (J.version != +1) this->movups(this->ptr[t + 48], xsrc_next);
619
620 this->movups(xa_lo, this->ptr[t + 16 - 8]);
621 this->movups(xa_hi, this->ptr[t + 16 - 8 + 4 * sizeof(float)]);
622 this->movups(xb_lo, this->ptr[t + 16 - 4]);
623 this->movups(xb_hi, this->ptr[t + 16 - 4 + 4 * sizeof(float)]);
624 this->movups(xd_lo, this->ptr[t + 16 + 4]);
625 this->movups(xd_hi, this->ptr[t + 16 + 4 + 4 * sizeof(float)]);
626 this->movups(xe_lo, this->ptr[t + 16 + 8]);
627 this->movups(xe_hi, this->ptr[t + 16 + 8 + 4 * sizeof(float)]);
628 this->movaps(xc_lo, xsrc_lo);
629 this->movaps(xc_hi, xsrc_hi);
630 this->mulps(xsum_lo, xc_lo);
631 this->mulps(xsum_hi, xc_hi);
632 this->mulps(xa_lo, xa_lo);
633 this->mulps(xa_hi, xa_hi);
634 this->addps(xsum_lo, xa_lo);
635 this->addps(xsum_hi, xa_hi); // xsum <- xsum + xa*xa
636 this->mulps(xb_lo, xb_lo);
637 this->mulps(xb_hi, xb_hi);
638 this->addps(xsum_lo, xb_lo);
639 this->addps(xsum_hi, xb_hi);
640 this->mulps(xd_lo, xd_lo);
641 this->mulps(xd_hi, xd_hi);
642 this->addps(xsum_lo, xd_lo);
643 this->addps(xsum_hi, xd_hi);
644 this->mulps(xe_lo, xe_lo);
645 this->mulps(xe_hi, xe_hi);
646 this->addps(xsum_lo, xe_lo);
647 this->addps(xsum_hi, xe_hi);
648
649 this->mulps(xsum_lo, xalpha_);
650 this->mulps(xsum_hi, xalpha_);
651 this->addps(xsum_lo, xk_);
652 this->addps(xsum_hi, xk_); // xsum <- xsum*xalpha_+xk_
653
654 this->movaps(xbase_lo, xsum_lo);
655 this->movaps(xbase_hi, xsum_hi);
656 if (pk_ != prop_kind::forward_inference) {
657 this->movups(this->ptr[scratch_], xbase_lo);
658 this->movups(this->ptr[scratch_ + 4 * sizeof(float)], xbase_hi);
659 }
660 this->mulps(xsum_lo, xsum_lo);
661 this->mulps(xsum_hi, xsum_hi);
662 this->mulps(xsum_lo, xbase_lo);
663 this->mulps(xsum_hi, xbase_hi); // xsum = xbase^3;
664 this->sqrtps(xsum_lo, xsum_lo);
665 this->sqrtps(xsum_hi, xsum_hi);
666 this->sqrtps(xsum_lo, xsum_lo);
667 this->sqrtps(xsum_hi, xsum_hi); // xsum = xbase^0.75
668 this->divps(xsrc_lo, xsum_lo);
669 this->divps(xsrc_hi, xsum_hi); // xdst = xsrc / xsum
670 this->movups(this->ptr[dst_], xsrc_lo);
671 this->movups(this->ptr[dst_ + 4 * sizeof(float)], xsrc_hi);
672
673 this->add(src_, 32);
674 this->add(dst_, 32);
675 if (pk_ != prop_kind::forward_inference) add(scratch_, 32);
676 this->dec(hw);
677 this->cmp(hw, 0);
678 this->jne(lrn_loop, this->T_NEAR);
679
680 this->add(t, 64);
681 this->postamble();
682 }
683
684 template <cpu_isa_t isa, data_type_t d_type>
jit_uni_lrn_fwd_kernel_t(const struct nhwc_across_t & J,float A,float K,prop_kind_t pk,void * code_ptr,size_t code_size)685 jit_uni_lrn_fwd_kernel_t<isa, d_type>::jit_uni_lrn_fwd_kernel_t(
686 const struct nhwc_across_t &J, float A, float K, prop_kind_t pk,
687 void *code_ptr, size_t code_size)
688 : Base(code_ptr, code_size)
689 , config_(lrn_config_t::nhwc_across)
690 , nhwc_across_(J)
691 , alpha_(A)
692 , k_(K)
693 , pk_(pk) {}
694
695 template <cpu_isa_t isa, data_type_t d_type>
generate(const nhwc_across_t & J)696 void jit_uni_lrn_fwd_kernel_t<isa, d_type>::generate(const nhwc_across_t &J) {
697 static const uint32_t mask[] = {0, 0, 0x80000000, 0x80000000, 0x80000000,
698 0x80000000, 0x80000000, 0x80000000, 0x80000000, 0, 0};
699
700 const Xbyak::Reg64 &c = this->r9;
701 const Xbyak::Ymm &ya = this->ymm2;
702 const Xbyak::Ymm &yb = this->ymm3;
703 const Xbyak::Ymm &yc = this->ymm4;
704 const Xbyak::Ymm &yd = this->ymm5;
705 const Xbyak::Ymm &ye = this->ymm6;
706 const Xbyak::Ymm &ysum = this->ymm7;
707 const Xbyak::Ymm &ydst = this->ymm8;
708 const Xbyak::Ymm &ybase = this->ymm9;
709 const Xbyak::Ymm &ymask = this->ymm10;
710
711 this->preamble();
712
713 this->mov(src_, this->ptr[this->param1 + 0]);
714 this->mov(dst_, this->ptr[this->param1 + 8]);
715 if (pk_ != prop_kind::forward_inference)
716 this->mov(scratch_, this->ptr[this->param1 + 16]);
717 this->mov(this->imm_addr64_, float2int(this->alpha_));
718 this->movq(xalpha_, this->imm_addr64_);
719 this->vbroadcastss(valpha_, xalpha_);
720
721 this->mov(this->imm_addr64_, float2int(this->k_));
722 this->movq(xk_, this->imm_addr64_);
723 this->vbroadcastss(yk_, xk_);
724
725 this->vxorps(ysum, ysum, ysum);
726
727 this->mov(this->imm_addr64_, reinterpret_cast<size_t>(&mask[0]));
728 this->vmovups(ymask, this->ptr[this->imm_addr64_]);
729 this->vmaskmovps(ya, ymask, this->ptr[src_ - 8]);
730 this->vfmadd231ps(ysum, ya, ya); // ysum <- ysum + ya^2+yb^2+yc^2+yd^2+ye^2
731
732 this->mov(this->imm_addr64_, reinterpret_cast<size_t>(&mask[1]));
733 this->vmovups(ymask, this->ptr[this->imm_addr64_]);
734 this->vmaskmovps(yb, ymask, this->ptr[src_ - 4]);
735 this->vfmadd231ps(ysum, yb, yb);
736
737 this->mov(c, J.C / 8 - 1);
738 Label lrn_loop;
739 this->L(lrn_loop);
740
741 this->vmovups(yc, this->ptr[src_]);
742 this->vmovups(yd, this->ptr[src_ + 4]);
743 this->vmovups(ye, this->ptr[src_ + 8]);
744 this->vfmadd231ps(ysum, yc, yc);
745 this->vfmadd231ps(ysum, yd, yd);
746 this->vfmadd231ps(ysum, ye, ye);
747
748 this->vmovups(ydst, ysum);
749 this->vfmadd132ps(ydst, yk_, valpha_); // ydst <- ysum*valpha_+yk_
750
751 this->vmovaps(ybase, ydst);
752 if (pk_ != prop_kind::forward_inference)
753 this->vmovups(this->ptr[scratch_], ybase);
754 this->vmulps(ydst, ydst, ydst);
755 this->vmulps(ydst, ydst, ybase); // ydst = (ysum*valpha_+yk_)^3;
756 this->vsqrtps(ydst, ydst);
757 this->vsqrtps(ydst, ydst); // ydst = (ysum*valpha_+yk_)^0.75
758
759 this->vdivps(ydst, yc, ydst); // ydst = ysrc / (ysum*valpha_+yk_)^0.75
760 this->vmovups(this->ptr[dst_], ydst);
761
762 this->vxorps(ysum, ysum, ysum);
763
764 this->add(src_, 32);
765 this->add(dst_, 32);
766 if (pk_ != prop_kind::forward_inference) this->add(scratch_, 32);
767
768 this->vmovups(ya, this->ptr[src_ - 8]);
769 this->vfmadd231ps(ysum, ya, ya);
770 this->vmovups(yb, this->ptr[src_ - 4]);
771 this->vfmadd231ps(ysum, yb, yb);
772
773 this->dec(c);
774 this->cmp(c, 0);
775 this->jne(lrn_loop, this->T_NEAR);
776
777 this->vmovups(yc, this->ptr[src_]);
778 this->vfmadd231ps(ysum, yc, yc);
779
780 this->mov(this->imm_addr64_, reinterpret_cast<size_t>(&mask[2]));
781 this->vmovups(ymask, this->ptr[this->imm_addr64_]);
782 this->vmaskmovps(yd, ymask, this->ptr[src_ + 4]);
783 this->vfmadd231ps(ysum, yd, yd); // ysum <- ysum + ya^2+yb^2+yc^2+yd^2+ye^2
784
785 this->mov(this->imm_addr64_, reinterpret_cast<size_t>(&mask[3]));
786 this->vmovups(ymask, this->ptr[this->imm_addr64_]);
787 this->vmaskmovps(ye, ymask, this->ptr[src_ + 8]);
788 this->vfmadd231ps(ysum, ye, ye);
789
790 this->vmovups(ydst, ysum);
791 this->vfmadd132ps(ydst, yk_, valpha_); // ydst <- ysum*valpha_+yk_
792
793 this->vmovaps(ybase, ydst);
794 if (pk_ != prop_kind::forward_inference)
795 this->vmovups(this->ptr[scratch_], ybase);
796 this->vmulps(ydst, ydst, ydst);
797 this->vmulps(ydst, ydst, ybase); // ydst = (ysum*valpha_+yk_)^3;
798 this->vsqrtps(ydst, ydst);
799 this->vsqrtps(ydst, ydst); // ydst = (ysum*valpha_+yk_)^0.75
800 this->vdivps(ydst, yc, ydst); // ydst = ysrc / (ysum*valpha_+yk_)^0.75
801
802 this->vmovups(this->ptr[dst_], ydst);
803
804 this->postamble();
805 }
806
807 template <>
808 jit_uni_lrn_fwd_kernel_t<sse41, dnnl::impl::data_type::f32>::
jit_uni_lrn_fwd_kernel_t(const struct nhwc_across_t & J,float A,float K,prop_kind_t pk,void * code_ptr,size_t code_size)809 jit_uni_lrn_fwd_kernel_t(const struct nhwc_across_t &J, float A,
810 float K, prop_kind_t pk, void *code_ptr, size_t code_size)
811 : Base(code_ptr, code_size)
812 , config_(lrn_config_t::nhwc_across)
813 , nhwc_across_(J)
814 , alpha_(A)
815 , k_(K)
816 , pk_(pk) {}
817
818 template <>
generate(const nhwc_across_t & J)819 void jit_uni_lrn_fwd_kernel_t<sse41, dnnl::impl::data_type::f32>::generate(
820 const nhwc_across_t &J) {
821 static uint32_t store[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
822 const Xbyak::Reg64 c = this->r9;
823
824 const Xbyak::Xmm &xdst_lo = this->xmm0;
825 const Xbyak::Xmm &xdst_hi = this->xmm1;
826 const Xbyak::Xmm &xa_lo = this->xmm2;
827 const Xbyak::Xmm &xa_hi = this->xmm3;
828 const Xbyak::Xmm &xb_lo = this->xmm2;
829 const Xbyak::Xmm &xb_hi = this->xmm3;
830 const Xbyak::Xmm &xc_lo = this->xmm4;
831 const Xbyak::Xmm &xc_hi = this->xmm5;
832 const Xbyak::Xmm &xd_lo = this->xmm6;
833 const Xbyak::Xmm &xd_hi = this->xmm7;
834 const Xbyak::Xmm &xe_lo = this->xmm8;
835 const Xbyak::Xmm &xe_hi = this->xmm9;
836 const Xbyak::Xmm &xsum_lo = this->xmm10;
837 const Xbyak::Xmm &xsum_hi = this->xmm11;
838 // unused: xmm12, xmm13;
839 const Xbyak::Xmm &xbase_lo = this->xmm14;
840 const Xbyak::Xmm &xbase_hi = this->xmm15;
841
842 this->preamble();
843
844 this->mov(src_, this->ptr[this->param1 + 0]);
845 this->mov(dst_, this->ptr[this->param1 + 8]);
846 if (pk_ != prop_kind::forward_inference)
847 mov(scratch_, this->ptr[this->param1 + 16]);
848 this->mov(this->imm_addr64_, float2int(this->alpha_));
849 this->movq(xalpha_, this->imm_addr64_);
850 this->shufps(xalpha_, xalpha_, 0);
851
852 this->mov(this->imm_addr64_, float2int(this->k_));
853 this->movq(xk_, this->imm_addr64_);
854 this->shufps(xk_, xk_, 0);
855
856 this->mov(store_addr_, reinterpret_cast<size_t>(&store[0]));
857 this->and_(store_addr_, -15);
858 this->movups(this->ptr[store_addr_], xalpha_);
859 this->movups(this->ptr[store_addr_ + 4 * sizeof(float)], xk_);
860
861 this->xorps(xsum_lo, xsum_lo);
862 this->xorps(xsum_hi, xsum_hi);
863
864 /* load the 2 first blocks of channels
865 * block: | -- low -- | -- hi -- |
866 * C: [c1,c2,c3,c4,c5,c6,c7,c8]
867 * xa_lo << 2 [0,0,c1,c2]
868 * xa_hi [c3,c4,c5,c6]
869 * xb_lo << 1 [0,c1,c2,c3]
870 * xb_hi [c4,c5,c6,c7]
871 * | -- data -- (...)
872 * ^ memory boundary
873 */
874 this->movups(xa_lo, this->ptr[src_]);
875 this->movups(xa_hi, this->ptr[src_ + 2 * sizeof(float)]);
876 this->pslldq(xa_lo, 2 * sizeof(float));
877 this->mulps(xa_lo, xa_lo);
878 this->mulps(xa_hi, xa_hi);
879 this->addps(xsum_lo, xa_lo);
880 this->addps(xsum_hi, xa_hi); // xsum <- xsum + xa^2+xb^2+xc^2+xd^2+xe^2
881
882 this->movups(xb_lo, this->ptr[src_]);
883 this->movups(xb_hi, this->ptr[src_ + 3 * sizeof(float)]);
884 this->pslldq(xb_lo, 1 * sizeof(float));
885 this->mulps(xb_lo, xb_lo);
886 this->mulps(xb_hi, xb_hi);
887 this->addps(xsum_lo, xb_lo);
888 this->addps(xsum_hi, xb_hi);
889
890 this->mov(c, J.C / 8 - 1);
891 Label lrn_loop;
892 this->L(lrn_loop);
893
894 this->movups(xc_lo, this->ptr[src_]);
895 this->movups(xc_hi, this->ptr[src_ + 4 * sizeof(float)]);
896 this->movups(xd_lo, this->ptr[src_ + 4]);
897 this->movups(xd_hi, this->ptr[src_ + 4 + 4 * sizeof(float)]);
898 this->movups(xe_lo, this->ptr[src_ + 8]);
899 this->movups(xe_hi, this->ptr[src_ + 8 + 4 * sizeof(float)]);
900 this->mulps(xc_lo, xc_lo);
901 this->mulps(xc_hi, xc_hi);
902 this->addps(xsum_lo, xc_lo);
903 this->addps(xsum_hi, xc_hi);
904 this->mulps(xd_lo, xd_lo);
905 this->mulps(xd_hi, xd_hi);
906 this->addps(xsum_lo, xd_lo);
907 this->addps(xsum_hi, xd_hi);
908 this->mulps(xe_lo, xe_lo);
909 this->mulps(xe_hi, xe_hi);
910 this->addps(xsum_lo, xe_lo);
911 this->addps(xsum_hi, xe_hi);
912
913 this->movaps(xdst_lo, xsum_lo);
914 this->movaps(xdst_hi, xsum_hi);
915 // xdst <- xsum*xalpha_+xk_
916 this->mulps(xdst_lo, this->ptr[store_addr_]);
917 this->mulps(xdst_hi, this->ptr[store_addr_]);
918 this->addps(xdst_lo, this->ptr[store_addr_ + 4 * sizeof(float)]);
919 this->addps(xdst_hi, this->ptr[store_addr_ + 4 * sizeof(float)]);
920
921 this->movaps(xbase_lo, xdst_lo);
922 this->movaps(xbase_hi, xdst_hi);
923 if (pk_ != prop_kind::forward_inference) {
924 this->movups(this->ptr[scratch_], xbase_lo);
925 this->movups(this->ptr[scratch_ + 4 * sizeof(float)], xbase_hi);
926 }
927 this->mulps(xdst_lo, xdst_lo);
928 this->mulps(xdst_hi, xdst_hi);
929 this->mulps(xdst_lo, xbase_lo);
930 this->mulps(xdst_hi, xbase_hi); // xdst = (xsum*xalpha_+xk_)^3;
931 this->sqrtps(xdst_lo, xdst_lo);
932 this->sqrtps(xdst_hi, xdst_hi);
933 this->sqrtps(xdst_lo, xdst_lo);
934 this->sqrtps(xdst_hi, xdst_hi); // xdst = (xsum*xalpha_+xk_)^0.75
935
936 this->movups(xc_lo, this->ptr[src_]);
937 this->movups(xc_hi, this->ptr[src_ + 4 * sizeof(float)]);
938 this->divps(xc_lo, xdst_lo);
939 this->divps(xc_hi, xdst_hi); // xdst = xsrc / (xsum*xalpha_+xk_)^0.75
940 this->movups(this->ptr[dst_], xc_lo);
941 this->movups(this->ptr[dst_ + 4 * sizeof(float)], xc_hi);
942
943 this->xorps(xsum_lo, xsum_lo);
944 this->xorps(xsum_hi, xsum_hi);
945
946 this->add(src_, 32);
947 this->add(dst_, 32);
948 if (pk_ != prop_kind::forward_inference) this->add(scratch_, 32);
949
950 this->movups(xa_lo, this->ptr[src_ - 8]);
951 this->movups(xa_hi, this->ptr[src_ - 8 + 4 * sizeof(float)]);
952 this->mulps(xa_lo, xa_lo);
953 this->mulps(xa_hi, xa_hi);
954 this->addps(xsum_lo, xa_lo);
955 this->addps(xsum_hi, xa_hi);
956 this->movups(xb_lo, this->ptr[src_ - 4]);
957 this->movups(xb_hi, this->ptr[src_ - 4 + 4 * sizeof(float)]);
958 this->mulps(xb_lo, xb_lo);
959 this->mulps(xb_hi, xb_hi);
960 this->addps(xsum_lo, xb_lo);
961 this->addps(xsum_hi, xb_hi);
962
963 this->dec(c);
964 this->cmp(c, 0);
965 this->jne(lrn_loop, this->T_NEAR);
966
967 /* compute last 3 blocks of channels:
968 * block: | -- low -- | -- hi -- |
969 * C: [c1,c2,c3,c4,c5,c6,c7,c8]
970 * xc_lo|xc_hi [c1,c2,c3,c4|c5,c6,c7,c8]
971 * xd_lo [c2,c3,c4,c5]
972 * xd_hi >> 1 [c6,c7,c8, 0]
973 * xe_lo [c3,c4,c5,c6]
974 * xe_hi >> 2 [c7,c8, 0, 0]
975 * (...) -- data -- | -- illegal reading -- (...)
976 * ^ memory boundary
977 */
978 this->movups(xc_lo, this->ptr[src_]);
979 this->movups(xc_hi, this->ptr[src_ + 4 * sizeof(float)]);
980 this->mulps(xc_lo, xc_lo);
981 this->mulps(xc_hi, xc_hi);
982 this->addps(xsum_lo, xc_lo);
983 this->addps(xsum_hi, xc_hi);
984
985 this->movups(xd_lo, this->ptr[src_ + 1 * sizeof(float)]);
986 this->movups(xd_hi, this->ptr[src_ + 4 * sizeof(float)]);
987 this->psrldq(xd_hi, 1 * sizeof(float));
988 this->mulps(xd_lo, xd_lo);
989 this->mulps(xd_hi, xd_hi);
990 this->addps(xsum_lo, xd_lo);
991 this->addps(xsum_hi, xd_hi); // xsum <- xsum + xa^2+xb^2+xc^2+xd^2+xe^2
992
993 this->movups(xe_lo, this->ptr[src_ + 2 * sizeof(float)]);
994 this->movups(xe_hi, this->ptr[src_ + 4 * sizeof(float)]);
995 this->psrldq(xe_hi, 2 * sizeof(float));
996 this->mulps(xe_lo, xe_lo);
997 this->mulps(xe_hi, xe_hi);
998 this->addps(xsum_lo, xe_lo);
999 this->addps(xsum_hi, xe_hi);
1000
1001 this->movups(xdst_lo, xsum_lo);
1002 this->movups(xdst_hi, xsum_hi);
1003 // xdst <- xsum*xalpha_+xk_
1004 this->mulps(xdst_lo, this->ptr[store_addr_]);
1005 this->mulps(xdst_hi, this->ptr[store_addr_]);
1006 this->addps(xdst_lo, this->ptr[store_addr_ + 4 * sizeof(float)]);
1007 this->addps(xdst_hi, this->ptr[store_addr_ + 4 * sizeof(float)]);
1008
1009 this->movaps(xbase_lo, xdst_lo);
1010 this->movaps(xbase_hi, xdst_hi);
1011 if (pk_ != prop_kind::forward_inference) {
1012 this->movups(this->ptr[scratch_], xbase_lo);
1013 this->movups(this->ptr[scratch_ + 4 * sizeof(float)], xbase_hi);
1014 }
1015 this->mulps(xdst_lo, xdst_lo);
1016 this->mulps(xdst_hi, xdst_hi);
1017 this->mulps(xdst_lo, xbase_lo);
1018 this->mulps(xdst_hi, xbase_hi); // xdst = (xsum*xalpha_+xk_)^3;
1019 this->sqrtps(xdst_lo, xdst_lo);
1020 this->sqrtps(xdst_hi, xdst_hi);
1021 this->sqrtps(xdst_lo, xdst_lo);
1022 this->sqrtps(xdst_hi, xdst_hi); // xdst = (xsum*xalpha_+xk_)^0.75
1023 this->movups(xc_lo, this->ptr[src_]);
1024 this->movups(xc_hi, this->ptr[src_ + 4 * sizeof(float)]);
1025 this->divps(xc_lo, xdst_lo);
1026 this->divps(xc_hi, xdst_hi); // xdst = xsrc / (xsum*xalpha_+xk_)^0.75
1027
1028 this->movups(this->ptr[dst_], xc_lo);
1029 this->movups(this->ptr[dst_ + 4 * sizeof(float)], xc_hi);
1030
1031 this->postamble();
1032 }
1033
1034 template <>
nchw_body(int tail,int HW,prop_kind_t pk,Xbyak::Ymm ymask,Xbyak::Ymm ya,Xbyak::Ymm yb,Xbyak::Ymm yc,Xbyak::Ymm yd,Xbyak::Ymm ye,Xbyak::Ymm ysum)1035 void jit_uni_lrn_fwd_kernel_t<sse41, dnnl::impl::data_type::f32>::nchw_body(
1036 int tail, int HW, prop_kind_t pk, Xbyak::Ymm ymask, Xbyak::Ymm ya,
1037 Xbyak::Ymm yb, Xbyak::Ymm yc, Xbyak::Ymm yd, Xbyak::Ymm ye,
1038 Xbyak::Ymm ysum) {}
1039
1040 template <cpu_isa_t isa, data_type_t d_type>
nchw_body(int tail,int HW,prop_kind_t pk,Xbyak::Ymm ymask,Xbyak::Ymm ya,Xbyak::Ymm yb,Xbyak::Ymm yc,Xbyak::Ymm yd,Xbyak::Ymm ye,Xbyak::Ymm ysum)1041 void jit_uni_lrn_fwd_kernel_t<isa, d_type>::nchw_body(int tail, int HW,
1042 prop_kind_t pk, Xbyak::Ymm ymask, Xbyak::Ymm ya, Xbyak::Ymm yb,
1043 Xbyak::Ymm yc, Xbyak::Ymm yd, Xbyak::Ymm ye, Xbyak::Ymm ysum) {
1044 const Xbyak::Ymm &ydst = this->ymm14;
1045 const Xbyak::Ymm &ybase = this->ymm15;
1046
1047 this->vfmadd231ps(ysum, ye, ye);
1048
1049 this->vmovups(ydst, ysum);
1050 this->vfmadd132ps(ydst, yk_, valpha_); // ydst <- ysum*valpha_+yk_
1051
1052 this->vmovaps(ybase, ydst);
1053 if (pk_ != prop_kind::forward_inference) {
1054 if (tail != 0)
1055 this->vmaskmovps(this->ptr[scratch_], ymask, ybase);
1056 else
1057 this->vmovups(this->ptr[scratch_], ybase);
1058 }
1059 this->vmulps(ydst, ydst, ydst);
1060 this->vmulps(ydst, ydst, ybase); // ydst = (ysum*valpha_+yk_)^3;
1061 this->vsqrtps(ydst, ydst);
1062 this->vsqrtps(ydst, ydst); // ydst = (ysum*valpha_+yk_)^0.75
1063 this->vdivps(ydst, yc, ydst); // ydst = ysrc / (ysum*valpha_+yk_)^0.75
1064
1065 if (tail != 0)
1066 this->vmaskmovps(this->ptr[dst_], ymask, ydst);
1067 else
1068 this->vmovups(this->ptr[dst_], ydst);
1069
1070 this->vfnmadd231ps(ysum, ya, ya);
1071 this->vmovups(ya, yb);
1072 this->vmovups(yb, yc);
1073 this->vmovups(yc, yd);
1074 this->vmovups(yd, ye);
1075 }
1076
1077 template <cpu_isa_t isa, data_type_t d_type>
nchw_tail_sse41(int tail,Xbyak::Reg64 reg_dst,Xbyak::Xmm xtail_lo,Xbyak::Xmm xtail_hi)1078 void jit_uni_lrn_fwd_kernel_t<isa, d_type>::nchw_tail_sse41(int tail,
1079 Xbyak::Reg64 reg_dst, Xbyak::Xmm xtail_lo, Xbyak::Xmm xtail_hi) {}
1080
1081 template <>
1082 void jit_uni_lrn_fwd_kernel_t<sse41,
nchw_tail_sse41(int tail,Xbyak::Reg64 reg_dst,Xbyak::Xmm xtail_lo,Xbyak::Xmm xtail_hi)1083 dnnl::impl::data_type::f32>::nchw_tail_sse41(int tail,
1084 Xbyak::Reg64 reg_dst, Xbyak::Xmm xtail_lo, Xbyak::Xmm xtail_hi) {
1085 Xbyak::Xmm xmm_tmp = xmm10;
1086 this->movaps(xmm_tmp, xtail_hi);
1087
1088 if (tail > 3) {
1089 /* Store upper-half directly */
1090 this->movups(this->ptr[reg_dst + (tail - 4) * sizeof(float)], xtail_hi);
1091 this->movaps(xmm_tmp, xtail_lo);
1092 tail -= 4;
1093 }
1094 if (tail > 0) {
1095 /* Store on a single-element basis when 'tail' overlaps
1096 * with 'src_' */
1097 this->psrldq(xmm_tmp, (4 - tail) * sizeof(float));
1098 this->movss(this->ptr[reg_dst], xmm_tmp);
1099
1100 for (int i = 1; i < tail; i++) {
1101 this->psrldq(xmm_tmp, sizeof(float));
1102 this->movss(this->ptr[reg_dst + i * sizeof(float)], xmm_tmp);
1103 }
1104 }
1105 }
1106
1107 template <>
1108 void jit_uni_lrn_fwd_kernel_t<sse41,
nchw_body_sse41(int tail,int HW,prop_kind_t pk,Xbyak::Xmm xe_lo,Xbyak::Xmm xe_hi,Xbyak::Xmm xsum_lo,Xbyak::Xmm xsum_hi)1109 dnnl::impl::data_type::f32>::nchw_body_sse41(int tail, int HW,
1110 prop_kind_t pk, Xbyak::Xmm xe_lo, Xbyak::Xmm xe_hi, Xbyak::Xmm xsum_lo,
1111 Xbyak::Xmm xsum_hi) {
1112 const Xbyak::Xmm &xdst_lo = this->xmm0;
1113 const Xbyak::Xmm &xdst_hi = this->xmm1;
1114 const Xbyak::Xmm &xbase_lo = this->xmm6;
1115 const Xbyak::Xmm &xbase_hi = this->xmm7;
1116 const Xbyak::Xmm &xtmp_lo = this->xmm8;
1117 const Xbyak::Xmm &xtmp_hi = this->xmm9;
1118 const Xbyak::Xmm &xa_lo = this->xmm6;
1119 const Xbyak::Xmm &xa_hi = this->xmm7;
1120 const Xbyak::Xmm &xb_lo = this->xmm8;
1121 const Xbyak::Xmm &xb_hi = this->xmm9;
1122 const Xbyak::Xmm &xc_lo = this->xmm10;
1123 const Xbyak::Xmm &xc_hi = this->xmm11;
1124 const Xbyak::Xmm &xd_lo = this->xmm12;
1125 const Xbyak::Xmm &xd_hi = this->xmm13;
1126
1127 // store xe
1128 this->movaps(this->ptr[store_addr_ + 10 * 4 * sizeof(float)], xe_lo);
1129 this->movaps(this->ptr[store_addr_ + 11 * 4 * sizeof(float)], xe_hi);
1130
1131 this->mulps(xe_lo, xe_lo);
1132 this->mulps(xe_hi, xe_hi);
1133 this->addps(xsum_lo, xe_lo);
1134 this->addps(xsum_hi, xe_hi);
1135
1136 // xdst <- xsum*xalpha_+xk_
1137 this->movaps(xdst_lo, xsum_lo);
1138 this->movaps(xdst_hi, xsum_hi);
1139 this->mulps(xdst_lo, this->ptr[store_addr_ + 0 * 4 * sizeof(float)]);
1140 this->mulps(xdst_hi, this->ptr[store_addr_ + 0 * 4 * sizeof(float)]);
1141 this->addps(xdst_lo, this->ptr[store_addr_ + 1 * 4 * sizeof(float)]);
1142 this->addps(xdst_hi, this->ptr[store_addr_ + 1 * 4 * sizeof(float)]);
1143
1144 this->movaps(xbase_lo, xdst_lo);
1145 this->movaps(xbase_hi, xdst_hi);
1146 if (pk_ != prop_kind::forward_inference) {
1147 if (tail != 0) {
1148 nchw_tail_sse41(tail, scratch_, xbase_lo, xbase_hi);
1149 } else {
1150 this->movups(this->ptr[scratch_], xbase_lo);
1151 this->movups(this->ptr[scratch_ + 4 * sizeof(float)], xbase_hi);
1152 }
1153 }
1154 this->mulps(xdst_lo, xdst_lo);
1155 this->mulps(xdst_hi, xdst_hi);
1156 this->mulps(xdst_lo, xbase_lo);
1157 this->mulps(xdst_hi, xbase_hi); // xdst = (xsum*xalpha_+xk_)^3;
1158 this->sqrtps(xdst_lo, xdst_lo);
1159 this->sqrtps(xdst_hi, xdst_hi);
1160 this->sqrtps(xdst_lo, xdst_lo);
1161 this->sqrtps(xdst_hi, xdst_hi); // xdst = (xsum*xalpha_+xk_)^0.75
1162 this->movaps(xtmp_lo, this->ptr[store_addr_ + 6 * 4 * sizeof(float)]);
1163 this->movaps(xtmp_hi, this->ptr[store_addr_ + 7 * 4 * sizeof(float)]);
1164 this->divps(xtmp_lo, xdst_lo);
1165 this->divps(xtmp_hi, xdst_hi); // xdst = xsrc / (xsum*xalpha_+xk_)^0.75
1166 this->movaps(xdst_lo, xtmp_lo);
1167 this->movaps(xdst_hi, xtmp_hi);
1168
1169 if (tail != 0) {
1170 nchw_tail_sse41(tail, dst_, xdst_lo, xdst_hi);
1171 } else {
1172 this->movups(this->ptr[dst_], xdst_lo);
1173 this->movups(this->ptr[dst_ + 4 * sizeof(float)], xdst_hi);
1174 }
1175
1176 this->movaps(xa_lo, this->ptr[store_addr_ + 2 * 4 * sizeof(float)]);
1177 this->movaps(xa_hi, this->ptr[store_addr_ + 3 * 4 * sizeof(float)]);
1178 this->mulps(xa_lo, xa_lo);
1179 this->mulps(xa_hi, xa_hi);
1180 this->subps(xsum_lo, xa_lo);
1181 this->subps(xsum_hi, xa_hi);
1182
1183 // xa <- xb
1184 this->movaps(xb_lo, this->ptr[store_addr_ + 4 * 4 * sizeof(float)]);
1185 this->movaps(xb_hi, this->ptr[store_addr_ + 5 * 4 * sizeof(float)]);
1186 this->movaps(this->ptr[store_addr_ + 2 * 4 * sizeof(float)], xb_lo);
1187 this->movaps(this->ptr[store_addr_ + 3 * 4 * sizeof(float)], xb_hi);
1188
1189 // xb <- xc
1190 this->movaps(xc_lo, this->ptr[store_addr_ + 6 * 4 * sizeof(float)]);
1191 this->movaps(xc_hi, this->ptr[store_addr_ + 7 * 4 * sizeof(float)]);
1192 this->movaps(this->ptr[store_addr_ + 4 * 4 * sizeof(float)], xc_lo);
1193 this->movaps(this->ptr[store_addr_ + 5 * 4 * sizeof(float)], xc_hi);
1194
1195 // xc <- xd
1196 this->movaps(xd_lo, this->ptr[store_addr_ + 8 * 4 * sizeof(float)]);
1197 this->movaps(xd_hi, this->ptr[store_addr_ + 9 * 4 * sizeof(float)]);
1198 this->movaps(this->ptr[store_addr_ + 6 * 4 * sizeof(float)], xd_lo);
1199 this->movaps(this->ptr[store_addr_ + 7 * 4 * sizeof(float)], xd_hi);
1200
1201 // xd <- xe
1202 this->movaps(xe_lo, this->ptr[store_addr_ + 10 * 4 * sizeof(float)]);
1203 this->movaps(xe_hi, this->ptr[store_addr_ + 11 * 4 * sizeof(float)]);
1204 this->movaps(this->ptr[store_addr_ + 8 * 4 * sizeof(float)], xe_lo);
1205 this->movaps(this->ptr[store_addr_ + 9 * 4 * sizeof(float)], xe_hi);
1206 }
1207
1208 template <cpu_isa_t isa, data_type_t d_type>
nchw_body_sse41(int tail,int HW,prop_kind_t pk,Xbyak::Xmm xe_lo,Xbyak::Xmm xe_hi,Xbyak::Xmm xsum_lo,Xbyak::Xmm xsum_hi)1209 void jit_uni_lrn_fwd_kernel_t<isa, d_type>::nchw_body_sse41(int tail, int HW,
1210 prop_kind_t pk, Xbyak::Xmm xe_lo, Xbyak::Xmm xe_hi, Xbyak::Xmm xsum_lo,
1211 Xbyak::Xmm xsum_hi) {}
1212
1213 template <cpu_isa_t isa, data_type_t d_type>
jit_uni_lrn_fwd_kernel_t(const nchw_across_t & J,float A,float K,prop_kind_t pk,void * code_ptr,size_t code_size)1214 jit_uni_lrn_fwd_kernel_t<isa, d_type>::jit_uni_lrn_fwd_kernel_t(
1215 const nchw_across_t &J, float A, float K, prop_kind_t pk,
1216 void *code_ptr, size_t code_size)
1217 : Base(code_ptr, code_size)
1218 , config_(lrn_config_t::nchw_across)
1219 , nchw_across_(J)
1220 , alpha_(A)
1221 , k_(K)
1222 , pk_(pk) {}
1223
1224 template <cpu_isa_t isa, data_type_t d_type>
generate(const nchw_across_t & J)1225 void jit_uni_lrn_fwd_kernel_t<isa, d_type>::generate(const nchw_across_t &J) {
1226 static const uint32_t mask[]
1227 = {0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000,
1228 0x80000000, 0x80000000, 0, 0, 0, 0, 0, 0, 0};
1229 const Xbyak::Reg64 &c = this->r10;
1230 const Xbyak::Ymm &ymask = this->ymm2;
1231 const Xbyak::Ymm &ye = this->ymm3;
1232 const Xbyak::Ymm &ya = this->ymm4;
1233 const Xbyak::Ymm &yb = this->ymm5;
1234 const Xbyak::Ymm &yc = this->ymm6;
1235 const Xbyak::Ymm &yd = this->ymm7;
1236 const Xbyak::Ymm &ysum = this->ymm8;
1237
1238 this->preamble();
1239
1240 if (J.tail != 0) {
1241 this->mov(
1242 this->imm_addr64_, reinterpret_cast<size_t>(&mask[7 - J.tail]));
1243 this->vmovups(ymask, this->ptr[this->imm_addr64_]);
1244 }
1245 this->mov(this->imm_addr64_, float2int(this->alpha_));
1246 this->vmovq(xalpha_, this->imm_addr64_);
1247 this->vbroadcastss(valpha_, xalpha_);
1248
1249 this->mov(this->imm_addr64_, float2int(this->k_));
1250 this->vmovq(xk_, this->imm_addr64_);
1251 this->vbroadcastss(yk_, xk_);
1252
1253 this->mov(src_, this->ptr[this->param1 + 0]);
1254 this->mov(dst_, this->ptr[this->param1 + 8]);
1255 if (pk_ != prop_kind::forward_inference)
1256 this->mov(scratch_, this->ptr[this->param1 + 16]);
1257
1258 this->vxorps(ya, ya, ya);
1259 this->vxorps(yb, yb, yb);
1260 if (J.tail != 0)
1261 this->vmaskmovps(yc, ymask, this->ptr[src_ + J.HW * 0]);
1262 else
1263 this->vmovups(yc, this->ptr[src_ + J.HW * 0]);
1264 if (J.tail != 0)
1265 this->vmaskmovps(yd, ymask, this->ptr[src_ + J.HW * 4]);
1266 else
1267 this->vmovups(yd, this->ptr[src_ + J.HW * 4]);
1268
1269 this->vxorps(ysum, ysum, ysum);
1270 this->vfmadd231ps(ysum, yc, yc); // ysum <- ysum + ya^2+yb^2+yc^2+yd^2+ye^2
1271 this->vfmadd231ps(ysum, yd, yd);
1272
1273 this->mov(c, J.C - 2);
1274 Label lrn_loop;
1275 this->L(lrn_loop);
1276
1277 if (J.tail != 0)
1278 this->vmaskmovps(ye, ymask, this->ptr[src_ + J.HW * 8]);
1279 else
1280 this->vmovups(ye, this->ptr[src_ + J.HW * 8]);
1281
1282 nchw_body(J.tail, J.HW, pk_, ymask, ya, yb, yc, yd, ye, ysum);
1283
1284 this->add(src_, J.HW * 4);
1285 this->add(dst_, J.HW * 4);
1286 if (pk_ != prop_kind::forward_inference) this->add(scratch_, J.HW * 4);
1287 this->dec(c);
1288 this->cmp(c, 0);
1289 this->jne(lrn_loop, this->T_NEAR);
1290
1291 this->vxorps(ye, ye, ye);
1292
1293 nchw_body(J.tail, J.HW, pk_, ymask, ya, yb, yc, yd, ye, ysum);
1294 this->add(src_, J.HW * 4);
1295 this->add(dst_, J.HW * 4);
1296 if (pk_ != prop_kind::forward_inference) this->add(scratch_, J.HW * 4);
1297
1298 nchw_body(J.tail, J.HW, pk_, ymask, ya, yb, yc, yd, ye, ysum);
1299
1300 this->postamble();
1301 }
1302
1303 template <cpu_isa_t isa, data_type_t d_type>
1304 jit_uni_lrn_fwd_kernel_t<isa, d_type>::~jit_uni_lrn_fwd_kernel_t() = default;
1305
1306 template <>
1307 jit_uni_lrn_fwd_kernel_t<sse41, dnnl::impl::data_type::f32>::
jit_uni_lrn_fwd_kernel_t(const nchw_across_t & J,float A,float K,prop_kind_t pk,void * code_ptr,size_t code_size)1308 jit_uni_lrn_fwd_kernel_t(const nchw_across_t &J, float A, float K,
1309 prop_kind_t pk, void *code_ptr, size_t code_size)
1310 : Base(code_ptr, code_size)
1311 , config_(lrn_config_t::nchw_across)
1312 , nchw_across_(J)
1313 , alpha_(A)
1314 , k_(K)
1315 , pk_(pk) {}
1316
1317 template <>
generate(const nchw_across_t & J)1318 void jit_uni_lrn_fwd_kernel_t<sse41, dnnl::impl::data_type::f32>::generate(
1319 const nchw_across_t &J) {
1320
1321 /* Load from within the memory boundary of 'src_' and apply a zero-mask to
1322 * the 'x_hi' register:
1323 * block: src_ |tail = 3
1324 * src_: [x,x,x,x|a,b,c]
1325 * x_hi: [x,a,b,c]
1326 * mask: [0,1,1,1]
1327 * (...) -- data -- | -- illegal reading -- (...)
1328 * ^ memory boundary
1329 *
1330 * 'x_lo' is loaded with the elements between 'src_' and 'x_hi' when
1331 * tail.size is between [5:7]. The register is then left-shifted to
1332 * clear the overlapping elements with 'x_hi'.
1333 * block: - src_ - | tail = 7
1334 * src_: (...) [x,|a,b,c,d,e,f,g]
1335 * x_hi [d,e,f,g]
1336 * x_lo [a,b,c,d]
1337 * x_lo >> 1: [0,a,b,c]
1338 * (...) -- data -- | -- illegal reading -- (...)
1339 * ^ memory boundary
1340 *
1341 * - seg-fault happens if read occurs anywhere outside the
1342 * memory boundary.
1343 * */
1344 static const uint32_t mask[]
1345 = {0, 0, 0, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff};
1346 assert(J.HW > 3);
1347
1348 const Xbyak::Reg64 &c = r10;
1349
1350 // unused: xmm2
1351 const Xbyak::Xmm &xmask_hi = this->xmm3;
1352 const Xbyak::Xmm &xsum_lo = this->xmm4;
1353 const Xbyak::Xmm &xsum_hi = this->xmm5;
1354 const Xbyak::Xmm &xa_lo = this->xmm6;
1355 const Xbyak::Xmm &xa_hi = this->xmm7;
1356 const Xbyak::Xmm &xb_lo = this->xmm8;
1357 const Xbyak::Xmm &xb_hi = this->xmm9;
1358 const Xbyak::Xmm &xc_lo = this->xmm10;
1359 const Xbyak::Xmm &xc_hi = this->xmm11;
1360 const Xbyak::Xmm &xd_lo = this->xmm12;
1361 const Xbyak::Xmm &xd_hi = this->xmm13;
1362 const Xbyak::Xmm &xe_lo = this->xmm14;
1363 const Xbyak::Xmm &xe_hi = this->xmm15;
1364
1365 const int vlen = cpu_isa_traits<sse41>::vlen / sizeof(float);
1366
1367 bool compute_tail = J.tail != 0;
1368 bool load_lo = J.tail == 0 || J.tail > 4;
1369
1370 size_t h_offset = vlen;
1371 size_t l_shift = 0;
1372
1373 this->preamble();
1374
1375 this->mov(src_, this->ptr[this->param1 + 0]);
1376 this->mov(dst_, this->ptr[this->param1 + 8]);
1377 if (pk_ != prop_kind::forward_inference)
1378 this->mov(scratch_, this->ptr[this->param1 + 16]);
1379
1380 this->sub(rsp, stack_space_needed_);
1381 this->mov(store_addr_, rsp);
1382 this->and_(store_addr_, -15);
1383
1384 this->mov(this->imm_addr64_, float2int(this->alpha_));
1385 this->movq(xalpha_, this->imm_addr64_);
1386 this->shufps(xalpha_, xalpha_, 0);
1387
1388 this->mov(this->imm_addr64_, float2int(this->k_));
1389 this->movq(xk_, this->imm_addr64_);
1390 this->shufps(xk_, xk_, 0);
1391
1392 // put alpha_ and k_ into store (free up regs)
1393 this->movaps(this->ptr[store_addr_ + 0 * 4 * sizeof(float)], xalpha_);
1394 this->movaps(this->ptr[store_addr_ + 1 * 4 * sizeof(float)], xk_);
1395
1396 if (compute_tail) {
1397 assert(J.tail > 0 && J.tail < 2 * vlen);
1398 h_offset = J.tail - vlen;
1399 l_shift = nstl::min(2 * vlen - J.tail, vlen);
1400
1401 /* if 'tail' is between [1:3], need to zero-mask for underflow */
1402 size_t m_off = nstl::min(J.tail - 1, 3);
1403 this->mov(this->imm_addr64_, reinterpret_cast<size_t>(&mask[m_off]));
1404 this->movups(xmask_hi, this->ptr[this->imm_addr64_]);
1405 }
1406 // init xa, xb
1407 this->xorps(xa_lo, xa_lo);
1408 this->xorps(xa_hi, xa_hi);
1409 this->xorps(xb_lo, xb_lo);
1410 this->xorps(xb_hi, xb_hi);
1411
1412 // read xc, xd
1413 if (load_lo) this->movups(xc_lo, this->ptr[src_ + J.HW * 0]);
1414 this->movups(xc_hi, this->ptr[src_ + J.HW * 0 + h_offset * sizeof(float)]);
1415 if (compute_tail) {
1416 this->pslldq(xc_lo, l_shift * sizeof(float));
1417 this->andps(xc_hi, xmask_hi);
1418 }
1419
1420 if (load_lo) this->movups(xd_lo, this->ptr[src_ + J.HW * 4]);
1421 this->movups(xd_hi, this->ptr[src_ + J.HW * 4 + h_offset * sizeof(float)]);
1422 if (compute_tail) {
1423 this->pslldq(xd_lo, l_shift * sizeof(float));
1424 this->andps(xd_hi, xmask_hi);
1425 }
1426
1427 // put xa, xb, xc, xd into store to free-up regs
1428 this->movaps(this->ptr[store_addr_ + 2 * 4 * sizeof(float)], xa_lo);
1429 this->movaps(this->ptr[store_addr_ + 3 * 4 * sizeof(float)], xa_hi);
1430 this->movaps(this->ptr[store_addr_ + 4 * 4 * sizeof(float)], xb_lo);
1431 this->movaps(this->ptr[store_addr_ + 5 * 4 * sizeof(float)], xb_hi);
1432 this->movaps(this->ptr[store_addr_ + 6 * 4 * sizeof(float)], xc_lo);
1433 this->movaps(this->ptr[store_addr_ + 7 * 4 * sizeof(float)], xc_hi);
1434 this->movaps(this->ptr[store_addr_ + 8 * 4 * sizeof(float)], xd_lo);
1435 this->movaps(this->ptr[store_addr_ + 9 * 4 * sizeof(float)], xd_hi);
1436
1437 this->xorps(xsum_lo, xsum_lo);
1438 this->xorps(xsum_hi, xsum_hi);
1439 this->mulps(xc_lo, xc_lo);
1440 this->mulps(xc_hi, xc_hi);
1441 this->addps(xsum_lo, xc_lo);
1442 this->addps(xsum_hi, xc_hi);
1443 this->mulps(xd_lo, xd_lo);
1444 this->mulps(xd_hi, xd_hi);
1445 this->addps(xsum_lo, xd_lo);
1446 this->addps(xsum_hi, xd_hi); // xsum <- xsum + xa^2+xb^2+xc^2+xd^2+xe^2
1447
1448 this->mov(c, J.C - 2);
1449 Label lrn_loop;
1450 this->L(lrn_loop);
1451
1452 if (load_lo) this->movups(xe_lo, this->ptr[src_ + J.HW * 8]);
1453 this->movups(xe_hi, this->ptr[src_ + J.HW * 8 + h_offset * sizeof(float)]);
1454 if (compute_tail) {
1455 this->pslldq(xe_lo, l_shift * sizeof(float));
1456 this->andps(xe_hi, xmask_hi);
1457 }
1458
1459 nchw_body_sse41(J.tail, J.HW, pk_, xe_lo, xe_hi, xsum_lo, xsum_hi);
1460
1461 this->add(src_, J.HW * 4);
1462 this->add(dst_, J.HW * 4);
1463 if (pk_ != prop_kind::forward_inference) add(scratch_, J.HW * 4);
1464 this->dec(c);
1465 this->cmp(c, 0);
1466 this->jne(lrn_loop, this->T_NEAR);
1467
1468 this->xorps(xe_lo, xe_lo);
1469 this->xorps(xe_hi, xe_hi);
1470
1471 nchw_body_sse41(J.tail, J.HW, pk_, xe_lo, xe_hi, xsum_lo, xsum_hi);
1472 this->add(src_, J.HW * 4);
1473 this->add(dst_, J.HW * 4);
1474 if (pk_ != prop_kind::forward_inference) add(scratch_, J.HW * 4);
1475
1476 nchw_body_sse41(J.tail, J.HW, pk_, xe_lo, xe_hi, xsum_lo, xsum_hi);
1477
1478 this->add(rsp, stack_space_needed_);
1479
1480 this->postamble();
1481 }
1482
1483 //////////////////////////////////////////////////////////////////////////////
1484 // backward kernel
1485 template <cpu_isa_t isa, data_type_t d_type>
jit_uni_lrn_bwd_kernel_t(const nchw8c_across_t & J,float A,float B,int use_h_parallel,void * code_ptr,size_t code_size)1486 jit_uni_lrn_bwd_kernel_t<isa, d_type>::jit_uni_lrn_bwd_kernel_t(
1487 const nchw8c_across_t &J, float A, float B, int use_h_parallel,
1488 void *code_ptr, size_t code_size)
1489 : Base(code_ptr, code_size)
1490 , config_(lrn_config_t::nchw8c_across)
1491 , nchw8c_across_(J)
1492 , nalphabeta_(-2 * A * B)
1493 , use_h_parallelizm_(use_h_parallel) {}
1494
1495 template <cpu_isa_t isa, data_type_t d_type>
generate(const nchw8c_across_t & J)1496 void jit_uni_lrn_bwd_kernel_t<isa, d_type>::generate(const nchw8c_across_t &J) {
1497
1498 const Xbyak::Reg64 &t = this->rsp;
1499 const Xbyak::Reg64 &hw = this->r10;
1500 const Xbyak::Xmm &xsrc_prev = this->xmm1;
1501 const Xbyak::Xmm &xws_prev = this->xmm2;
1502 const Xbyak::Xmm &xdiffdst_prev = this->xmm3;
1503 const Xbyak::Ymm &ysrc = this->ymm4;
1504 const Xbyak::Ymm &yws = this->ymm5;
1505 const Xbyak::Ymm &ydiffdst = this->ymm6;
1506 const Xbyak::Xmm &xsrc_next = this->xmm7;
1507 const Xbyak::Xmm &xws_next = this->xmm8;
1508 const Xbyak::Xmm &xdiffdst_next = this->xmm9;
1509 const Xbyak::Ymm &ya = this->ymm10;
1510 const Xbyak::Xmm &xa = this->xmm10;
1511 const Xbyak::Ymm &yb = this->ymm11;
1512 const Xbyak::Ymm &yd = this->ymm12;
1513 const Xbyak::Ymm &ye = this->ymm13;
1514 const Xbyak::Ymm &ysum = this->ymm14;
1515 const Xbyak::Ymm &ydiffsrc = this->ymm15;
1516
1517 this->preamble();
1518
1519 #define GET_OFF(field) offsetof(jit_args_bwd_t, field)
1520 this->mov(src_, this->ptr[this->param1 + GET_OFF(src)]);
1521 this->mov(diffdst_, this->ptr[this->param1 + GET_OFF(diff_dst)]);
1522 this->mov(scratch_, this->ptr[this->param1 + GET_OFF(scratch)]);
1523 this->mov(bwd_intermediate_res_,
1524 this->ptr[this->param1 + GET_OFF(bwd_intermediate_res)]);
1525 this->mov(diffsrc_, this->ptr[this->param1 + GET_OFF(diff_src)]);
1526 #undef GET_OFF
1527
1528 this->sub(t, 64);
1529 this->mov(this->imm_addr64_, float2int(this->nalphabeta_));
1530 this->vmovq(xnalphabeta_, this->imm_addr64_);
1531 this->vbroadcastss(vnalphabeta_, xnalphabeta_);
1532
1533 bool is_single = J.version == 3;
1534 bool is_first = J.version == -1 || J.version == -2;
1535 bool is_last = J.version == +1 || J.version == -2;
1536
1537 if (is_first || is_single) {
1538 this->vxorps(xsrc_prev, xsrc_prev, xsrc_prev);
1539 this->vmovups(this->ptr[t + 0], xsrc_prev);
1540 }
1541 if (is_last || is_single) {
1542 this->vxorps(xsrc_next, xsrc_next, xsrc_next);
1543 this->vmovups(this->ptr[t + 48], xsrc_next);
1544 }
1545 this->mov(hw, this->use_h_parallelizm_ ? J.W : J.H * J.W);
1546 Label lrn_loop;
1547 this->L(lrn_loop);
1548 {
1549 if (!is_first && !is_single) {
1550 this->vmovups(xws_prev, this->ptr[scratch_ - J.H * J.W * 32 + 16]);
1551 this->vmovups(xsrc_prev, this->ptr[src_ - J.H * J.W * 32 + 16]);
1552 this->vmovups(
1553 xdiffdst_prev, this->ptr[diffdst_ - J.H * J.W * 32 + 16]);
1554 this->vmulps(xa, xws_prev, xws_prev);
1555 this->vmulps(xa, xa, xws_prev);
1556 this->vsqrtps(xa, xa);
1557 this->vsqrtps(xa, xa);
1558 this->vmulps(xa, xa, xws_prev);
1559 this->vdivps(xsrc_prev, xsrc_prev, xa);
1560 this->vmulps(xdiffdst_prev, xdiffdst_prev, xsrc_prev);
1561 }
1562
1563 this->vmovups(ysrc, this->ptr[src_]);
1564 this->vmovups(yws, this->ptr[scratch_]);
1565 this->vmovups(ydiffdst, this->ptr[diffdst_]);
1566 this->vmulps(ya, yws, yws);
1567 this->vmulps(ya, ya, yws);
1568 this->vsqrtps(ya, ya);
1569 this->vsqrtps(ya, ya);
1570 this->vdivps(ydiffsrc, ydiffdst, ya);
1571 this->vdivps(ysum, ydiffsrc, yws);
1572 this->vmulps(ysum, ysum, ysrc);
1573
1574 if (!is_last && !is_single) {
1575 this->vmovups(xws_next, this->ptr[scratch_ + J.H * J.W * 32]);
1576 this->vmovups(xsrc_next, this->ptr[src_ + J.H * J.W * 32]);
1577 this->vmovups(xdiffdst_next, this->ptr[diffdst_ + J.H * J.W * 32]);
1578 this->vmulps(xa, xws_next, xws_next);
1579 this->vmulps(xa, xa, xws_next);
1580 this->vsqrtps(xa, xa);
1581 this->vsqrtps(xa, xa);
1582 this->vmulps(xa, xa, xws_next);
1583 this->vdivps(xsrc_next, xsrc_next, xa);
1584 this->vmulps(xdiffdst_next, xdiffdst_next, xsrc_next);
1585 }
1586
1587 if (!is_first && !is_single)
1588 this->vmovups(this->ptr[t + 0], xdiffdst_prev);
1589 this->vmovups(this->ptr[t + 16], ysum);
1590 if (!is_last && !is_single)
1591 this->vmovups(this->ptr[t + 48], xdiffdst_next);
1592
1593 this->vmovups(ya, this->ptr[t + 16 - 8]);
1594 this->vmovups(yb, this->ptr[t + 16 - 4]);
1595 this->vaddps(ysum, ysum, ya);
1596 this->vmulps(ysrc, ysrc, vnalphabeta_);
1597 this->vaddps(ysum, ysum, yb);
1598
1599 this->vmovups(yd, this->ptr[t + 16 + 4]);
1600 this->vmovups(ye, this->ptr[t + 16 + 8]);
1601 this->vaddps(ysum, ysum, yd);
1602 this->vaddps(ysum, ysum, ye);
1603
1604 this->vfmadd231ps(ydiffsrc, ysum, ysrc);
1605
1606 this->vmovups(this->ptr[diffsrc_], ydiffsrc);
1607
1608 this->add(src_, 32);
1609 this->add(diffsrc_, 32);
1610 this->add(diffdst_, 32);
1611 this->add(scratch_, 32);
1612
1613 this->dec(hw);
1614 this->cmp(hw, 0);
1615 this->jne(lrn_loop, this->T_NEAR);
1616 }
1617
1618 this->add(t, 64);
1619 this->postamble();
1620 }
1621
1622 template <cpu_isa_t isa, data_type_t d_type>
jit_uni_lrn_bwd_kernel_t(const within_config_t & config,float A,float B,void * code_ptr,size_t code_size)1623 jit_uni_lrn_bwd_kernel_t<isa, d_type>::jit_uni_lrn_bwd_kernel_t(
1624 const within_config_t &config, float A, float B, void *code_ptr,
1625 size_t code_size)
1626 : Base(config, code_ptr, code_size)
1627 , config_(lrn_config_t::within_config)
1628 , within_config_(config)
1629 , nalphabeta_(-2.0f * A * B) {}
1630
1631 template <cpu_isa_t isa, data_type_t d_type>
generate(const within_config_t & config)1632 void jit_uni_lrn_bwd_kernel_t<isa, d_type>::generate(
1633 const within_config_t &config) {
1634
1635 this->preamble();
1636
1637 #define GET_OFF(field) offsetof(jit_args_bwd_t, field)
1638 this->mov(src_, this->ptr[this->param1 + GET_OFF(src)]);
1639 this->mov(diffdst_, this->ptr[this->param1 + GET_OFF(diff_dst)]);
1640 this->mov(scratch_, this->ptr[this->param1 + GET_OFF(scratch)]);
1641 this->mov(bwd_intermediate_res_,
1642 this->ptr[this->param1 + GET_OFF(bwd_intermediate_res)]);
1643 this->mov(diffsrc_, this->ptr[this->param1 + GET_OFF(diff_src)]);
1644 #undef GET_OFF
1645 this->load_constant(nalphabeta_, vnalphabeta_, xnalphabeta_);
1646
1647 static const int max_reg_blocks = isa == avx512_common ? 3 : 1;
1648 this->within_loop(config, max_reg_blocks, prop_kind::backward);
1649
1650 this->postamble();
1651 }
1652
1653 template <cpu_isa_t isa, data_type_t d_type>
within_body(int hoff,int Hoff,int woff,int Woff,int stride,prop_kind_t pk,const int reg_block,int pixel_offset)1654 void jit_uni_lrn_bwd_kernel_t<isa, d_type>::within_body(int hoff, int Hoff,
1655 int woff, int Woff, int stride, prop_kind_t pk, const int reg_block,
1656 int pixel_offset) {
1657
1658 static const std::array<Vmm, 3> vsum {{Vmm(1), Vmm(9), Vmm(18)}};
1659 static const std::array<std::array<Vmm, 3>, 3> diff_dst {{
1660 {{Vmm(2), Vmm(3), Vmm(6)}},
1661 {{Vmm(10), Vmm(11), Vmm(23)}},
1662 {{Vmm(19), Vmm(20), Vmm(26)}},
1663 }};
1664 static const std::array<std::array<Vmm, 3>, 3> ws1 {{
1665 {{Vmm(4), Vmm(5), Vmm(15)}},
1666 {{Vmm(12), Vmm(13), Vmm(27)}},
1667 {{Vmm(21), Vmm(22), Vmm(28)}},
1668 }};
1669 static const std::array<Vmm, 3> ws0 = !this->emulate_bfloat_
1670 ? std::array<Vmm, 3> {{Vmm(29), Vmm(30), Vmm(31)}}
1671 : std::array<Vmm, 3> {{Vmm(6), Vmm(15), Vmm(23)}};
1672 static const std::array<Vmm, 3> src {{Vmm(7), Vmm(16), Vmm(24)}};
1673 static const std::array<Vmm, 3> a {{Vmm(8), Vmm(17), Vmm(25)}};
1674
1675 static const std::size_t used_tmp_regs
1676 = this->emulate_bfloat_ ? ws1[0].size() - 1 : ws1[0].size();
1677
1678 IRB_LOOP(this->uni_vxorps(vsum[irb], vsum[irb], vsum[irb]));
1679 for (int i = hoff; i <= Hoff; ++i) {
1680 for (int j = woff; j <= Woff; ++j) {
1681 const auto idx = this->tempIdx_ % used_tmp_regs;
1682 IRB_LOOP(this->load_data(diff_dst[irb][idx],
1683 this->ptr[(diffdst_ + pixel_offset + irb_off)
1684 + (i * stride + j) * this->single_pixel_offset_]));
1685 IRB_LOOP(this->load_data(ws1[irb][idx],
1686 this->ptr[(bwd_intermediate_res_ + pixel_offset + irb_off)
1687 + (i * stride + j) * this->single_pixel_offset_]));
1688
1689 if (i == 0 && j == 0) {
1690 if (d_type == dnnl::impl::data_type::bf16) {
1691 IRB_LOOP(this->load_data(ws0[irb],
1692 this->ptr[(scratch_ + pixel_offset + irb_off)]));
1693 IRB_LOOP(
1694 this->vdivps(a[irb], diff_dst[irb][idx], ws0[irb]));
1695 } else {
1696 IRB_LOOP(this->vdivps(a[irb], diff_dst[irb][idx],
1697 this->ptr[(scratch_ + pixel_offset + irb_off)]));
1698 }
1699 }
1700
1701 IRB_LOOP(this->vfmadd231ps(
1702 vsum[irb], ws1[irb][idx], diff_dst[irb][idx]));
1703 ++(this->tempIdx_);
1704 }
1705 }
1706
1707 this->tempIdx_ = this->tempIdx_ % used_tmp_regs;
1708
1709 if (d_type == dnnl::impl::data_type::bf16) {
1710 IRB_LOOP(this->load_data(
1711 src[irb], this->ptr[(src_ + pixel_offset + irb_off)]));
1712 IRB_LOOP(this->vmulps(src[irb], this->vnalphabeta_, src[irb]));
1713 } else {
1714 IRB_LOOP(this->vmulps(src[irb], this->vnalphabeta_,
1715 this->ptr[(src_ + pixel_offset + irb_off)]));
1716 }
1717
1718 IRB_LOOP(this->vfmadd231ps(a[irb], src[irb], vsum[irb]));
1719
1720 IRB_LOOP(this->store_data(
1721 this->ptr[diffsrc_ + pixel_offset + irb_off], a[irb]));
1722
1723 if (isa == avx512_common)
1724 this->reg_block_idx_ = (this->reg_block_idx_ % vsum.size()) + 1;
1725 }
1726
1727 template <cpu_isa_t isa, data_type_t d_type>
move_data_pointers(int pixel_count,prop_kind_t pk)1728 void jit_uni_lrn_bwd_kernel_t<isa, d_type>::move_data_pointers(
1729 int pixel_count, prop_kind_t pk) {
1730 const int pixel_offset = this->single_pixel_offset_ * pixel_count;
1731 this->add(src_, pixel_offset);
1732 this->add(diffsrc_, pixel_offset);
1733 this->add(diffdst_, pixel_offset);
1734 this->add(scratch_, pixel_offset);
1735 this->add(bwd_intermediate_res_, pixel_offset);
1736 }
1737
1738 template class jit_uni_lrn_fwd_kernel_t<sse41, dnnl::impl::data_type::f32>;
1739 template class jit_uni_lrn_fwd_kernel_t<avx2, dnnl::impl::data_type::f32>;
1740 template class jit_uni_lrn_fwd_kernel_t<avx512_common,
1741 dnnl::impl::data_type::f32>;
1742 template class jit_uni_lrn_fwd_kernel_t<avx512_common,
1743 dnnl::impl::data_type::bf16>;
1744
1745 template class jit_uni_lrn_kernel_t<
1746 jit_uni_lrn_fwd_kernel_t<sse41, dnnl::impl::data_type::f32>>;
1747 template class jit_uni_lrn_kernel_t<
1748 jit_uni_lrn_fwd_kernel_t<avx2, dnnl::impl::data_type::f32>>;
1749 template class jit_uni_lrn_kernel_t<
1750 jit_uni_lrn_fwd_kernel_t<avx512_common, dnnl::impl::data_type::f32>>;
1751 template class jit_uni_lrn_kernel_t<
1752 jit_uni_lrn_fwd_kernel_t<avx512_common, dnnl::impl::data_type::bf16>>;
1753
1754 template class jit_uni_lrn_bwd_kernel_t<avx512_common,
1755 dnnl::impl::data_type::f32>;
1756 template class jit_uni_lrn_bwd_kernel_t<avx512_common,
1757 dnnl::impl::data_type::bf16>;
1758 template class jit_uni_lrn_bwd_kernel_t<avx2, dnnl::impl::data_type::f32>;
1759
1760 template class jit_uni_lrn_kernel_t<
1761 jit_uni_lrn_bwd_kernel_t<avx2, dnnl::impl::data_type::f32>>;
1762 template class jit_uni_lrn_kernel_t<
1763 jit_uni_lrn_bwd_kernel_t<avx512_common, dnnl::impl::data_type::f32>>;
1764 template class jit_uni_lrn_kernel_t<
1765 jit_uni_lrn_bwd_kernel_t<avx512_common, dnnl::impl::data_type::bf16>>;
1766
1767 } // namespace x64
1768 } // namespace cpu
1769 } // namespace impl
1770 } // namespace dnnl
1771
1772 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
1773