1 /*******************************************************************************
2 * Copyright 2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "common/dnnl_thread.hpp"
18
19 #include "cpu/x64/jit_avx512_core_bf16cvt.hpp"
20 #include "cpu/x64/jit_uni_binary_kernel.hpp"
21
22 namespace dnnl {
23 namespace impl {
24 namespace cpu {
25 namespace x64 {
26
27 #define PARAM_OFF(x) offsetof(jit_binary_call_s, x)
28
get_supported_bcast_strategies()29 static bcast_set_t get_supported_bcast_strategies() {
30 return {broadcasting_strategy_t::scalar, broadcasting_strategy_t::per_oc,
31 broadcasting_strategy_t::per_oc_spatial};
32 }
33
binary_kernel_t(const size_t vlen,const binary_pd_t * pd,const jit_binary_conf_t conf,bool tail_kernel)34 binary_kernel_t::binary_kernel_t(const size_t vlen, const binary_pd_t *pd,
35 const jit_binary_conf_t conf, bool tail_kernel)
36 : vlen_(vlen)
37 , simd_w_(vlen / sizeof(float))
38 , pd_(pd)
39 , conf_(conf)
40 , is_tail_kernel_(tail_kernel)
41 , is_src1_outer_dims_tail_(
42 conf_.is_src_different_layouts && conf_.outer_dims % simd_w_)
43 , tail_size_(get_tail_size())
44 , padding_tail_size_(
45 pd->src_md(0)->padded_dims[1] - pd->src_md(0)->dims[1]) {}
46
get_tail_size() const47 size_t binary_kernel_t::get_tail_size() const {
48 memory_desc_wrapper src0_d(pd_->src_md(0));
49 const auto &dims = src0_d.dims();
50 const auto &ndims = src0_d.ndims();
51
52 dim_t nelems = 0;
53
54 if (ndims == 1)
55 nelems = dims[0];
56 else if (is_src1_outer_dims_tail_)
57 nelems = conf_.outer_dims;
58 else if (!conf_.is_i8 && conf_.op_type == op_t::c_blocked
59 && (is_tail_kernel_ || conf_.bcast_type == bcast_t::per_w))
60 nelems = dims[1];
61 else if (conf_.bcast_type == bcast_t::none
62 && !conf_.postops_per_oc_broadcast_exists)
63 nelems = src0_d.nelems(true);
64 else {
65 if (conf_.op_type == op_t::n_spatial_c)
66 nelems = dims[1];
67 else if (conf_.op_type == op_t::n_c_spatial && ndims >= 3)
68 nelems = conf_.bcast_type == bcast_t::per_w
69 ? dims[ndims - 1]
70 : utils::array_product(dims + 2, ndims - 2);
71 }
72 // it's float due to for bfloat16 we still load 16 elements, not 32.
73 return nelems % simd_w_;
74 }
75
76 template <cpu_isa_t isa>
jit_uni_binary_kernel_t(const binary_pd_t * pd,const jit_binary_conf_t conf,bool tail_kernel)77 jit_uni_binary_kernel_t<isa>::jit_uni_binary_kernel_t(
78 const binary_pd_t *pd, const jit_binary_conf_t conf, bool tail_kernel)
79 : binary_kernel_t(cpu_isa_traits<isa>::vlen, pd, conf, tail_kernel)
80 , offt_src0_(vlen_ / (conf_.is_bf16 ? 2 : 1))
81 , offt_src1_(conf_.use_stride_src1 ? offt_src0_ : 0)
82 , io_(this, isa, {conf_.src0_type, conf_.src1_type, conf_.dst_type},
83 {false},
84 io::io_tail_conf_t {simd_w_, tail_size_, tail_opmask_,
85 vmm_tail_vmask_.getIdx(), reg_tmp_},
86 io::io_emu_bf16_conf_t {vreg_bf16_emu_1_, vreg_bf16_emu_2_,
87 vreg_bf16_emu_3_, reg_tmp_, vreg_bf16_emu_4_},
88 create_saturation_vmm_map(),
89 io::io_gather_conf_t {simd_w_, full_mask_,
90 vmm_full_mask_.getIdx(), reg_tmp_, reg_tmp1_,
91 vmm_tmp_gather_.getIdx()}) {
92 init();
93 }
94
95 template <cpu_isa_t isa>
96 std::map<data_type_t, io::io_saturation_conf_t>
create_saturation_vmm_map() const97 jit_uni_binary_kernel_t<isa>::create_saturation_vmm_map() const {
98
99 std::map<data_type_t, io::io_saturation_conf_t> saturation_map {};
100
101 if (conf_.is_i8)
102 saturation_map.emplace(conf_.dst_type,
103 io::io_saturation_conf_t {vreg_zero_.getIdx(),
104 vreg_saturation_ubound_.getIdx(), reg_tmp_});
105
106 return saturation_map;
107 }
108
109 template <cpu_isa_t isa>
init()110 void jit_uni_binary_kernel_t<isa>::init() {
111 if (conf_.with_postops) init_post_ops_injector();
112 }
113
114 template <cpu_isa_t isa>
init_post_ops_injector()115 void jit_uni_binary_kernel_t<isa>::init_post_ops_injector() {
116 const memory_desc_wrapper src0_d(pd_->src_md(0));
117 const auto &po = pd_->attr()->post_ops_;
118
119 const eltwise_injector::static_params_t esp(true /*save_state*/,
120 reg_elt_inj_table_, elt_inj_opmask_, true /*is_fwd*/,
121 false /*use_dst*/);
122 const binary_injector::rhs_arg_static_params_t rhs_arg_bsp {10, reg_tmp_,
123 reg_elt_inj_table_, true /*preserve gpr*/, true /*preserve vmm*/,
124 PARAM_OFF(post_ops_binary_rhs_arg_vec), src0_d, tail_size_,
125 tail_opmask_, false /*use_exact_tail_scalar_bcast*/};
126 const binary_injector::static_params_t bsp(
127 this->param1, get_supported_bcast_strategies(), rhs_arg_bsp);
128
129 postops_injector_ = utils::make_unique<
130 injector::jit_uni_postops_injector_t<inject_isa>>(
131 this, po, bsp, esp);
132 }
133
134 template <cpu_isa_t isa>
apply_postops(int unroll,bool tail)135 void jit_uni_binary_kernel_t<isa>::apply_postops(int unroll, bool tail) {
136 binary_injector::rhs_arg_dynamic_params_t rhs_arg_params;
137 for (int vmm_idx = 1; vmm_idx < unroll + vmm_start_idx_; vmm_idx++) {
138 if (utils::one_of(conf_.op_type, op_t::c_blocked, op_t::n_c_spatial)) {
139 rhs_arg_params.vmm_idx_to_oc_elem_off_addr.emplace(
140 vmm_idx, ptr[param1 + PARAM_OFF(oc_l_off)]);
141 } else if (conf_.op_type == op_t::n_spatial_c) {
142 rhs_arg_params.vmm_idx_to_oc_off_oprnd.emplace(
143 vmm_idx, reg_off_rhs_postops_);
144 rhs_arg_params.vmm_idx_to_oc_elem_off_val.emplace(vmm_idx,
145 (vmm_idx - vmm_start_idx_) * static_cast<int>(simd_w_));
146 }
147 if (tail) rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx);
148 }
149 postops_injector_->compute_vector_range(
150 1, unroll + vmm_start_idx_, rhs_arg_params);
151 }
152
153 template <cpu_isa_t isa>
load_kernel_params()154 void jit_uni_binary_kernel_t<isa>::load_kernel_params() {
155 mov(reg_tmp_, float2int(conf_.sum_scale));
156 uni_vmovq(xreg_sum_scale_, reg_tmp_);
157 uni_vbroadcastss(vreg_sum_scale_, xreg_sum_scale_);
158 if (is_src1_outer_dims_tail_)
159 mov(reg_outer_dims_range_,
160 ptr[reg_param_ + PARAM_OFF(spat_offt_count)]);
161 else
162 mov(reg_reverse_spat_offt_,
163 ptr[reg_param_ + PARAM_OFF(spat_offt_count)]);
164 mov(reg_src0_, ptr[reg_param_ + PARAM_OFF(src0)]);
165 mov(reg_src1_, ptr[reg_param_ + PARAM_OFF(src1)]);
166 mov(reg_dst_, ptr[reg_param_ + PARAM_OFF(dst)]);
167 if (conf_.is_src_different_layouts) {
168 mov(reg_tmp_, ptr[reg_param_ + PARAM_OFF(indices)]);
169 uni_vmovdqu(vmm_indices_, ptr[reg_tmp_]);
170
171 mov(reg_src1_stride_range_,
172 ptr[reg_param_ + PARAM_OFF(src1_stride_range)]);
173 mov(reg_reverse_src1_stride_range_, reg_src1_stride_range_);
174 }
175 if (conf_.do_scale_src0)
176 mov(reg_scales_src0_, ptr[reg_param_ + PARAM_OFF(scales_src0)]);
177 if (conf_.do_scale_src1)
178 mov(reg_scales_src1_, ptr[reg_param_ + PARAM_OFF(scales_src1)]);
179 }
180
181 template <cpu_isa_t isa>
src0_ptr(size_t offt)182 Address jit_uni_binary_kernel_t<isa>::src0_ptr(size_t offt) {
183 return vmmword[reg_src0_ + reg_offt_src0_ + offt];
184 }
185
186 template <cpu_isa_t isa>
src1_ptr(size_t offt)187 Address jit_uni_binary_kernel_t<isa>::src1_ptr(size_t offt) {
188 return vmmword[reg_src1_ + reg_offt_src1_ + offt];
189 }
190
191 template <cpu_isa_t isa>
dst_ptr(size_t offt)192 Address jit_uni_binary_kernel_t<isa>::dst_ptr(size_t offt) {
193 const Reg64 ®_offt_dst = conf_.is_i8 ? reg_offt_dst_ : reg_offt_src0_;
194 return vmmword[reg_dst_ + reg_offt_dst + offt];
195 }
196
197 template <cpu_isa_t isa>
cmp_predicate(alg_kind_t alg)198 unsigned int jit_uni_binary_kernel_t<isa>::cmp_predicate(alg_kind_t alg) {
199 using namespace alg_kind;
200 switch (alg) {
201 case binary_ge: return _cmp_nlt_us;
202 case binary_gt: return _cmp_nle_us;
203 case binary_le: return _cmp_le_os;
204 case binary_lt: return _cmp_lt_os;
205 case binary_eq: return _cmp_eq_oq;
206 case binary_ne: return _cmp_neq_uq;
207 default: assert(!"not supported operation!"); return -1;
208 }
209 }
210
211 template <cpu_isa_t isa>
perform_op(const Vmm & v0,const Vmm & v1,const Vmm & s_src0,const Vmm & s_src1)212 void jit_uni_binary_kernel_t<isa>::perform_op(
213 const Vmm &v0, const Vmm &v1, const Vmm &s_src0, const Vmm &s_src1) {
214 using namespace alg_kind;
215 const auto alg = pd_->desc()->alg_kind;
216 const bool cmp_op = utils::one_of(alg, alg_kind::binary_ge,
217 alg_kind::binary_gt, alg_kind::binary_le, alg_kind::binary_lt,
218 alg_kind::binary_eq, alg_kind::binary_ne);
219 if (conf_.do_scale_src0) uni_vmulps(v0, v0, s_src0);
220 if (conf_.do_scale_src1
221 && (conf_.is_i8
222 || (offt_src1_ != 0 && !conf_.broadcast_src1_value)))
223 uni_vmulps(v1, v1, s_src1);
224
225 if (alg == binary_add)
226 uni_vaddps(v0, v0, v1);
227 else if (alg == binary_mul)
228 uni_vmulps(v0, v0, v1);
229 else if (alg == binary_max)
230 uni_vmaxps(v0, v0, v1);
231 else if (alg == binary_min)
232 uni_vminps(v0, v0, v1);
233 else if (alg == binary_div)
234 uni_vdivps(v0, v0, v1);
235 else if (alg == binary_sub)
236 uni_vsubps(v0, v0, v1);
237 else if (cmp_op) {
238 const unsigned int predicate = cmp_predicate(alg);
239 if (is_avx512) {
240 vcmpps(cmp_mask, v0, v1, predicate);
241 vmovups(v0 | cmp_mask | T_z, vreg_one_);
242 } else {
243 uni_vcmpps(v0, v0, v1, predicate);
244 uni_vminps(v0, v0, vreg_one_);
245 }
246 } else
247 assert(!"not supported operation!");
248 }
249
250 template <cpu_isa_t isa>
prepare_isa_kernel()251 void jit_uni_binary_kernel_t<isa>::prepare_isa_kernel() {
252 if (conf_.is_bf16) io_.init_bf16();
253 if (tail_size_ > 0) io_.prepare_tail_mask();
254 if (conf_.is_src_different_layouts && is_superset(isa, avx2)) {
255 io_.init_full_mask();
256 io_.prepare_full_mask();
257 }
258 }
259
260 template <cpu_isa_t isa>
compute_bcast(bool tail)261 void jit_uni_binary_kernel_t<isa>::compute_bcast(bool tail) {
262 if (conf_.broadcast_src1_value) {
263 if (conf_.is_i8)
264 uni_vpxor(xreg_bcast_src1_, xreg_bcast_src1_, xreg_bcast_src1_);
265 io_.at(conf_.src1_type)->broadcast(src1_ptr(), vreg_bcast_src1_);
266 } else if (!conf_.is_i8 && offt_src1_ == 0) {
267 io_.at(conf_.src1_type)->load(src1_ptr(), vreg_bcast_src1_, tail);
268 }
269 }
270
271 template <cpu_isa_t isa>
load_src1(const Vmm & vreg_src1,const int offt,bool tail)272 void jit_uni_binary_kernel_t<isa>::load_src1(
273 const Vmm &vreg_src1, const int offt, bool tail) {
274 if (conf_.is_src_different_layouts) {
275 // if different layouts, gather data with strides
276 // after getting to stride range, offset is restored and
277 // increased
278 io_.at(conf_.src1_type)
279 ->gather(reg_src1_, vmm_indices_, vreg_src1, tail);
280 // gather is using register instead of operand to read address
281 // use reg_src1_ directly, without offset stored in second
282 // register
283 add(reg_src1_,
284 types::data_type_size(conf_.src1_type) * conf_.src1_stride
285 * simd_w_);
286 sub(reg_reverse_src1_stride_range_,
287 types::data_type_size(conf_.src1_type) * conf_.src1_stride
288 * simd_w_);
289
290 Label src1_stride_range_not_exceed, src1_C_tail_end;
291
292 cmp(reg_reverse_src1_stride_range_, 0);
293 jg(src1_stride_range_not_exceed, T_NEAR);
294 {
295 pop(reg_src1_);
296 add(reg_src1_, types::data_type_size(conf_.src1_type));
297 push(reg_src1_);
298 mov(reg_reverse_src1_stride_range_, reg_src1_stride_range_);
299 }
300 L(src1_stride_range_not_exceed);
301 } else
302 io_.at(conf_.src1_type)
303 ->load(src1_ptr(offt * types::data_type_size(conf_.src1_type)),
304 vreg_src1, tail);
305 }
306
307 template <cpu_isa_t isa>
compute_dst(int unroll,bool tail)308 void jit_uni_binary_kernel_t<isa>::compute_dst(int unroll, bool tail) {
309 for (int i = 0; i < unroll; i++) {
310 const Vmm vreg_tmp_src0 = Vmm(i + vmm_start_idx_);
311 const Vmm vreg_tmp = conf_.is_src_different_layouts
312 ? vmm_gathered_src_
313 : Vmm(unroll + i + vmm_start_idx_);
314 const Vmm vreg_tmp_src1 = offt_src1_ ? vreg_tmp : vreg_bcast_src1_;
315 const int offt = simd_w_ * i;
316 io_.at(conf_.src0_type)
317 ->load(src0_ptr(offt * types::data_type_size(conf_.src0_type)),
318 vreg_tmp_src0, tail);
319 if (offt_src1_) load_src1(vreg_tmp_src1, offt, tail);
320
321 // avoid multiple multiplication on input scale for broadcasted vreg
322 // not needed for different layouts
323 if (!conf_.is_src_different_layouts)
324 uni_vmovups(vreg_tmp, vreg_tmp_src1);
325 perform_op(
326 vreg_tmp_src0, vreg_tmp, vreg_scales_src0_, vreg_scales_src1_);
327 if (conf_.do_sum) {
328 io_.at(conf_.dst_type)
329 ->load(dst_ptr(offt
330 * types::data_type_size(conf_.dst_type)),
331 vreg_tmp, tail);
332 uni_vfmadd231ps(vreg_tmp_src0, vreg_tmp, vreg_sum_scale_);
333 }
334 }
335
336 if (postops_injector_) apply_postops(unroll, tail);
337
338 for (int i = 0; i < unroll; i++) {
339 const Vmm vreg_tmp_src0 = Vmm(i + vmm_start_idx_);
340 const int offt = simd_w_ * i;
341 const auto dt_size = types::data_type_size(conf_.dst_type);
342
343 if (is_tail_kernel_ && padding_tail_size_) {
344 // apply zero-padding
345 Label end;
346 auto off_base = 0;
347 auto zero_pad_left = padding_tail_size_;
348
349 // inplace data is assumed to be zero-padded
350 cmp(reg_src0_, reg_dst_);
351 je(end, T_NEAR);
352
353 if (zero_pad_left >= simd_w_ - tail_size_) {
354 vxorps(vreg_zero_, vreg_zero_, vreg_zero_);
355 if (is_avx512)
356 uni_vmovups(vreg_zero_ | tail_opmask_, vreg_tmp_src0);
357 else
358 uni_vblendvps(vreg_zero_, vreg_zero_, vreg_tmp_src0,
359 vmm_tail_vmask_);
360 io_.at(conf_.dst_type)
361 ->store(vreg_zero_, dst_ptr(offt * dt_size), false);
362 off_base = simd_w_ * dt_size;
363 zero_pad_left -= simd_w_ - tail_size_;
364 } else {
365 io_.at(conf_.dst_type)
366 ->store(vreg_tmp_src0, dst_ptr(offt * dt_size), true);
367 off_base = tail_size_ * dt_size;
368 }
369
370 if (zero_pad_left) {
371 push(abi_param1);
372 const Reg32 ®_zero = eax;
373 const Reg64 ®_ptr = rdi;
374 const Reg64 ®_counter = rcx;
375 const auto off_start = off_base;
376 const auto off_end = off_start + zero_pad_left * dt_size;
377 xor_(reg_zero, reg_zero);
378 lea(reg_ptr,
379 ptr[dst_ptr(offt * dt_size).getRegExp()
380 + RegExp(off_start)]);
381 mov(reg_counter, off_end - off_start);
382 rep();
383 stosb();
384 pop(abi_param1);
385 }
386 L(end);
387 } else
388 io_.at(conf_.dst_type)
389 ->store(vreg_tmp_src0, dst_ptr(offt * dt_size), tail);
390 }
391 }
392
393 template <cpu_isa_t isa>
forward()394 void jit_uni_binary_kernel_t<isa>::forward() {
395 Label unroll_loop, unroll_loop_tail, nelems_tail, end;
396
397 const auto src0_type_size = types::data_type_size(conf_.src0_type);
398 const auto src1_type_size = types::data_type_size(conf_.src1_type);
399 const auto dst_type_size = types::data_type_size(conf_.dst_type);
400
401 if (conf_.is_src_different_layouts) push(reg_src1_);
402
403 // if outer dims tail, do it outside outer dims loop
404 if (!is_src1_outer_dims_tail_) {
405 if (conf_.is_i8) {
406 uni_vpxor(vreg_zero_, vreg_zero_, vreg_zero_);
407 io_.init_saturate_f32({conf_.dst_type});
408 xor_(reg_offt_dst_, reg_offt_dst_); // offt_dst to get addr of dst
409 }
410
411 xor_(reg_offt_src0_,
412 reg_offt_src0_); // offt_src0 to get addr of src0/dst
413 if (!conf_.is_src_different_layouts)
414 xor_(reg_offt_src1_,
415 reg_offt_src1_); // offt_src1 to get addr of src1
416 if (conf_.use_stride_rhs_postops && !conf_.is_i8)
417 xor_(reg_off_rhs_postops_, reg_off_rhs_postops_);
418 }
419 const auto alg = pd_->desc()->alg_kind;
420
421 if (utils::one_of(alg, alg_kind::binary_ge, alg_kind::binary_gt,
422 alg_kind::binary_le, alg_kind::binary_lt, alg_kind::binary_eq,
423 alg_kind::binary_ne)) {
424 Xmm xreg_one = Xmm(vreg_one_.getIdx());
425 mov(reg_tmp_, float2int(1));
426 uni_vmovq(xreg_one, reg_tmp_);
427 uni_vbroadcastss(vreg_one_, xreg_one);
428 }
429
430 compute_bcast(false); // bcast/load vreg just one time per a kernel call
431
432 // used in c_blocked strategy for last blocked if tail exists
433 const bool treat_each_compute_step_as_tail
434 = !conf_.is_i8 && is_tail_kernel_ && tail_size_;
435
436 if (conf_.do_scale_src0)
437 uni_vbroadcastss(vreg_scales_src0_, dword[reg_scales_src0_]);
438 if (conf_.do_scale_src1) {
439 uni_vbroadcastss(vreg_scales_src1_, dword[reg_scales_src1_]);
440 if (!conf_.is_i8 && (conf_.broadcast_src1_value || offt_src1_ == 0))
441 uni_vmulps(vreg_bcast_src1_, vreg_bcast_src1_, vreg_scales_src1_);
442 }
443
444 L(unroll_loop);
445 {
446 const size_t offt = unroll_regs_ * simd_w_;
447 cmp(reg_reverse_spat_offt_, offt * dst_type_size);
448 jl(unroll_loop_tail, T_NEAR);
449
450 compute_dst(unroll_regs_, treat_each_compute_step_as_tail);
451 sub(reg_reverse_spat_offt_, offt * dst_type_size);
452 add(reg_offt_src0_, offt * src0_type_size);
453 if (conf_.is_i8) {
454 if (!conf_.broadcast_src1_value && !conf_.is_src_different_layouts)
455 add(reg_offt_src1_, offt * src1_type_size);
456 add(reg_offt_dst_, offt);
457 } else {
458 if (conf_.use_stride_src1 && !conf_.is_src_different_layouts)
459 add(reg_offt_src1_, offt * src1_type_size);
460 if (conf_.use_stride_rhs_postops) add(reg_off_rhs_postops_, offt);
461 }
462 jmp(unroll_loop);
463 }
464
465 L(unroll_loop_tail);
466 {
467 cmp(reg_reverse_spat_offt_, simd_w_ * dst_type_size);
468 jl(nelems_tail, T_NEAR);
469
470 compute_dst(1, treat_each_compute_step_as_tail);
471 sub(reg_reverse_spat_offt_, simd_w_ * dst_type_size);
472 add(reg_offt_src0_, simd_w_ * src0_type_size);
473 if (conf_.is_i8) {
474 if (!conf_.broadcast_src1_value && !conf_.is_src_different_layouts)
475 add(reg_offt_src1_, simd_w_ * src1_type_size);
476 add(reg_offt_dst_, simd_w_);
477 } else {
478 if (conf_.use_stride_src1 && !conf_.is_src_different_layouts)
479 add(reg_offt_src1_, simd_w_ * src1_type_size);
480 if (conf_.use_stride_rhs_postops)
481 add(reg_off_rhs_postops_, simd_w_);
482 }
483
484 jmp(unroll_loop_tail);
485 }
486
487 L(nelems_tail);
488 {
489 cmp(reg_reverse_spat_offt_, 1);
490 jl(end, T_NEAR);
491
492 compute_dst(1, true);
493 // need to increase if forward over outer dims
494 if (is_src1_outer_dims_tail_) {
495 add(reg_offt_src0_, tail_size_ * src0_type_size);
496 if (conf_.is_i8)
497 add(reg_offt_dst_, tail_size_);
498 else {
499 if (conf_.use_stride_rhs_postops)
500 add(reg_off_rhs_postops_, tail_size_);
501 }
502 }
503 }
504
505 L(end);
506 if (conf_.is_src_different_layouts) pop(reg_src1_);
507 }
508
509 template <cpu_isa_t isa>
forward_over_outer_dims()510 void jit_uni_binary_kernel_t<isa>::forward_over_outer_dims() {
511 const auto outer_dims_size
512 = conf_.outer_dims * types::data_type_size(conf_.dst_type);
513
514 if (conf_.is_i8) {
515 uni_vpxor(vreg_zero_, vreg_zero_, vreg_zero_);
516 io_.init_saturate_f32({conf_.dst_type});
517 xor_(reg_offt_dst_, reg_offt_dst_); // offt_dst to get addr of dst
518 }
519
520 xor_(reg_offt_src0_,
521 reg_offt_src0_); // offt_src0 to get addr of src0/dst
522 if (conf_.use_stride_rhs_postops && !conf_.is_i8)
523 xor_(reg_off_rhs_postops_, reg_off_rhs_postops_);
524
525 Label c_loop;
526 L(c_loop);
527 {
528 mov(reg_reverse_spat_offt_, outer_dims_size);
529 forward();
530 sub(reg_outer_dims_range_, outer_dims_size);
531 cmp(reg_outer_dims_range_, 0);
532 jg(c_loop);
533 }
534 }
535
536 template <cpu_isa_t isa>
generate()537 void jit_uni_binary_kernel_t<isa>::generate() {
538 preamble();
539 load_kernel_params();
540 prepare_isa_kernel();
541 // if outer dims is not aligned to simd_w, iterate over it to avoid
542 // modifying the gather indices
543 if (is_src1_outer_dims_tail_)
544 forward_over_outer_dims();
545 else
546 forward();
547 postamble();
548
549 if ((conf_.with_eltwise || conf_.is_i8) && postops_injector_)
550 postops_injector_->prepare_table();
551 }
552
553 #undef PARAM_OFF
554
555 template struct jit_uni_binary_kernel_t<avx512_core_bf16>;
556 template struct jit_uni_binary_kernel_t<avx512_core>;
557 template struct jit_uni_binary_kernel_t<avx512_common>;
558 template struct jit_uni_binary_kernel_t<avx2>;
559 template struct jit_uni_binary_kernel_t<sse41>;
560
561 } // namespace x64
562 } // namespace cpu
563 } // namespace impl
564 } // namespace dnnl
565