1 /*******************************************************************************
2 * Copyright 2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef CPU_X64_UNI_BINARY_KERNEL_HPP
18 #define CPU_X64_UNI_BINARY_KERNEL_HPP
19 
20 #include <cassert>
21 
22 #include "common/c_types_map.hpp"
23 #include "common/type_helpers.hpp"
24 #include "common/utils.hpp"
25 
26 #include "cpu/x64/cpu_isa_traits.hpp"
27 #include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
28 #include "cpu/x64/jit_generator.hpp"
29 #include "cpu/x64/jit_primitive_conf.hpp"
30 #include "cpu/x64/utils/jit_io_helper.hpp"
31 
32 #include "cpu/cpu_binary_pd.hpp"
33 
34 namespace dnnl {
35 namespace impl {
36 namespace cpu {
37 namespace x64 {
38 
39 using namespace Xbyak;
40 
41 struct binary_kernel_t : public jit_generator {
42     using op_t = binary_op_t;
43     using bcast_t = binary_bcast_t;
44 
45     binary_kernel_t(const size_t vlen, const binary_pd_t *pd,
46             const jit_binary_conf_t conf, bool tail_kernel = false);
47     ~binary_kernel_t() override = default;
48 
operator ()dnnl::impl::cpu::x64::binary_kernel_t49     void operator()(jit_binary_call_s *p) { jit_generator::operator()(p); }
50 
simd_wdnnl::impl::cpu::x64::binary_kernel_t51     size_t simd_w() const noexcept { return simd_w_; }
vlendnnl::impl::cpu::x64::binary_kernel_t52     size_t vlen() const noexcept { return vlen_; }
53 
54 protected:
55     size_t get_tail_size() const;
56 
57     const size_t vlen_;
58     const size_t simd_w_;
59     constexpr static int vmm_start_idx_ = 1;
60     const binary_pd_t *pd_;
61     const jit_binary_conf_t conf_;
62     const bool is_tail_kernel_;
63     const bool is_src1_outer_dims_tail_;
64     const size_t tail_size_;
65     const size_t padding_tail_size_;
66 };
67 
68 template <cpu_isa_t isa>
69 struct jit_uni_binary_kernel_t : public binary_kernel_t {
70     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_binary_kernel_t)
71 
72     using Vmm = typename cpu_isa_traits<isa>::Vmm;
73     const AddressFrame &vmmword
74             = (isa == sse41) ? xword : ((isa == avx2) ? yword : zword);
75 
76     static constexpr bool is_avx512
77             = utils::one_of(isa, avx512_common, avx512_core, avx512_core_bf16);
78     static constexpr bool is_avx512_core
79             = utils::one_of(isa, avx512_core, avx512_core_bf16);
80     static constexpr bool is_avx512_common = isa == avx512_common;
81     const bool is_avx512_not_mic
82             = is_avx512_core || (is_avx512_common && !conf_.is_i8);
83 
84     const Reg64 &reg_param_ = abi_param1;
85     const Reg64 &reg_src0_ = r8;
86     const Reg64 &reg_src1_ = r9;
87     const Reg64 &reg_dst_ = r10;
88     const Reg64 &reg_offt_src0_ = r11;
89     const Reg64 &reg_outer_dims_range_ = r12;
90     const Reg64 &reg_offt_src1_ = rax;
91     const Reg64 &reg_src1_stride_range_ = r15;
92     const Reg64 &reg_reverse_src1_stride_range_ = rax;
93     const Reg64 &reg_reverse_spat_offt_ = r13;
94     const Reg64 &reg_tmp_ = r14;
95     const Reg64 &reg_tmp1_ = abi_not_param1;
96     const Reg64 &reg_elt_inj_table_ = r15;
97     const Reg64 &reg_off_rhs_postops_ = rdx;
98     const Reg64 &reg_scales_src0_ = rbx;
99     const Reg64 &reg_scales_src1_ = rbp;
100     const Reg64 &reg_offt_dst_ = rdx;
101     const Opmask &tail_opmask_ = k2;
102     const Opmask &cmp_mask = k3;
103     const Opmask &full_mask_ = k4;
104     const Vmm vmm_tail_vmask_ = Vmm(0);
105     const Vmm vreg_sum_scale_ = Vmm(is_avx512 ? 17 : 9);
106     const Xmm xreg_sum_scale_ = Xmm(9);
107     const Vmm vreg_zero_ = Vmm(is_avx512 ? 18 : 10);
108     const Vmm vreg_one_ = Vmm(is_avx512 ? 19 : 11);
109     const Vmm vreg_saturation_ubound_ = Vmm(is_avx512 ? 20 : 12);
110     const Vmm vreg_bcast_src1_ = Vmm(is_avx512_not_mic ? 21 : 13);
111     const Xmm xreg_bcast_src1_ = Xmm(13);
112     const Vmm vreg_scales_src0_ = Vmm(is_avx512 ? 22 : 14);
113     const Vmm vreg_scales_src1_ = Vmm(is_avx512 ? 23 : 15);
114 
115     const Zmm vreg_bf16_emu_1_ = Zmm(26);
116     const Zmm vreg_bf16_emu_2_ = Zmm(27);
117     const Zmm vreg_bf16_emu_3_ = Zmm(28);
118     const Zmm vreg_bf16_emu_4_ = Zmm(29);
119 
120     const Vmm vmm_full_mask_ = Vmm(is_avx512_not_mic ? 24 : 5);
121     const Vmm vmm_tmp_gather_ = Vmm(is_avx512_not_mic ? 25 : 6);
122     const Vmm vmm_indices_ = Vmm(is_avx512_not_mic ? 30 : 7);
123     const Vmm vmm_gathered_src_ = Vmm(is_avx512_not_mic ? 31 : 8);
124 
125     const size_t unroll_regs_ = is_avx512
126                     && IMPLICATION(
127                             conf_.is_src_different_layouts, is_avx512_not_mic)
128             ? 8
129             : 4;
130     const size_t offt_src0_;
131     const size_t offt_src1_;
132 
133     static constexpr cpu_isa_t inject_isa
134             = isa == avx512_core_bf16 ? avx512_core : isa;
135     io::jit_io_multi_dt_helper_t<Vmm> io_;
136     std::unique_ptr<injector::jit_uni_postops_injector_t<inject_isa>>
137             postops_injector_;
138     const Opmask &elt_inj_opmask_ = k1;
139 
140     void init();
141     void init_post_ops_injector();
142     void apply_postops(int unroll, bool tail);
143     void load_kernel_params();
144     Address src0_ptr(size_t offt = 0);
145     Address src1_ptr(size_t offt = 0);
146     Address dst_ptr(size_t offt = 0);
147     unsigned int cmp_predicate(alg_kind_t alg);
148     void perform_op(
149             const Vmm &v0, const Vmm &v1, const Vmm &s_src0, const Vmm &s_src1);
150     void prepare_isa_kernel();
151     void compute_bcast(bool tail);
152     void load_src1(const Vmm &vreg_src1, const int offt, bool tail);
153     void compute_dst(int unroll, bool tail);
154     void forward();
155     void forward_over_outer_dims();
156     void generate() override;
157 
158     jit_uni_binary_kernel_t(const binary_pd_t *pd, const jit_binary_conf_t conf,
159             bool tail_kernel = false);
160     ~jit_uni_binary_kernel_t() override = default;
161 
162     std::map<data_type_t, io::io_saturation_conf_t>
163     create_saturation_vmm_map() const;
164 };
165 
166 } // namespace x64
167 } // namespace cpu
168 } // namespace impl
169 } // namespace dnnl
170 
171 #endif