1 /******************************************************************************* 2 * Copyright 2021 Intel Corporation 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 *******************************************************************************/ 16 17 #ifndef CPU_X64_UTILS_JIT_IO_HELPER_HPP 18 #define CPU_X64_UTILS_JIT_IO_HELPER_HPP 19 20 #include <map> 21 #include <memory> 22 #include <unordered_set> 23 24 #include "common/optional.hpp" 25 26 #include "cpu/x64/cpu_isa_traits.hpp" 27 #include "cpu/x64/jit_generator.hpp" 28 29 namespace dnnl { 30 namespace impl { 31 namespace cpu { 32 namespace x64 { 33 34 struct bf16_emulation_t; 35 36 namespace io { 37 38 class io_conf_t { 39 public: 40 io_conf_t() = default; 41 io_conf_t(const bool nt_stores_enabled); 42 io_conf_t(const io_conf_t &other) = default; 43 44 io_conf_t &operator=(const io_conf_t &other) = default; 45 46 bool nt_stores_enabled_ = false; 47 }; 48 49 class io_tail_conf_t { 50 public: 51 io_tail_conf_t(const std::size_t simd_w, const std::size_t tail_size, 52 const Xbyak::Opmask &tail_opmask, const int tail_vmm_mask_idx, 53 const Xbyak::Reg64 ®_tmp); 54 io_tail_conf_t(const io_tail_conf_t &other) = default; 55 56 io_tail_conf_t &operator=(const io_tail_conf_t &other) = default; 57 58 std::size_t simd_w_ = 0; 59 std::size_t tail_size_ = 0; 60 Xbyak::Opmask tail_opmask_ = Xbyak::Opmask(); 61 int tail_vmm_mask_idx_ = 0; 62 Xbyak::Reg64 reg_tmp_ = Xbyak::Reg64(); 63 }; 64 65 class io_emu_bf16_conf_t { 66 public: 67 io_emu_bf16_conf_t() = default; 68 io_emu_bf16_conf_t(const Xbyak::Zmm &bf16_emu_reserv_1, 69 const Xbyak::Zmm &bf16_emu_reserv_2, 70 const Xbyak::Zmm &bf16_emu_reserv_3, const Xbyak::Reg64 ®_tmp, 71 const Xbyak::Zmm &bf16_emu_reserv_4); 72 io_emu_bf16_conf_t(const io_emu_bf16_conf_t &other) = default; 73 74 io_emu_bf16_conf_t &operator=(const io_emu_bf16_conf_t &other) = default; 75 76 Xbyak::Zmm bf16_emu_reserv_1_ = Xbyak::Zmm(28); 77 Xbyak::Zmm bf16_emu_reserv_2_ = Xbyak::Zmm(29); 78 Xbyak::Zmm bf16_emu_reserv_3_ = Xbyak::Zmm(30); 79 Xbyak::Reg64 reg_tmp_ = Xbyak::util::rax; 80 Xbyak::Zmm bf16_emu_reserv_4_ = Xbyak::Zmm(31); 81 }; 82 83 class io_saturation_conf_t { 84 public: 85 io_saturation_conf_t(const int vreg_zero_saturation_idx, 86 const int vreg_saturation_ubound_idx, const Xbyak::Reg64 ®_tmp); 87 io_saturation_conf_t(const io_saturation_conf_t &other) = default; 88 89 io_saturation_conf_t &operator=(const io_saturation_conf_t &other) 90 = default; 91 92 int vreg_zero_saturation_idx_ = 0; 93 int vreg_saturation_ubound_idx_ = 0; 94 Xbyak::Reg64 reg_tmp_ = Xbyak::Reg64(); 95 }; 96 97 class io_gather_conf_t { 98 public: 99 io_gather_conf_t(const std::size_t simd_w, const Xbyak::Opmask &full_opmask, 100 const int full_vmm_mask_idx, const Xbyak::Reg64 ®_tmp, 101 const Xbyak::Reg64 ®_tmp1, 102 const utils::optional_t<int> &vmm_tmp_idx = utils::nullopt); 103 io_gather_conf_t(const io_gather_conf_t &other) = default; 104 105 io_gather_conf_t &operator=(const io_gather_conf_t &other) = default; 106 107 std::size_t simd_w_ = 0; 108 Xbyak::Opmask full_opmask_ = Xbyak::Opmask(); 109 int full_vmm_mask_idx_ = 0; 110 Xbyak::Reg64 reg_tmp_ = Xbyak::Reg64(); 111 Xbyak::Reg64 reg_tmp1_ = Xbyak::Reg64(); 112 // It is needed, when io_helper use emulation for gather 113 // and it is not needed for sse. 114 utils::optional_t<int> vmm_tmp_idx_ = utils::nullopt; 115 }; 116 117 template <typename Vmm> 118 class jit_io_multi_dt_helper_t; 119 120 template <typename Vmm> 121 class jit_io_helper_t { 122 public: 123 friend class jit_io_multi_dt_helper_t<Vmm>; 124 125 jit_io_helper_t(jit_generator *host, const cpu_isa_t &isa, 126 const data_type_t &data_type, const io_conf_t &io_conf, 127 const utils::optional_t<io_tail_conf_t> &tail_conf = utils::nullopt, 128 const utils::optional_t<io_emu_bf16_conf_t> &bf16_conf 129 = utils::nullopt, 130 const utils::optional_t<io_saturation_conf_t> &saturation_conf 131 = utils::nullopt, 132 const utils::optional_t<io_gather_conf_t> &gather_conf 133 = utils::nullopt); 134 jit_io_helper_t(jit_io_helper_t &&) = default; 135 jit_io_helper_t &operator=(jit_io_helper_t &&) = default; 136 137 ~jit_io_helper_t(); 138 void prepare_tail_mask(); 139 void prepare_full_mask(); 140 /* 141 * Sometimes the values in the register can be nan at the 142 * beginning of the kernel, then using vcmpps(vmm, vmm, vmm) 143 * will not set all bits to 1, instead of that this instruction will 144 * return zero. At the beginning, it is worth to zeroing 145 * full mask vmm to be sure, that vcmpps work properly. 146 */ 147 void init_full_mask(); 148 void init_saturate_f32() const; 149 void init_bf16(); 150 void gather(const Xbyak::Reg64 &src_reg, const Vmm &indices_vmm, 151 const Vmm &dst_vmm, const bool tail); 152 void broadcast(const Xbyak::Address &src_addr, const Vmm &dst_vmm); 153 void load(const Xbyak::Address &src_addr, const Vmm &dst_vmm, 154 const bool tail); 155 void store(const Vmm &src_vmm, const Xbyak::Address &dst_addr, 156 const bool tail); 157 158 private: 159 void prepare_opmask(const std::size_t how_many_bits_to_set, 160 const Xbyak::Reg64 ®_tmp, const Xbyak::Opmask &mask); 161 void prepare_vmm_mask(const std::size_t how_many_bits_to_set, 162 const std::size_t simd_w, const Xbyak::Reg64 ®_tmp, 163 const Vmm &mask); 164 void prepare_i8_data_to_store(const Vmm &i8_vmm); 165 // Emulates the behavior of vgatherdps for architectures 166 // that do not support this instruction. 167 void emu_gather(const Xbyak::Reg64 &src_reg, const Vmm &indices_vmm, 168 const Vmm &dst_vmm, const bool tail); 169 void load_byte_by_byte(const Xbyak::Address &src_addr, const Vmm &dst_vmm, 170 const int load_size); 171 void load_f32(const Xbyak::Address &src_addr, const Vmm &dst_vmm, 172 const bool tail); 173 void load_s32(const Xbyak::Address &src_addr, const Vmm &dst_vmm, 174 const bool tail); 175 void load_bf16(const Xbyak::Address &src_addr, const Vmm &dst_vmm); 176 void load_i8(const Xbyak::Address &src_addr, const Vmm &dst_vmm); 177 void saturate(const Vmm &vmm); 178 void store_byte_by_byte(const Vmm &src_vmm, const Xbyak::Address &dst_addr, 179 const int store_size); 180 void store_f32(const Vmm &src_vmm, const Xbyak::Address &dst_addr, 181 const bool tail); 182 void store_bf16(const Vmm &src_vmm, const Xbyak::Address &dst_addr); 183 void store_i8(const Vmm &src_vmm, const Xbyak::Address &dst_addr); 184 void convert_to_f32(const Vmm &dst_vmm, const Xbyak::Xmm &src_vmm, 185 const data_type_t src_data_type); 186 187 jit_generator *host_; 188 const cpu_isa_t isa_; 189 const data_type_t data_type_; 190 const bool bf16_supported_; 191 std::unique_ptr<bf16_emulation_t> bf16_emu_; 192 const io_conf_t io_conf_; 193 const utils::optional_t<io_tail_conf_t> tail_conf_; 194 const utils::optional_t<io_emu_bf16_conf_t> bf16_conf_; 195 const utils::optional_t<io_saturation_conf_t> saturation_conf_; 196 const utils::optional_t<io_gather_conf_t> gather_conf_; 197 }; 198 199 template <typename Vmm> 200 class jit_io_multi_dt_helper_t { 201 public: 202 using data_types_t = std::unordered_set<data_type_t, std::hash<int>>; 203 204 jit_io_multi_dt_helper_t(jit_generator *host, const cpu_isa_t &isa, 205 const data_types_t &data_types, const io_conf_t &io_conf, 206 const utils::optional_t<io_tail_conf_t> &tail_conf = utils::nullopt, 207 const utils::optional_t<io_emu_bf16_conf_t> &bf16_conf 208 = utils::nullopt, 209 const std::map<data_type_t, io_saturation_conf_t> &saturation_confs 210 = std::map<data_type_t, io_saturation_conf_t> {}, 211 const utils::optional_t<io_gather_conf_t> &gather_conf 212 = utils::nullopt); 213 ~jit_io_multi_dt_helper_t(); 214 void prepare_tail_mask(); 215 void prepare_full_mask(); 216 void init_saturate_f32(const data_types_t &store_data_types); 217 void init_full_mask(); 218 void init_bf16(); 219 220 std::shared_ptr<jit_io_helper_t<Vmm>> at(const data_type_t dt) const; 221 222 private: 223 std::unordered_map<data_type_t, std::shared_ptr<jit_io_helper_t<Vmm>>, 224 std::hash<int>> 225 storage_; 226 }; 227 228 } // namespace io 229 } // namespace x64 230 } // namespace cpu 231 } // namespace impl 232 } // namespace dnnl 233 234 #endif 235