1 /*******************************************************************************
2 * Copyright 2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef CPU_X64_UTILS_JIT_IO_HELPER_HPP
18 #define CPU_X64_UTILS_JIT_IO_HELPER_HPP
19 
20 #include <map>
21 #include <memory>
22 #include <unordered_set>
23 
24 #include "common/optional.hpp"
25 
26 #include "cpu/x64/cpu_isa_traits.hpp"
27 #include "cpu/x64/jit_generator.hpp"
28 
29 namespace dnnl {
30 namespace impl {
31 namespace cpu {
32 namespace x64 {
33 
34 struct bf16_emulation_t;
35 
36 namespace io {
37 
38 class io_conf_t {
39 public:
40     io_conf_t() = default;
41     io_conf_t(const bool nt_stores_enabled);
42     io_conf_t(const io_conf_t &other) = default;
43 
44     io_conf_t &operator=(const io_conf_t &other) = default;
45 
46     bool nt_stores_enabled_ = false;
47 };
48 
49 class io_tail_conf_t {
50 public:
51     io_tail_conf_t(const std::size_t simd_w, const std::size_t tail_size,
52             const Xbyak::Opmask &tail_opmask, const int tail_vmm_mask_idx,
53             const Xbyak::Reg64 &reg_tmp);
54     io_tail_conf_t(const io_tail_conf_t &other) = default;
55 
56     io_tail_conf_t &operator=(const io_tail_conf_t &other) = default;
57 
58     std::size_t simd_w_ = 0;
59     std::size_t tail_size_ = 0;
60     Xbyak::Opmask tail_opmask_ = Xbyak::Opmask();
61     int tail_vmm_mask_idx_ = 0;
62     Xbyak::Reg64 reg_tmp_ = Xbyak::Reg64();
63 };
64 
65 class io_emu_bf16_conf_t {
66 public:
67     io_emu_bf16_conf_t() = default;
68     io_emu_bf16_conf_t(const Xbyak::Zmm &bf16_emu_reserv_1,
69             const Xbyak::Zmm &bf16_emu_reserv_2,
70             const Xbyak::Zmm &bf16_emu_reserv_3, const Xbyak::Reg64 &reg_tmp,
71             const Xbyak::Zmm &bf16_emu_reserv_4);
72     io_emu_bf16_conf_t(const io_emu_bf16_conf_t &other) = default;
73 
74     io_emu_bf16_conf_t &operator=(const io_emu_bf16_conf_t &other) = default;
75 
76     Xbyak::Zmm bf16_emu_reserv_1_ = Xbyak::Zmm(28);
77     Xbyak::Zmm bf16_emu_reserv_2_ = Xbyak::Zmm(29);
78     Xbyak::Zmm bf16_emu_reserv_3_ = Xbyak::Zmm(30);
79     Xbyak::Reg64 reg_tmp_ = Xbyak::util::rax;
80     Xbyak::Zmm bf16_emu_reserv_4_ = Xbyak::Zmm(31);
81 };
82 
83 class io_saturation_conf_t {
84 public:
85     io_saturation_conf_t(const int vreg_zero_saturation_idx,
86             const int vreg_saturation_ubound_idx, const Xbyak::Reg64 &reg_tmp);
87     io_saturation_conf_t(const io_saturation_conf_t &other) = default;
88 
89     io_saturation_conf_t &operator=(const io_saturation_conf_t &other)
90             = default;
91 
92     int vreg_zero_saturation_idx_ = 0;
93     int vreg_saturation_ubound_idx_ = 0;
94     Xbyak::Reg64 reg_tmp_ = Xbyak::Reg64();
95 };
96 
97 class io_gather_conf_t {
98 public:
99     io_gather_conf_t(const std::size_t simd_w, const Xbyak::Opmask &full_opmask,
100             const int full_vmm_mask_idx, const Xbyak::Reg64 &reg_tmp,
101             const Xbyak::Reg64 &reg_tmp1,
102             const utils::optional_t<int> &vmm_tmp_idx = utils::nullopt);
103     io_gather_conf_t(const io_gather_conf_t &other) = default;
104 
105     io_gather_conf_t &operator=(const io_gather_conf_t &other) = default;
106 
107     std::size_t simd_w_ = 0;
108     Xbyak::Opmask full_opmask_ = Xbyak::Opmask();
109     int full_vmm_mask_idx_ = 0;
110     Xbyak::Reg64 reg_tmp_ = Xbyak::Reg64();
111     Xbyak::Reg64 reg_tmp1_ = Xbyak::Reg64();
112     // It is needed, when io_helper use emulation for gather
113     // and it is not needed for sse.
114     utils::optional_t<int> vmm_tmp_idx_ = utils::nullopt;
115 };
116 
117 template <typename Vmm>
118 class jit_io_multi_dt_helper_t;
119 
120 template <typename Vmm>
121 class jit_io_helper_t {
122 public:
123     friend class jit_io_multi_dt_helper_t<Vmm>;
124 
125     jit_io_helper_t(jit_generator *host, const cpu_isa_t &isa,
126             const data_type_t &data_type, const io_conf_t &io_conf,
127             const utils::optional_t<io_tail_conf_t> &tail_conf = utils::nullopt,
128             const utils::optional_t<io_emu_bf16_conf_t> &bf16_conf
129             = utils::nullopt,
130             const utils::optional_t<io_saturation_conf_t> &saturation_conf
131             = utils::nullopt,
132             const utils::optional_t<io_gather_conf_t> &gather_conf
133             = utils::nullopt);
134     jit_io_helper_t(jit_io_helper_t &&) = default;
135     jit_io_helper_t &operator=(jit_io_helper_t &&) = default;
136 
137     ~jit_io_helper_t();
138     void prepare_tail_mask();
139     void prepare_full_mask();
140     /*
141      * Sometimes the values in the register can be nan at the
142      * beginning of the kernel, then using vcmpps(vmm, vmm, vmm)
143      * will not set all bits to 1, instead of that this instruction will
144      * return zero. At the beginning, it is worth to zeroing
145      * full mask vmm to be sure, that vcmpps work properly.
146      */
147     void init_full_mask();
148     void init_saturate_f32() const;
149     void init_bf16();
150     void gather(const Xbyak::Reg64 &src_reg, const Vmm &indices_vmm,
151             const Vmm &dst_vmm, const bool tail);
152     void broadcast(const Xbyak::Address &src_addr, const Vmm &dst_vmm);
153     void load(const Xbyak::Address &src_addr, const Vmm &dst_vmm,
154             const bool tail);
155     void store(const Vmm &src_vmm, const Xbyak::Address &dst_addr,
156             const bool tail);
157 
158 private:
159     void prepare_opmask(const std::size_t how_many_bits_to_set,
160             const Xbyak::Reg64 &reg_tmp, const Xbyak::Opmask &mask);
161     void prepare_vmm_mask(const std::size_t how_many_bits_to_set,
162             const std::size_t simd_w, const Xbyak::Reg64 &reg_tmp,
163             const Vmm &mask);
164     void prepare_i8_data_to_store(const Vmm &i8_vmm);
165     // Emulates the behavior of vgatherdps for architectures
166     // that do not support this instruction.
167     void emu_gather(const Xbyak::Reg64 &src_reg, const Vmm &indices_vmm,
168             const Vmm &dst_vmm, const bool tail);
169     void load_byte_by_byte(const Xbyak::Address &src_addr, const Vmm &dst_vmm,
170             const int load_size);
171     void load_f32(const Xbyak::Address &src_addr, const Vmm &dst_vmm,
172             const bool tail);
173     void load_s32(const Xbyak::Address &src_addr, const Vmm &dst_vmm,
174             const bool tail);
175     void load_bf16(const Xbyak::Address &src_addr, const Vmm &dst_vmm);
176     void load_i8(const Xbyak::Address &src_addr, const Vmm &dst_vmm);
177     void saturate(const Vmm &vmm);
178     void store_byte_by_byte(const Vmm &src_vmm, const Xbyak::Address &dst_addr,
179             const int store_size);
180     void store_f32(const Vmm &src_vmm, const Xbyak::Address &dst_addr,
181             const bool tail);
182     void store_bf16(const Vmm &src_vmm, const Xbyak::Address &dst_addr);
183     void store_i8(const Vmm &src_vmm, const Xbyak::Address &dst_addr);
184     void convert_to_f32(const Vmm &dst_vmm, const Xbyak::Xmm &src_vmm,
185             const data_type_t src_data_type);
186 
187     jit_generator *host_;
188     const cpu_isa_t isa_;
189     const data_type_t data_type_;
190     const bool bf16_supported_;
191     std::unique_ptr<bf16_emulation_t> bf16_emu_;
192     const io_conf_t io_conf_;
193     const utils::optional_t<io_tail_conf_t> tail_conf_;
194     const utils::optional_t<io_emu_bf16_conf_t> bf16_conf_;
195     const utils::optional_t<io_saturation_conf_t> saturation_conf_;
196     const utils::optional_t<io_gather_conf_t> gather_conf_;
197 };
198 
199 template <typename Vmm>
200 class jit_io_multi_dt_helper_t {
201 public:
202     using data_types_t = std::unordered_set<data_type_t, std::hash<int>>;
203 
204     jit_io_multi_dt_helper_t(jit_generator *host, const cpu_isa_t &isa,
205             const data_types_t &data_types, const io_conf_t &io_conf,
206             const utils::optional_t<io_tail_conf_t> &tail_conf = utils::nullopt,
207             const utils::optional_t<io_emu_bf16_conf_t> &bf16_conf
208             = utils::nullopt,
209             const std::map<data_type_t, io_saturation_conf_t> &saturation_confs
210             = std::map<data_type_t, io_saturation_conf_t> {},
211             const utils::optional_t<io_gather_conf_t> &gather_conf
212             = utils::nullopt);
213     ~jit_io_multi_dt_helper_t();
214     void prepare_tail_mask();
215     void prepare_full_mask();
216     void init_saturate_f32(const data_types_t &store_data_types);
217     void init_full_mask();
218     void init_bf16();
219 
220     std::shared_ptr<jit_io_helper_t<Vmm>> at(const data_type_t dt) const;
221 
222 private:
223     std::unordered_map<data_type_t, std::shared_ptr<jit_io_helper_t<Vmm>>,
224             std::hash<int>>
225             storage_;
226 };
227 
228 } // namespace io
229 } // namespace x64
230 } // namespace cpu
231 } // namespace impl
232 } // namespace dnnl
233 
234 #endif
235