1 /******************************************************************************* 2 * Copyright 2019-2021 Intel Corporation 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 *******************************************************************************/ 16 17 #ifndef GPU_JIT_XE_HP_SYSTOLIC_GEMM_KERNEL_HPP 18 #define GPU_JIT_XE_HP_SYSTOLIC_GEMM_KERNEL_HPP 19 20 #include <cstdint> 21 22 #include "common/c_types_map.hpp" 23 #include "common/type_helpers.hpp" 24 #include "gpu/jit/gemm/gen_gemm_kernel_common.hpp" 25 #include "gpu/jit/gemm/gen_gemm_kernel_generator.hpp" 26 #include "gpu/jit/jit_generator.hpp" 27 #include "gpu/jit/jit_post_op_injector.hpp" 28 29 // clang-format off 30 // This header must be loaded after ngen. 31 #include "gpu/jit/gemm/emulation.hpp" 32 // clang-format on 33 34 namespace dnnl { 35 namespace impl { 36 namespace gpu { 37 namespace jit { 38 39 template <gpu_gen_t hw> 40 class xehp_systolic_gemm_kernel_t : public jit_generator<hw> { 41 public: 42 NGEN_FORWARD_OPENCL(hw); 43 enum class bias_t { none, fixed, row, column, runtime }; 44 45 struct config_t { 46 ngen::DataType a_type, b_type, c_type, acc_type; 47 ngen::DataType co_type = ngen::DataType::invalid; 48 ngen::DataType scale_type = ngen::DataType::f; 49 bool alpha1, beta0, beta1; 50 51 bool a_bias = false; 52 bool b_bias = false; 53 bias_t c_bias = bias_t::none; 54 bool early_c_bias = false; 55 bool c_packed = false; 56 bool batch = false; 57 bool emulate64 = (hw == ngen::HW::XeHPG); 58 59 int tile_m = 32; 60 int tile_n = 48; 61 bool walk_n_first = false; 62 bool alt_barriers = false; 63 bool use_slm_fence = true; 64 bool c_remainder = true; 65 bool c_align16_check = true; 66 bool pad_a = true; 67 bool global_3x_buf = false; 68 bool fulsim = true; 69 70 post_ops_t post_ops; 71 bool post_op_is_fwd = true; 72 float eltwise_alpha, eltwise_beta, eltwise_scale; 73 have_post_opdnnl::impl::gpu::jit::xehp_systolic_gemm_kernel_t::config_t74 bool have_post_op() const { return (post_ops.len() > 0); } 75 validdnnl::impl::gpu::jit::xehp_systolic_gemm_kernel_t::config_t76 bool valid() const { 77 using ngen::DataType; 78 79 bool ok = true; 80 if (c_type == DataType::d || c_type == DataType::ud) { 81 ok &= (a_type == DataType::b || a_type == DataType::ub); 82 ok &= (b_type == DataType::b || b_type == DataType::ub); 83 ok &= (acc_type == c_type); 84 } else { 85 ok &= (a_type == b_type); 86 ok &= (a_type == DataType::bf || a_type == DataType::hf); 87 ok &= (c_type == DataType::f || c_type == a_type); 88 ok &= (acc_type == DataType::f); 89 ok &= !a_bias && !b_bias; 90 } 91 ok &= (alt_barriers || use_slm_fence); 92 ok &= (c_bias == bias_t::none 93 || (early_c_bias 94 && (co_type == acc_type || co_type == a_type)) 95 || co_type == c_type); 96 ok &= (tile_m == 32); 97 ok &= (tile_n == 32 || tile_n == 48); 98 ok &= !(tile_n > 32 && global_3x_buf); 99 100 return ok; 101 } 102 103 template <typename T> castdnnl::impl::gpu::jit::xehp_systolic_gemm_kernel_t::config_t104 T &cast() { 105 return *reinterpret_cast<T *>(this); 106 } 107 }; 108 109 static constexpr size_t unroll_k_bytes = 32; 110 static constexpr size_t thread_group_m = 4; 111 static constexpr size_t thread_group_n = 4; 112 static constexpr size_t nominal_subgroup_size = 8; 113 unroll_k(data_type_t dt)114 static size_t unroll_k(data_type_t dt) { 115 return unroll_k_bytes / types::data_type_size(dt); 116 } 117 this_unroll_k() const118 int this_unroll_k() const { return unroll_k_bytes / getBytes(cfg.a_type); } 119 min_block_k(data_type_t dt)120 static int min_block_k(data_type_t dt) { 121 return 8192 / int(types::data_type_size(dt)); 122 } 123 driver_info(int eu_count) const124 CommonDriverInfo driver_info(int eu_count) const { 125 CommonDriverInfo info; 126 info.subgroupSize = nominal_subgroup_size; 127 info.fusedEUs = true; 128 info.loopOrder[0] = LoopM; 129 info.loopOrder[1] = LoopN; 130 info.loopOrder[2] = LoopK; 131 info.unroll[LoopM] = cfg.tile_m; 132 info.unroll[LoopN] = cfg.tile_n; 133 info.unroll[LoopK] = this_unroll_k(); 134 info.wg[LoopM] = thread_group_m; 135 info.wg[LoopN] = thread_group_n; 136 info.blocking[LoopM] = 1024; 137 info.blocking[LoopN] = eu_count * 6; 138 info.blocking[LoopK] = 8192 / getBytes(cfg.a_type); 139 info.fixedWG = true; 140 return info; 141 } 142 143 private: 144 config_t cfg; 145 146 using injector_t = jit_post_op_injector<hw>; 147 std::unique_ptr<injector_t> post_op_injector; 148 149 // Surface assignments 150 int ap_surface, bp_surface, co_surface; 151 152 // Register assignments (main loop) 153 ngen::GRFRange a_copy0 = r40 - r47; 154 ngen::GRFRange b_copy0 = r2 - r13; 155 ngen::GRFRange a_regs = r48 - r63; 156 ngen::GRFRange b_regs = r14 - r37; 157 ngen::GRFRange c_regs = r64 - r255; 158 ngen::GRFRange a_copy1 = r96 - r103; 159 ngen::GRFRange b_copy1 = r104 - r111; 160 ngen::GRFRange a_copy2 = r144 - r151; 161 ngen::GRFRange b_copy2 = r152 - r159; 162 ngen::GRFRange a_copy[3] = {a_copy0, a_copy1, a_copy2}; 163 ngen::GRFRange b_copy[3] = {b_copy0, b_copy1, b_copy2}; 164 ngen::GRF addr0 = r1; 165 ngen::GRF addr1 = r38; 166 ngen::GRF addr2 = r39; 167 ngen::GRF addr3 = r0; 168 ngen::Subregister a_ptr_mem = addr1.uq(3); 169 ngen::Subregister b_ptr_mem = addr2.uq(3); 170 ngen::Subregister c_ptr_mem = addr2.uq(2); 171 ngen::Subregister slm_a_offset_load = addr1.uw(8); // offsets in OWords 172 ngen::Subregister slm_b_offset_load = addr1.uw(9); 173 ngen::Subregister slm_a_offset_store = addr1.uw(10); 174 ngen::Subregister slm_b_offset_store = addr1.uw(11); 175 ngen::Subregister slm_a_offset_load_init = addr1.uw(6); 176 ngen::Subregister slm_b_offset_load_init = addr1.uw(7); 177 ngen::Subregister slm_a_offset_store_init = addr2.uw(6); 178 ngen::Subregister slm_b_offset_store_init = addr2.uw(7); 179 ngen::Register base_save = acc0.ud(); 180 ngen::Subregister k_counter = acc0.d(0); 181 ngen::Subregister ldc_save = acc0.ud(1); 182 ngen::Subregister off_co_save = acc0.ud(2); 183 ngen::Subregister k_save = acc0.ud(3); 184 ngen::Subregister mrem_save = acc0.uw(8); 185 ngen::Subregister nrem_save = acc0.uw(9); 186 ngen::Subregister abo_save = acc0.ud(5); 187 ngen::Subregister ao_save = acc0.w(10); 188 ngen::Subregister bo_save = acc0.w(11); 189 ngen::Subregister alpha_save = acc0.ud(6); 190 ngen::Subregister beta_save = acc0.ud(7); 191 ngen::AccumulatorRegister r0_save = acc2; 192 ngen::Subregister off_asum_save = a0.ud(0); 193 ngen::Subregister off_bsum_save = a0.ud(1); 194 ngen::Subregister flags_save = a0.ud(2); 195 196 ngen::InstructionModifier dep_addr0 {}, dep_addr1 {}, dep_addr2 {}, 197 dep_addr3 {}; // Dependencies for addr registers. 198 199 // Register assignments (C update) 200 ngen::GRFRange utemp = r32 - r63; 201 ngen::GRFRange uheaders = r0 - r15; 202 ngen::GRFRange upost_op_scratch = r16 - r21; 203 ngen::GRFRange uoffset = r22 - r27; 204 ngen::GRFRange uemulate_temp = r20 - r21; 205 206 ngen::GRF ubase = r28.ud(); 207 ngen::Subregister uldc = r28.ud(1); 208 ngen::Subregister uoff_co = r28.ud(2); 209 ngen::Subregister uk = r28.ud(3); 210 ngen::Subregister um_rem = r28.uw(8); 211 ngen::Subregister un_rem = r28.uw(9); 212 ngen::Subregister uao = r28.w(10); 213 ngen::Subregister ubo = r28.w(11); 214 ngen::Subregister ualpha_regs[2] = {r28.f(6), r30.f(6)}; 215 ngen::Subregister ubeta_regs[2] = {r28.f(7), r30.f(7)}; 216 217 ngen::Subregister uc_base = r29.uq(0); 218 ngen::Subregister uoff_co2 = r29.ud(2); 219 220 ngen::Subregister uao_bo_k = r30.ud(0); 221 ngen::Subregister uflags = r30.ud(1); 222 223 ngen::Subregister uldc_x2 = r31.ud(1); 224 ngen::Subregister uldc_x4 = r31.ud(2); 225 ngen::Subregister uldc_x8 = r31.ud(3); 226 227 // 64-bit emulation registers 228 EmulationStrategy emu_strategy; 229 EmulationState emu_state; 230 231 // Scoreboard usage: 232 // $0-2 B SLM loads 233 // $3-4 A SLM loads 234 // $5 Last DPASW in chain 235 // $6-7 Load local IDs/kernel arguments 236 // $8 A copy to SLM 237 // $9-10 B copy to SLM 238 // $11 Initial A copy to SLM using C register space 239 // $12-13 Initial B copy to SLM using C register space 240 // $14 EOT 241 // $15 Barriers/SLM fences 242 243 static constexpr int acc_stride = 48; 244 245 friend struct EmulationImplementation; 246 template <typename DT = void> emov(const ngen::InstructionModifier & mod,ngen::RegData dst,ngen::RegData src0)247 void emov(const ngen::InstructionModifier &mod, ngen::RegData dst, 248 ngen::RegData src0) { 249 EmulationImplementation::emov<DT>(*this, mod, dst, src0, emu_strategy); 250 } 251 template <typename DT = void> emov(const ngen::InstructionModifier & mod,ngen::RegData dst,ngen::Immediate src0)252 void emov(const ngen::InstructionModifier &mod, ngen::RegData dst, 253 ngen::Immediate src0) { 254 EmulationImplementation::emov<DT>(*this, mod, dst, src0, emu_strategy); 255 } 256 template <typename DT = void> eadd(const ngen::InstructionModifier & mod,const ngen::RegData & dst,const ngen::RegData & src0,const ngen::RegData & src1)257 void eadd(const ngen::InstructionModifier &mod, const ngen::RegData &dst, 258 const ngen::RegData &src0, const ngen::RegData &src1) { 259 EmulationImplementation::eadd<DT>( 260 *this, mod, dst, src0, src1, emu_strategy, emu_state); 261 } 262 template <typename DT = void> eadd(const ngen::InstructionModifier & mod,const ngen::RegData & dst,const ngen::RegData & src0,ngen::Immediate src1)263 void eadd(const ngen::InstructionModifier &mod, const ngen::RegData &dst, 264 const ngen::RegData &src0, ngen::Immediate src1) { 265 EmulationImplementation::eadd<DT>( 266 *this, mod, dst, src0, src1, emu_strategy, emu_state); 267 } 268 template <typename DT = void> emul(const ngen::InstructionModifier & mod,const ngen::RegData & dst,const ngen::RegData & src0,const ngen::RegData & src1)269 void emul(const ngen::InstructionModifier &mod, const ngen::RegData &dst, 270 const ngen::RegData &src0, const ngen::RegData &src1) { 271 EmulationImplementation::emul<DT>( 272 *this, mod, dst, src0, src1, emu_strategy, emu_state); 273 } 274 template <typename DT = void> emul(const ngen::InstructionModifier & mod,const ngen::RegData & dst,const ngen::RegData & src0,ngen::Immediate src1)275 void emul(const ngen::InstructionModifier &mod, const ngen::RegData &dst, 276 const ngen::RegData &src0, ngen::Immediate src1) { 277 EmulationImplementation::emul<DT>( 278 *this, mod, dst, src0, src1, emu_strategy, emu_state); 279 } 280 template <typename DT = void> eshl(const ngen::InstructionModifier & mod,ngen::RegData dst,ngen::RegData src0,uint16_t src1)281 void eshl(const ngen::InstructionModifier &mod, ngen::RegData dst, 282 ngen::RegData src0, uint16_t src1) { 283 EmulationImplementation::eshl<DT>( 284 *this, mod, dst, src0, src1, emu_strategy, emu_state); 285 } 286 template <typename DT = void> eshr(const ngen::InstructionModifier & mod,ngen::RegData dst,ngen::RegData src0,uint16_t src1)287 void eshr(const ngen::InstructionModifier &mod, ngen::RegData dst, 288 ngen::RegData src0, uint16_t src1) { 289 EmulationImplementation::eshr<DT>( 290 *this, mod, dst, src0, src1, emu_strategy, emu_state); 291 } 292 template <typename DT = void> emulConstant(const ngen::InstructionModifier & mod,const ngen::RegData & dst,const ngen::RegData & src0,int32_t src1)293 void emulConstant(const ngen::InstructionModifier &mod, 294 const ngen::RegData &dst, const ngen::RegData &src0, int32_t src1) { 295 EmulationImplementation::emulConstant<DT>( 296 *this, mod, dst, src0, src1, emu_strategy, emu_state); 297 } 298 slm_buf_size() const299 int slm_buf_size() const { 300 return cfg.pad_a 301 ? 10752 // 4.5k A (128x32 + 4*128 padding) + 6k B (192x32) 302 : 10240; // 4k A (128x32) + 6k B (192x32) 303 } 304 packed_ldc() const305 int packed_ldc() const { return 32 / getBytes(cfg.b_type); } 306 307 void barrier_prep( 308 const ngen::InstructionModifier &swsb, const ngen::GRF &header); 309 void mul_constant(const ngen::InstructionModifier &mod, 310 const ngen::RegData &dst, const ngen::RegData &src0, int32_t src1); 311 void zero_c(); 312 313 void scattered_setup_c(int stride, bool load); 314 void block_setup_c(bool remainder, bool load); 315 316 int interleave(int j); 317 318 void load_c_bias_scattered(bias_t c_bias, bool remainder); 319 void load_c_bias_block(bias_t c_bias); 320 void load_c_bias(bias_t c_bias, bool remainder); 321 void load_c_bias(bool remainder); 322 void convert_c_bias(bias_t c_bias, ngen::DataType dst_type); 323 void add_c_bias(bias_t c_bias); 324 void add_c_bias(); 325 bool merge_abc_bias(); 326 void add_ab_bias(); 327 328 void load_update_c_internal(bool remainder, bool c_align16); 329 void load_update_c(bool remainder, bool c_align16); 330 void store_c(bool remainder, bool c_align16); 331 void update_c(bool remainder); 332 void update_c(); 333 334 void dpasw_typed(const ngen::InstructionModifier &mod, uint8_t sdepth, 335 uint8_t rcount, const ngen::GRF &c_reg, const ngen::GRF &a_reg, 336 const ngen::GRF &b_reg); 337 338 void multiply_chunk(int ao, int i0, bool waitb, 339 const ngen::InstructionModifier &swsb0 340 = ngen::InstructionModifier(), 341 const ngen::InstructionModifier &swsb_end 342 = ngen::InstructionModifier()); 343 void multiply(int buffer, bool last_multiply = false); 344 345 void copy_load(int store_buffer, bool use_c = false); 346 void copy_store(int store_buffer, bool first = false); 347 void store_signal(bool force_fence = false); 348 349 void body(); 350 351 public: 352 xehp_systolic_gemm_kernel_t(config_t cfg_); 353 }; 354 355 } // namespace jit 356 } // namespace gpu 357 } // namespace impl 358 } // namespace dnnl 359 360 #endif // GPU_JIT_XE_HP_SYSTOLIC_GEMM_KERNEL_HPP 361