1 /*******************************************************************************
2 * Copyright 2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include "common/dnnl_thread.hpp"
18 
19 #include "cpu/x64/jit_avx512_core_bf16cvt.hpp"
20 #include "cpu/x64/jit_uni_binary_kernel.hpp"
21 
22 namespace dnnl {
23 namespace impl {
24 namespace cpu {
25 namespace x64 {
26 
27 #define PARAM_OFF(x) offsetof(jit_binary_call_s, x)
28 
get_supported_bcast_strategies()29 static bcast_set_t get_supported_bcast_strategies() {
30     return {broadcasting_strategy_t::scalar, broadcasting_strategy_t::per_oc,
31             broadcasting_strategy_t::per_oc_spatial};
32 }
33 
binary_kernel_t(const size_t vlen,const binary_pd_t * pd,const jit_binary_conf_t conf,bool tail_kernel)34 binary_kernel_t::binary_kernel_t(const size_t vlen, const binary_pd_t *pd,
35         const jit_binary_conf_t conf, bool tail_kernel)
36     : vlen_(vlen)
37     , simd_w_(vlen / sizeof(float))
38     , pd_(pd)
39     , conf_(conf)
40     , is_tail_kernel_(tail_kernel)
41     , is_src1_outer_dims_tail_(
42               conf_.is_src_different_layouts && conf_.outer_dims % simd_w_)
43     , tail_size_(get_tail_size())
44     , padding_tail_size_(
45               pd->src_md(0)->padded_dims[1] - pd->src_md(0)->dims[1]) {}
46 
get_tail_size() const47 size_t binary_kernel_t::get_tail_size() const {
48     memory_desc_wrapper src0_d(pd_->src_md(0));
49     const auto &dims = src0_d.dims();
50     const auto &ndims = src0_d.ndims();
51 
52     dim_t nelems = 0;
53 
54     if (ndims == 1)
55         nelems = dims[0];
56     else if (is_src1_outer_dims_tail_)
57         nelems = conf_.outer_dims;
58     else if (!conf_.is_i8 && conf_.op_type == op_t::c_blocked
59             && (is_tail_kernel_ || conf_.bcast_type == bcast_t::per_w))
60         nelems = dims[1];
61     else if (conf_.bcast_type == bcast_t::none
62             && !conf_.postops_per_oc_broadcast_exists)
63         nelems = src0_d.nelems(true);
64     else {
65         if (conf_.op_type == op_t::n_spatial_c)
66             nelems = dims[1];
67         else if (conf_.op_type == op_t::n_c_spatial && ndims >= 3)
68             nelems = conf_.bcast_type == bcast_t::per_w
69                     ? dims[ndims - 1]
70                     : utils::array_product(dims + 2, ndims - 2);
71     }
72     // it's float due to for bfloat16 we still load 16 elements, not 32.
73     return nelems % simd_w_;
74 }
75 
76 template <cpu_isa_t isa>
jit_uni_binary_kernel_t(const binary_pd_t * pd,const jit_binary_conf_t conf,bool tail_kernel)77 jit_uni_binary_kernel_t<isa>::jit_uni_binary_kernel_t(
78         const binary_pd_t *pd, const jit_binary_conf_t conf, bool tail_kernel)
79     : binary_kernel_t(cpu_isa_traits<isa>::vlen, pd, conf, tail_kernel)
80     , offt_src0_(vlen_ / (conf_.is_bf16 ? 2 : 1))
81     , offt_src1_(conf_.use_stride_src1 ? offt_src0_ : 0)
82     , io_(this, isa, {conf_.src0_type, conf_.src1_type, conf_.dst_type},
83               {false},
84               io::io_tail_conf_t {simd_w_, tail_size_, tail_opmask_,
85                       vmm_tail_vmask_.getIdx(), reg_tmp_},
86               io::io_emu_bf16_conf_t {vreg_bf16_emu_1_, vreg_bf16_emu_2_,
87                       vreg_bf16_emu_3_, reg_tmp_, vreg_bf16_emu_4_},
88               create_saturation_vmm_map(),
89               io::io_gather_conf_t {simd_w_, full_mask_,
90                       vmm_full_mask_.getIdx(), reg_tmp_, reg_tmp1_,
91                       vmm_tmp_gather_.getIdx()}) {
92     init();
93 }
94 
95 template <cpu_isa_t isa>
96 std::map<data_type_t, io::io_saturation_conf_t>
create_saturation_vmm_map() const97 jit_uni_binary_kernel_t<isa>::create_saturation_vmm_map() const {
98 
99     std::map<data_type_t, io::io_saturation_conf_t> saturation_map {};
100 
101     if (conf_.is_i8)
102         saturation_map.emplace(conf_.dst_type,
103                 io::io_saturation_conf_t {vreg_zero_.getIdx(),
104                         vreg_saturation_ubound_.getIdx(), reg_tmp_});
105 
106     return saturation_map;
107 }
108 
109 template <cpu_isa_t isa>
init()110 void jit_uni_binary_kernel_t<isa>::init() {
111     if (conf_.with_postops) init_post_ops_injector();
112 }
113 
114 template <cpu_isa_t isa>
init_post_ops_injector()115 void jit_uni_binary_kernel_t<isa>::init_post_ops_injector() {
116     const memory_desc_wrapper src0_d(pd_->src_md(0));
117     const auto &po = pd_->attr()->post_ops_;
118 
119     const eltwise_injector::static_params_t esp(true /*save_state*/,
120             reg_elt_inj_table_, elt_inj_opmask_, true /*is_fwd*/,
121             false /*use_dst*/);
122     const binary_injector::rhs_arg_static_params_t rhs_arg_bsp {10, reg_tmp_,
123             reg_elt_inj_table_, true /*preserve gpr*/, true /*preserve vmm*/,
124             PARAM_OFF(post_ops_binary_rhs_arg_vec), src0_d, tail_size_,
125             tail_opmask_, false /*use_exact_tail_scalar_bcast*/};
126     const binary_injector::static_params_t bsp(
127             this->param1, get_supported_bcast_strategies(), rhs_arg_bsp);
128 
129     postops_injector_ = utils::make_unique<
130             injector::jit_uni_postops_injector_t<inject_isa>>(
131             this, po, bsp, esp);
132 }
133 
134 template <cpu_isa_t isa>
apply_postops(int unroll,bool tail)135 void jit_uni_binary_kernel_t<isa>::apply_postops(int unroll, bool tail) {
136     binary_injector::rhs_arg_dynamic_params_t rhs_arg_params;
137     for (int vmm_idx = 1; vmm_idx < unroll + vmm_start_idx_; vmm_idx++) {
138         if (utils::one_of(conf_.op_type, op_t::c_blocked, op_t::n_c_spatial)) {
139             rhs_arg_params.vmm_idx_to_oc_elem_off_addr.emplace(
140                     vmm_idx, ptr[param1 + PARAM_OFF(oc_l_off)]);
141         } else if (conf_.op_type == op_t::n_spatial_c) {
142             rhs_arg_params.vmm_idx_to_oc_off_oprnd.emplace(
143                     vmm_idx, reg_off_rhs_postops_);
144             rhs_arg_params.vmm_idx_to_oc_elem_off_val.emplace(vmm_idx,
145                     (vmm_idx - vmm_start_idx_) * static_cast<int>(simd_w_));
146         }
147         if (tail) rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx);
148     }
149     postops_injector_->compute_vector_range(
150             1, unroll + vmm_start_idx_, rhs_arg_params);
151 }
152 
153 template <cpu_isa_t isa>
load_kernel_params()154 void jit_uni_binary_kernel_t<isa>::load_kernel_params() {
155     mov(reg_tmp_, float2int(conf_.sum_scale));
156     uni_vmovq(xreg_sum_scale_, reg_tmp_);
157     uni_vbroadcastss(vreg_sum_scale_, xreg_sum_scale_);
158     if (is_src1_outer_dims_tail_)
159         mov(reg_outer_dims_range_,
160                 ptr[reg_param_ + PARAM_OFF(spat_offt_count)]);
161     else
162         mov(reg_reverse_spat_offt_,
163                 ptr[reg_param_ + PARAM_OFF(spat_offt_count)]);
164     mov(reg_src0_, ptr[reg_param_ + PARAM_OFF(src0)]);
165     mov(reg_src1_, ptr[reg_param_ + PARAM_OFF(src1)]);
166     mov(reg_dst_, ptr[reg_param_ + PARAM_OFF(dst)]);
167     if (conf_.is_src_different_layouts) {
168         mov(reg_tmp_, ptr[reg_param_ + PARAM_OFF(indices)]);
169         uni_vmovdqu(vmm_indices_, ptr[reg_tmp_]);
170 
171         mov(reg_src1_stride_range_,
172                 ptr[reg_param_ + PARAM_OFF(src1_stride_range)]);
173         mov(reg_reverse_src1_stride_range_, reg_src1_stride_range_);
174     }
175     if (conf_.do_scale_src0)
176         mov(reg_scales_src0_, ptr[reg_param_ + PARAM_OFF(scales_src0)]);
177     if (conf_.do_scale_src1)
178         mov(reg_scales_src1_, ptr[reg_param_ + PARAM_OFF(scales_src1)]);
179 }
180 
181 template <cpu_isa_t isa>
src0_ptr(size_t offt)182 Address jit_uni_binary_kernel_t<isa>::src0_ptr(size_t offt) {
183     return vmmword[reg_src0_ + reg_offt_src0_ + offt];
184 }
185 
186 template <cpu_isa_t isa>
src1_ptr(size_t offt)187 Address jit_uni_binary_kernel_t<isa>::src1_ptr(size_t offt) {
188     return vmmword[reg_src1_ + reg_offt_src1_ + offt];
189 }
190 
191 template <cpu_isa_t isa>
dst_ptr(size_t offt)192 Address jit_uni_binary_kernel_t<isa>::dst_ptr(size_t offt) {
193     const Reg64 &reg_offt_dst = conf_.is_i8 ? reg_offt_dst_ : reg_offt_src0_;
194     return vmmword[reg_dst_ + reg_offt_dst + offt];
195 }
196 
197 template <cpu_isa_t isa>
cmp_predicate(alg_kind_t alg)198 unsigned int jit_uni_binary_kernel_t<isa>::cmp_predicate(alg_kind_t alg) {
199     using namespace alg_kind;
200     switch (alg) {
201         case binary_ge: return _cmp_nlt_us;
202         case binary_gt: return _cmp_nle_us;
203         case binary_le: return _cmp_le_os;
204         case binary_lt: return _cmp_lt_os;
205         case binary_eq: return _cmp_eq_oq;
206         case binary_ne: return _cmp_neq_uq;
207         default: assert(!"not supported operation!"); return -1;
208     }
209 }
210 
211 template <cpu_isa_t isa>
perform_op(const Vmm & v0,const Vmm & v1,const Vmm & s_src0,const Vmm & s_src1)212 void jit_uni_binary_kernel_t<isa>::perform_op(
213         const Vmm &v0, const Vmm &v1, const Vmm &s_src0, const Vmm &s_src1) {
214     using namespace alg_kind;
215     const auto alg = pd_->desc()->alg_kind;
216     const bool cmp_op = utils::one_of(alg, alg_kind::binary_ge,
217             alg_kind::binary_gt, alg_kind::binary_le, alg_kind::binary_lt,
218             alg_kind::binary_eq, alg_kind::binary_ne);
219     if (conf_.do_scale_src0) uni_vmulps(v0, v0, s_src0);
220     if (conf_.do_scale_src1
221             && (conf_.is_i8
222                     || (offt_src1_ != 0 && !conf_.broadcast_src1_value)))
223         uni_vmulps(v1, v1, s_src1);
224 
225     if (alg == binary_add)
226         uni_vaddps(v0, v0, v1);
227     else if (alg == binary_mul)
228         uni_vmulps(v0, v0, v1);
229     else if (alg == binary_max)
230         uni_vmaxps(v0, v0, v1);
231     else if (alg == binary_min)
232         uni_vminps(v0, v0, v1);
233     else if (alg == binary_div)
234         uni_vdivps(v0, v0, v1);
235     else if (alg == binary_sub)
236         uni_vsubps(v0, v0, v1);
237     else if (cmp_op) {
238         const unsigned int predicate = cmp_predicate(alg);
239         if (is_avx512) {
240             vcmpps(cmp_mask, v0, v1, predicate);
241             vmovups(v0 | cmp_mask | T_z, vreg_one_);
242         } else {
243             uni_vcmpps(v0, v0, v1, predicate);
244             uni_vminps(v0, v0, vreg_one_);
245         }
246     } else
247         assert(!"not supported operation!");
248 }
249 
250 template <cpu_isa_t isa>
prepare_isa_kernel()251 void jit_uni_binary_kernel_t<isa>::prepare_isa_kernel() {
252     if (conf_.is_bf16) io_.init_bf16();
253     if (tail_size_ > 0) io_.prepare_tail_mask();
254     if (conf_.is_src_different_layouts && is_superset(isa, avx2)) {
255         io_.init_full_mask();
256         io_.prepare_full_mask();
257     }
258 }
259 
260 template <cpu_isa_t isa>
compute_bcast(bool tail)261 void jit_uni_binary_kernel_t<isa>::compute_bcast(bool tail) {
262     if (conf_.broadcast_src1_value) {
263         if (conf_.is_i8)
264             uni_vpxor(xreg_bcast_src1_, xreg_bcast_src1_, xreg_bcast_src1_);
265         io_.at(conf_.src1_type)->broadcast(src1_ptr(), vreg_bcast_src1_);
266     } else if (!conf_.is_i8 && offt_src1_ == 0) {
267         io_.at(conf_.src1_type)->load(src1_ptr(), vreg_bcast_src1_, tail);
268     }
269 }
270 
271 template <cpu_isa_t isa>
load_src1(const Vmm & vreg_src1,const int offt,bool tail)272 void jit_uni_binary_kernel_t<isa>::load_src1(
273         const Vmm &vreg_src1, const int offt, bool tail) {
274     if (conf_.is_src_different_layouts) {
275         // if different layouts, gather data with strides
276         // after getting to stride range, offset is restored and
277         // increased
278         io_.at(conf_.src1_type)
279                 ->gather(reg_src1_, vmm_indices_, vreg_src1, tail);
280         // gather is using register instead of operand to read address
281         // use reg_src1_ directly, without offset stored in second
282         // register
283         add(reg_src1_,
284                 types::data_type_size(conf_.src1_type) * conf_.src1_stride
285                         * simd_w_);
286         sub(reg_reverse_src1_stride_range_,
287                 types::data_type_size(conf_.src1_type) * conf_.src1_stride
288                         * simd_w_);
289 
290         Label src1_stride_range_not_exceed, src1_C_tail_end;
291 
292         cmp(reg_reverse_src1_stride_range_, 0);
293         jg(src1_stride_range_not_exceed, T_NEAR);
294         {
295             pop(reg_src1_);
296             add(reg_src1_, types::data_type_size(conf_.src1_type));
297             push(reg_src1_);
298             mov(reg_reverse_src1_stride_range_, reg_src1_stride_range_);
299         }
300         L(src1_stride_range_not_exceed);
301     } else
302         io_.at(conf_.src1_type)
303                 ->load(src1_ptr(offt * types::data_type_size(conf_.src1_type)),
304                         vreg_src1, tail);
305 }
306 
307 template <cpu_isa_t isa>
compute_dst(int unroll,bool tail)308 void jit_uni_binary_kernel_t<isa>::compute_dst(int unroll, bool tail) {
309     for (int i = 0; i < unroll; i++) {
310         const Vmm vreg_tmp_src0 = Vmm(i + vmm_start_idx_);
311         const Vmm vreg_tmp = conf_.is_src_different_layouts
312                 ? vmm_gathered_src_
313                 : Vmm(unroll + i + vmm_start_idx_);
314         const Vmm vreg_tmp_src1 = offt_src1_ ? vreg_tmp : vreg_bcast_src1_;
315         const int offt = simd_w_ * i;
316         io_.at(conf_.src0_type)
317                 ->load(src0_ptr(offt * types::data_type_size(conf_.src0_type)),
318                         vreg_tmp_src0, tail);
319         if (offt_src1_) load_src1(vreg_tmp_src1, offt, tail);
320 
321         // avoid multiple multiplication on input scale for broadcasted vreg
322         // not needed for different layouts
323         if (!conf_.is_src_different_layouts)
324             uni_vmovups(vreg_tmp, vreg_tmp_src1);
325         perform_op(
326                 vreg_tmp_src0, vreg_tmp, vreg_scales_src0_, vreg_scales_src1_);
327         if (conf_.do_sum) {
328             io_.at(conf_.dst_type)
329                     ->load(dst_ptr(offt
330                                    * types::data_type_size(conf_.dst_type)),
331                             vreg_tmp, tail);
332             uni_vfmadd231ps(vreg_tmp_src0, vreg_tmp, vreg_sum_scale_);
333         }
334     }
335 
336     if (postops_injector_) apply_postops(unroll, tail);
337 
338     for (int i = 0; i < unroll; i++) {
339         const Vmm vreg_tmp_src0 = Vmm(i + vmm_start_idx_);
340         const int offt = simd_w_ * i;
341         const auto dt_size = types::data_type_size(conf_.dst_type);
342 
343         if (is_tail_kernel_ && padding_tail_size_) {
344             // apply zero-padding
345             Label end;
346             auto off_base = 0;
347             auto zero_pad_left = padding_tail_size_;
348 
349             // inplace data is assumed to be zero-padded
350             cmp(reg_src0_, reg_dst_);
351             je(end, T_NEAR);
352 
353             if (zero_pad_left >= simd_w_ - tail_size_) {
354                 vxorps(vreg_zero_, vreg_zero_, vreg_zero_);
355                 if (is_avx512)
356                     uni_vmovups(vreg_zero_ | tail_opmask_, vreg_tmp_src0);
357                 else
358                     uni_vblendvps(vreg_zero_, vreg_zero_, vreg_tmp_src0,
359                             vmm_tail_vmask_);
360                 io_.at(conf_.dst_type)
361                         ->store(vreg_zero_, dst_ptr(offt * dt_size), false);
362                 off_base = simd_w_ * dt_size;
363                 zero_pad_left -= simd_w_ - tail_size_;
364             } else {
365                 io_.at(conf_.dst_type)
366                         ->store(vreg_tmp_src0, dst_ptr(offt * dt_size), true);
367                 off_base = tail_size_ * dt_size;
368             }
369 
370             if (zero_pad_left) {
371                 push(abi_param1);
372                 const Reg32 &reg_zero = eax;
373                 const Reg64 &reg_ptr = rdi;
374                 const Reg64 &reg_counter = rcx;
375                 const auto off_start = off_base;
376                 const auto off_end = off_start + zero_pad_left * dt_size;
377                 xor_(reg_zero, reg_zero);
378                 lea(reg_ptr,
379                         ptr[dst_ptr(offt * dt_size).getRegExp()
380                                 + RegExp(off_start)]);
381                 mov(reg_counter, off_end - off_start);
382                 rep();
383                 stosb();
384                 pop(abi_param1);
385             }
386             L(end);
387         } else
388             io_.at(conf_.dst_type)
389                     ->store(vreg_tmp_src0, dst_ptr(offt * dt_size), tail);
390     }
391 }
392 
393 template <cpu_isa_t isa>
forward()394 void jit_uni_binary_kernel_t<isa>::forward() {
395     Label unroll_loop, unroll_loop_tail, nelems_tail, end;
396 
397     const auto src0_type_size = types::data_type_size(conf_.src0_type);
398     const auto src1_type_size = types::data_type_size(conf_.src1_type);
399     const auto dst_type_size = types::data_type_size(conf_.dst_type);
400 
401     if (conf_.is_src_different_layouts) push(reg_src1_);
402 
403     // if outer dims tail, do it outside outer dims loop
404     if (!is_src1_outer_dims_tail_) {
405         if (conf_.is_i8) {
406             uni_vpxor(vreg_zero_, vreg_zero_, vreg_zero_);
407             io_.init_saturate_f32({conf_.dst_type});
408             xor_(reg_offt_dst_, reg_offt_dst_); // offt_dst to get addr of dst
409         }
410 
411         xor_(reg_offt_src0_,
412                 reg_offt_src0_); // offt_src0 to get addr of src0/dst
413         if (!conf_.is_src_different_layouts)
414             xor_(reg_offt_src1_,
415                     reg_offt_src1_); // offt_src1 to get addr of src1
416         if (conf_.use_stride_rhs_postops && !conf_.is_i8)
417             xor_(reg_off_rhs_postops_, reg_off_rhs_postops_);
418     }
419     const auto alg = pd_->desc()->alg_kind;
420 
421     if (utils::one_of(alg, alg_kind::binary_ge, alg_kind::binary_gt,
422                 alg_kind::binary_le, alg_kind::binary_lt, alg_kind::binary_eq,
423                 alg_kind::binary_ne)) {
424         Xmm xreg_one = Xmm(vreg_one_.getIdx());
425         mov(reg_tmp_, float2int(1));
426         uni_vmovq(xreg_one, reg_tmp_);
427         uni_vbroadcastss(vreg_one_, xreg_one);
428     }
429 
430     compute_bcast(false); // bcast/load vreg just one time per a kernel call
431 
432     // used in c_blocked strategy for last blocked if tail exists
433     const bool treat_each_compute_step_as_tail
434             = !conf_.is_i8 && is_tail_kernel_ && tail_size_;
435 
436     if (conf_.do_scale_src0)
437         uni_vbroadcastss(vreg_scales_src0_, dword[reg_scales_src0_]);
438     if (conf_.do_scale_src1) {
439         uni_vbroadcastss(vreg_scales_src1_, dword[reg_scales_src1_]);
440         if (!conf_.is_i8 && (conf_.broadcast_src1_value || offt_src1_ == 0))
441             uni_vmulps(vreg_bcast_src1_, vreg_bcast_src1_, vreg_scales_src1_);
442     }
443 
444     L(unroll_loop);
445     {
446         const size_t offt = unroll_regs_ * simd_w_;
447         cmp(reg_reverse_spat_offt_, offt * dst_type_size);
448         jl(unroll_loop_tail, T_NEAR);
449 
450         compute_dst(unroll_regs_, treat_each_compute_step_as_tail);
451         sub(reg_reverse_spat_offt_, offt * dst_type_size);
452         add(reg_offt_src0_, offt * src0_type_size);
453         if (conf_.is_i8) {
454             if (!conf_.broadcast_src1_value && !conf_.is_src_different_layouts)
455                 add(reg_offt_src1_, offt * src1_type_size);
456             add(reg_offt_dst_, offt);
457         } else {
458             if (conf_.use_stride_src1 && !conf_.is_src_different_layouts)
459                 add(reg_offt_src1_, offt * src1_type_size);
460             if (conf_.use_stride_rhs_postops) add(reg_off_rhs_postops_, offt);
461         }
462         jmp(unroll_loop);
463     }
464 
465     L(unroll_loop_tail);
466     {
467         cmp(reg_reverse_spat_offt_, simd_w_ * dst_type_size);
468         jl(nelems_tail, T_NEAR);
469 
470         compute_dst(1, treat_each_compute_step_as_tail);
471         sub(reg_reverse_spat_offt_, simd_w_ * dst_type_size);
472         add(reg_offt_src0_, simd_w_ * src0_type_size);
473         if (conf_.is_i8) {
474             if (!conf_.broadcast_src1_value && !conf_.is_src_different_layouts)
475                 add(reg_offt_src1_, simd_w_ * src1_type_size);
476             add(reg_offt_dst_, simd_w_);
477         } else {
478             if (conf_.use_stride_src1 && !conf_.is_src_different_layouts)
479                 add(reg_offt_src1_, simd_w_ * src1_type_size);
480             if (conf_.use_stride_rhs_postops)
481                 add(reg_off_rhs_postops_, simd_w_);
482         }
483 
484         jmp(unroll_loop_tail);
485     }
486 
487     L(nelems_tail);
488     {
489         cmp(reg_reverse_spat_offt_, 1);
490         jl(end, T_NEAR);
491 
492         compute_dst(1, true);
493         // need to increase if forward over outer dims
494         if (is_src1_outer_dims_tail_) {
495             add(reg_offt_src0_, tail_size_ * src0_type_size);
496             if (conf_.is_i8)
497                 add(reg_offt_dst_, tail_size_);
498             else {
499                 if (conf_.use_stride_rhs_postops)
500                     add(reg_off_rhs_postops_, tail_size_);
501             }
502         }
503     }
504 
505     L(end);
506     if (conf_.is_src_different_layouts) pop(reg_src1_);
507 }
508 
509 template <cpu_isa_t isa>
forward_over_outer_dims()510 void jit_uni_binary_kernel_t<isa>::forward_over_outer_dims() {
511     const auto outer_dims_size
512             = conf_.outer_dims * types::data_type_size(conf_.dst_type);
513 
514     if (conf_.is_i8) {
515         uni_vpxor(vreg_zero_, vreg_zero_, vreg_zero_);
516         io_.init_saturate_f32({conf_.dst_type});
517         xor_(reg_offt_dst_, reg_offt_dst_); // offt_dst to get addr of dst
518     }
519 
520     xor_(reg_offt_src0_,
521             reg_offt_src0_); // offt_src0 to get addr of src0/dst
522     if (conf_.use_stride_rhs_postops && !conf_.is_i8)
523         xor_(reg_off_rhs_postops_, reg_off_rhs_postops_);
524 
525     Label c_loop;
526     L(c_loop);
527     {
528         mov(reg_reverse_spat_offt_, outer_dims_size);
529         forward();
530         sub(reg_outer_dims_range_, outer_dims_size);
531         cmp(reg_outer_dims_range_, 0);
532         jg(c_loop);
533     }
534 }
535 
536 template <cpu_isa_t isa>
generate()537 void jit_uni_binary_kernel_t<isa>::generate() {
538     preamble();
539     load_kernel_params();
540     prepare_isa_kernel();
541     // if outer dims is not aligned to simd_w, iterate over it to avoid
542     // modifying the gather indices
543     if (is_src1_outer_dims_tail_)
544         forward_over_outer_dims();
545     else
546         forward();
547     postamble();
548 
549     if ((conf_.with_eltwise || conf_.is_i8) && postops_injector_)
550         postops_injector_->prepare_table();
551 }
552 
553 #undef PARAM_OFF
554 
555 template struct jit_uni_binary_kernel_t<avx512_core_bf16>;
556 template struct jit_uni_binary_kernel_t<avx512_core>;
557 template struct jit_uni_binary_kernel_t<avx512_common>;
558 template struct jit_uni_binary_kernel_t<avx2>;
559 template struct jit_uni_binary_kernel_t<sse41>;
560 
561 } // namespace x64
562 } // namespace cpu
563 } // namespace impl
564 } // namespace dnnl
565