1 /*******************************************************************************
2 * Copyright 2020-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef CPU_X64_JIT_AVX512_CORE_AMX_1X1_CONV_KERNEL_HPP
18 #define CPU_X64_JIT_AVX512_CORE_AMX_1X1_CONV_KERNEL_HPP
19 
20 #include "common/c_types_map.hpp"
21 #include "common/memory_tracking.hpp"
22 
23 #include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
24 #include "cpu/x64/jit_generator.hpp"
25 #include "cpu/x64/jit_primitive_conf.hpp"
26 
27 namespace dnnl {
28 namespace impl {
29 namespace cpu {
30 namespace x64 {
31 
32 struct jit_avx512_core_amx_1x1_fwd_kernel_t : public jit_generator {
33     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_amx_1x1_fwd_kernel_t)
34 
35     jit_avx512_core_amx_1x1_fwd_kernel_t(const jit_conv_conf_t &ajcp,
36             const primitive_attr_t &attr, const memory_desc_t &dst_md);
37 
38     static status_t init_conf(jit_conv_conf_t &jcp,
39             const convolution_desc_t &cd, memory_desc_t &src_pd,
40             memory_desc_t &weights_pd, memory_desc_t &dst_pd,
41             memory_desc_t &bias_pd, primitive_attr_t &attr, int nthreads);
42     static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
43             const jit_conv_conf_t &jcp, const primitive_attr_t &attr);
44 
45     // Tile-registers decomposition
46     enum { C_BASE = 0, W_BASE = 6, I_BASE = 4 };
47 
48     void tile_configure(char *tcgf_buff);
49 
50     jit_conv_conf_t jcp;
51     const primitive_attr_t &attr_;
52 
53 private:
54     constexpr static int isa_simd_width_
55             = cpu_isa_traits<avx512_common>::vlen / sizeof(float);
56     std::unique_ptr<injector::jit_uni_postops_injector_t<avx512_core>>
57             postops_injector_;
58 
59     enum {
60         zmm_idx_limit_bf16 = 29,
61         zmm_idx_limit_int8 = 27,
62     };
63 
64     int row_count_ = 0;
65     int buf_count_ = 0;
66     bool is_store_done_ = false;
67     bool is_buffer_empty_ = true;
68     bool check_last_sb_ = false;
69     bool last_oc_block_flag_ = false;
70 
71     /* data regs */
72     const Xbyak::Reg64 &inp_ptr = r15;
73     const Xbyak::Reg64 &wei_ptr = r14;
74     const Xbyak::Reg64 &out_ptr = r13;
75     const Xbyak::Reg64 &wsp_ptr = r12;
76 
77     const Xbyak::Reg64 &reg_bias = r11;
78     const Xbyak::Reg64 &reg_ptr_scales = r10;
79     const Xbyak::Reg64 &reg_ptr_sum_scale = r9;
80     const Xbyak::Reg64 &reg_ptr_sum_zp = rax;
81     const Xbyak::Reg64 &aux_reg_saturation = reg_ptr_sum_scale;
82     const Xbyak::Reg64 &reg_last_h = r8;
83 
84     const Xbyak::Reg64 &stride_seq = rbx;
85     const Xbyak::Reg64 &stride_nhwc = rsi;
86     const Xbyak::Reg64 &reg_tmp = abi_not_param1;
87 
88     const Xbyak::Reg64 &reg_oc_blocks = rdx;
89     const Xbyak::Reg64 &reg_is_osb = rsi;
90     const Xbyak::Reg64 &reg_postop = abi_not_param1;
91     const Xbyak::Reg64 &reg_scratch = reg_bias;
92     const Xbyak::Reg64 &reg_tilebuff = reg_ptr_scales;
93     /* zero-point */
94     const Xbyak::Reg64 &reg_zp_compensation = reg_last_h;
95     const Xbyak::Reg64 &reg_src_zero_point = reg_oc_blocks;
96     const Xbyak::Reg64 &reg_dst_zero_point = rax;
97 
98     const Xbyak::Zmm &zmm_bias = zmm31;
99     const Xbyak::Zmm &zmm_saturation = zmm_bias;
100     const Xbyak::Zmm &zmm_zero = zmm30;
101     const Xbyak::Zmm &zmm_prev_dst = zmm29;
102     const Xbyak::Zmm &zmm_sum_zp = zmm26;
103     /* zero-point */
104     const Xbyak::Zmm &zmm_zp = zmm29;
105     const Xbyak::Zmm &zmm_src_zp = zmm28;
106     const Xbyak::Zmm &zmm_dst_zp = zmm27;
107 
108     const Xbyak::Reg64 &bin_injector_helper_reg_1 = r14;
109     const Xbyak::Reg64 &bin_injector_helper_reg_2 = r15;
110 
111     const Xbyak::Opmask &ktail_mask = k2;
112 
113     bool is_bf16() const;
114 
115     void init_runtime_counters();
116 
117     int get_out_tensor(int h, int i) const;
118     int get_inp_tensor(int h) const;
119     int get_wei_tensor(int i) const;
120     int get_ic_tail() const;
121 
122     size_t out_h_shift() const;
123     size_t out_w_shift() const;
124     size_t inp_offset(int ih, int iw, int icb) const;
125     size_t out_row_offset(int h, int w, int ocb) const;
126 
127     void prepare_output();
128 
129     void cvt2ps(data_type_t type_in, const Xbyak::Zmm &ymm_in,
130             const Xbyak::Operand &op, bool mask_flag);
zmm_outdnnl::impl::cpu::x64::jit_avx512_core_amx_1x1_fwd_kernel_t131     Xbyak::Zmm zmm_out(const int idx) {
132         const int upper_limit
133                 = is_bf16() ? zmm_idx_limit_bf16 : zmm_idx_limit_int8;
134         assert(upper_limit > idx);
135         MAYBE_UNUSED(upper_limit);
136         return Xbyak::Zmm(idx);
137     }
138     Xbyak::Zmm zmm_mask(
139             const Xbyak::Zmm &zmm_in, bool mask_flag, bool store = false);
140     Xbyak::Ymm ymm_mask(
141             const Xbyak::Ymm &ymm_in, bool mask_flag, bool store = false);
142 
143     void update_buffer_pointers();
144     void interleave_store();
145     void apply_sum(const Xbyak::Zmm &zmm_out, const float *p_sum_scale,
146             const int32_t *p_sum_zp, const Xbyak::Address &addr,
147             const bool mask_flag);
148     void apply_postops(const Xbyak::Zmm &zmm_out, const float *p_sum_scale,
149             const int32_t *p_sum_zp, const Xbyak::Address &addr,
150             const size_t off, const bool mask_flag);
151     static bool is_fast_postops(const jit_conv_conf_t &jcp);
152     void store_output_vectors_int8(int ocb, int osb);
153     void store_output_vector_int8(
154             const Xbyak::Zmm &zmm_out, int ocb, int h, int w);
155     void store_output_vectors_bf16(int ocb, int osb);
156     void store_output_vector_bf16(
157             const Xbyak::Zmm &zmm_out, int ocb, int h, int w);
158     void store_output_vectors(int ocb, int osb);
159     void store_output_vector(const Xbyak::Zmm &zmm_out, int ocb, int h, int w);
160     void store_output(bool do_store, bool is_tail);
161     void icb_loop(bool do_store);
162     void osb_loop(int nb_os = 1);
163 
164     void generate() override;
165 };
166 
167 } // namespace x64
168 } // namespace cpu
169 } // namespace impl
170 } // namespace dnnl
171 
172 #endif
173