1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef GPU_JIT_XE_HP_SYSTOLIC_GEMM_KERNEL_HPP
18 #define GPU_JIT_XE_HP_SYSTOLIC_GEMM_KERNEL_HPP
19 
20 #include <cstdint>
21 
22 #include "common/c_types_map.hpp"
23 #include "common/type_helpers.hpp"
24 #include "gpu/jit/gemm/gen_gemm_kernel_common.hpp"
25 #include "gpu/jit/gemm/gen_gemm_kernel_generator.hpp"
26 #include "gpu/jit/jit_generator.hpp"
27 #include "gpu/jit/jit_post_op_injector.hpp"
28 
29 // clang-format off
30 // This header must be loaded after ngen.
31 #include "gpu/jit/gemm/emulation.hpp"
32 // clang-format on
33 
34 namespace dnnl {
35 namespace impl {
36 namespace gpu {
37 namespace jit {
38 
39 template <gpu_gen_t hw>
40 class xehp_systolic_gemm_kernel_t : public jit_generator<hw> {
41 public:
42     NGEN_FORWARD_OPENCL(hw);
43     enum class bias_t { none, fixed, row, column, runtime };
44 
45     struct config_t {
46         ngen::DataType a_type, b_type, c_type, acc_type;
47         ngen::DataType co_type = ngen::DataType::invalid;
48         ngen::DataType scale_type = ngen::DataType::f;
49         bool alpha1, beta0, beta1;
50 
51         bool a_bias = false;
52         bool b_bias = false;
53         bias_t c_bias = bias_t::none;
54         bool early_c_bias = false;
55         bool c_packed = false;
56         bool batch = false;
57         bool emulate64 = (hw == ngen::HW::XeHPG);
58 
59         int tile_m = 32;
60         int tile_n = 48;
61         bool walk_n_first = false;
62         bool alt_barriers = false;
63         bool use_slm_fence = true;
64         bool c_remainder = true;
65         bool c_align16_check = true;
66         bool pad_a = true;
67         bool global_3x_buf = false;
68         bool fulsim = true;
69 
70         post_ops_t post_ops;
71         bool post_op_is_fwd = true;
72         float eltwise_alpha, eltwise_beta, eltwise_scale;
73 
have_post_opdnnl::impl::gpu::jit::xehp_systolic_gemm_kernel_t::config_t74         bool have_post_op() const { return (post_ops.len() > 0); }
75 
validdnnl::impl::gpu::jit::xehp_systolic_gemm_kernel_t::config_t76         bool valid() const {
77             using ngen::DataType;
78 
79             bool ok = true;
80             if (c_type == DataType::d || c_type == DataType::ud) {
81                 ok &= (a_type == DataType::b || a_type == DataType::ub);
82                 ok &= (b_type == DataType::b || b_type == DataType::ub);
83                 ok &= (acc_type == c_type);
84             } else {
85                 ok &= (a_type == b_type);
86                 ok &= (a_type == DataType::bf || a_type == DataType::hf);
87                 ok &= (c_type == DataType::f || c_type == a_type);
88                 ok &= (acc_type == DataType::f);
89                 ok &= !a_bias && !b_bias;
90             }
91             ok &= (alt_barriers || use_slm_fence);
92             ok &= (c_bias == bias_t::none
93                     || (early_c_bias
94                             && (co_type == acc_type || co_type == a_type))
95                     || co_type == c_type);
96             ok &= (tile_m == 32);
97             ok &= (tile_n == 32 || tile_n == 48);
98             ok &= !(tile_n > 32 && global_3x_buf);
99 
100             return ok;
101         }
102 
103         template <typename T>
castdnnl::impl::gpu::jit::xehp_systolic_gemm_kernel_t::config_t104         T &cast() {
105             return *reinterpret_cast<T *>(this);
106         }
107     };
108 
109     static constexpr size_t unroll_k_bytes = 32;
110     static constexpr size_t thread_group_m = 4;
111     static constexpr size_t thread_group_n = 4;
112     static constexpr size_t nominal_subgroup_size = 8;
113 
unroll_k(data_type_t dt)114     static size_t unroll_k(data_type_t dt) {
115         return unroll_k_bytes / types::data_type_size(dt);
116     }
117 
this_unroll_k() const118     int this_unroll_k() const { return unroll_k_bytes / getBytes(cfg.a_type); }
119 
min_block_k(data_type_t dt)120     static int min_block_k(data_type_t dt) {
121         return 8192 / int(types::data_type_size(dt));
122     }
123 
driver_info(int eu_count) const124     CommonDriverInfo driver_info(int eu_count) const {
125         CommonDriverInfo info;
126         info.subgroupSize = nominal_subgroup_size;
127         info.fusedEUs = true;
128         info.loopOrder[0] = LoopM;
129         info.loopOrder[1] = LoopN;
130         info.loopOrder[2] = LoopK;
131         info.unroll[LoopM] = cfg.tile_m;
132         info.unroll[LoopN] = cfg.tile_n;
133         info.unroll[LoopK] = this_unroll_k();
134         info.wg[LoopM] = thread_group_m;
135         info.wg[LoopN] = thread_group_n;
136         info.blocking[LoopM] = 1024;
137         info.blocking[LoopN] = eu_count * 6;
138         info.blocking[LoopK] = 8192 / getBytes(cfg.a_type);
139         info.fixedWG = true;
140         return info;
141     }
142 
143 private:
144     config_t cfg;
145 
146     using injector_t = jit_post_op_injector<hw>;
147     std::unique_ptr<injector_t> post_op_injector;
148 
149     // Surface assignments
150     int ap_surface, bp_surface, co_surface;
151 
152     // Register assignments (main loop)
153     ngen::GRFRange a_copy0 = r40 - r47;
154     ngen::GRFRange b_copy0 = r2 - r13;
155     ngen::GRFRange a_regs = r48 - r63;
156     ngen::GRFRange b_regs = r14 - r37;
157     ngen::GRFRange c_regs = r64 - r255;
158     ngen::GRFRange a_copy1 = r96 - r103;
159     ngen::GRFRange b_copy1 = r104 - r111;
160     ngen::GRFRange a_copy2 = r144 - r151;
161     ngen::GRFRange b_copy2 = r152 - r159;
162     ngen::GRFRange a_copy[3] = {a_copy0, a_copy1, a_copy2};
163     ngen::GRFRange b_copy[3] = {b_copy0, b_copy1, b_copy2};
164     ngen::GRF addr0 = r1;
165     ngen::GRF addr1 = r38;
166     ngen::GRF addr2 = r39;
167     ngen::GRF addr3 = r0;
168     ngen::Subregister a_ptr_mem = addr1.uq(3);
169     ngen::Subregister b_ptr_mem = addr2.uq(3);
170     ngen::Subregister c_ptr_mem = addr2.uq(2);
171     ngen::Subregister slm_a_offset_load = addr1.uw(8); // offsets in OWords
172     ngen::Subregister slm_b_offset_load = addr1.uw(9);
173     ngen::Subregister slm_a_offset_store = addr1.uw(10);
174     ngen::Subregister slm_b_offset_store = addr1.uw(11);
175     ngen::Subregister slm_a_offset_load_init = addr1.uw(6);
176     ngen::Subregister slm_b_offset_load_init = addr1.uw(7);
177     ngen::Subregister slm_a_offset_store_init = addr2.uw(6);
178     ngen::Subregister slm_b_offset_store_init = addr2.uw(7);
179     ngen::Register base_save = acc0.ud();
180     ngen::Subregister k_counter = acc0.d(0);
181     ngen::Subregister ldc_save = acc0.ud(1);
182     ngen::Subregister off_co_save = acc0.ud(2);
183     ngen::Subregister k_save = acc0.ud(3);
184     ngen::Subregister mrem_save = acc0.uw(8);
185     ngen::Subregister nrem_save = acc0.uw(9);
186     ngen::Subregister abo_save = acc0.ud(5);
187     ngen::Subregister ao_save = acc0.w(10);
188     ngen::Subregister bo_save = acc0.w(11);
189     ngen::Subregister alpha_save = acc0.ud(6);
190     ngen::Subregister beta_save = acc0.ud(7);
191     ngen::AccumulatorRegister r0_save = acc2;
192     ngen::Subregister off_asum_save = a0.ud(0);
193     ngen::Subregister off_bsum_save = a0.ud(1);
194     ngen::Subregister flags_save = a0.ud(2);
195 
196     ngen::InstructionModifier dep_addr0 {}, dep_addr1 {}, dep_addr2 {},
197             dep_addr3 {}; // Dependencies for addr registers.
198 
199     // Register assignments (C update)
200     ngen::GRFRange utemp = r32 - r63;
201     ngen::GRFRange uheaders = r0 - r15;
202     ngen::GRFRange upost_op_scratch = r16 - r21;
203     ngen::GRFRange uoffset = r22 - r27;
204     ngen::GRFRange uemulate_temp = r20 - r21;
205 
206     ngen::GRF ubase = r28.ud();
207     ngen::Subregister uldc = r28.ud(1);
208     ngen::Subregister uoff_co = r28.ud(2);
209     ngen::Subregister uk = r28.ud(3);
210     ngen::Subregister um_rem = r28.uw(8);
211     ngen::Subregister un_rem = r28.uw(9);
212     ngen::Subregister uao = r28.w(10);
213     ngen::Subregister ubo = r28.w(11);
214     ngen::Subregister ualpha_regs[2] = {r28.f(6), r30.f(6)};
215     ngen::Subregister ubeta_regs[2] = {r28.f(7), r30.f(7)};
216 
217     ngen::Subregister uc_base = r29.uq(0);
218     ngen::Subregister uoff_co2 = r29.ud(2);
219 
220     ngen::Subregister uao_bo_k = r30.ud(0);
221     ngen::Subregister uflags = r30.ud(1);
222 
223     ngen::Subregister uldc_x2 = r31.ud(1);
224     ngen::Subregister uldc_x4 = r31.ud(2);
225     ngen::Subregister uldc_x8 = r31.ud(3);
226 
227     // 64-bit emulation registers
228     EmulationStrategy emu_strategy;
229     EmulationState emu_state;
230 
231     // Scoreboard usage:
232     //   $0-2   B SLM loads
233     //   $3-4   A SLM loads
234     //   $5     Last DPASW in chain
235     //   $6-7   Load local IDs/kernel arguments
236     //   $8     A copy to SLM
237     //   $9-10  B copy to SLM
238     //   $11    Initial A copy to SLM using C register space
239     //   $12-13 Initial B copy to SLM using C register space
240     //   $14    EOT
241     //   $15    Barriers/SLM fences
242 
243     static constexpr int acc_stride = 48;
244 
245     friend struct EmulationImplementation;
246     template <typename DT = void>
emov(const ngen::InstructionModifier & mod,ngen::RegData dst,ngen::RegData src0)247     void emov(const ngen::InstructionModifier &mod, ngen::RegData dst,
248             ngen::RegData src0) {
249         EmulationImplementation::emov<DT>(*this, mod, dst, src0, emu_strategy);
250     }
251     template <typename DT = void>
emov(const ngen::InstructionModifier & mod,ngen::RegData dst,ngen::Immediate src0)252     void emov(const ngen::InstructionModifier &mod, ngen::RegData dst,
253             ngen::Immediate src0) {
254         EmulationImplementation::emov<DT>(*this, mod, dst, src0, emu_strategy);
255     }
256     template <typename DT = void>
eadd(const ngen::InstructionModifier & mod,const ngen::RegData & dst,const ngen::RegData & src0,const ngen::RegData & src1)257     void eadd(const ngen::InstructionModifier &mod, const ngen::RegData &dst,
258             const ngen::RegData &src0, const ngen::RegData &src1) {
259         EmulationImplementation::eadd<DT>(
260                 *this, mod, dst, src0, src1, emu_strategy, emu_state);
261     }
262     template <typename DT = void>
eadd(const ngen::InstructionModifier & mod,const ngen::RegData & dst,const ngen::RegData & src0,ngen::Immediate src1)263     void eadd(const ngen::InstructionModifier &mod, const ngen::RegData &dst,
264             const ngen::RegData &src0, ngen::Immediate src1) {
265         EmulationImplementation::eadd<DT>(
266                 *this, mod, dst, src0, src1, emu_strategy, emu_state);
267     }
268     template <typename DT = void>
emul(const ngen::InstructionModifier & mod,const ngen::RegData & dst,const ngen::RegData & src0,const ngen::RegData & src1)269     void emul(const ngen::InstructionModifier &mod, const ngen::RegData &dst,
270             const ngen::RegData &src0, const ngen::RegData &src1) {
271         EmulationImplementation::emul<DT>(
272                 *this, mod, dst, src0, src1, emu_strategy, emu_state);
273     }
274     template <typename DT = void>
emul(const ngen::InstructionModifier & mod,const ngen::RegData & dst,const ngen::RegData & src0,ngen::Immediate src1)275     void emul(const ngen::InstructionModifier &mod, const ngen::RegData &dst,
276             const ngen::RegData &src0, ngen::Immediate src1) {
277         EmulationImplementation::emul<DT>(
278                 *this, mod, dst, src0, src1, emu_strategy, emu_state);
279     }
280     template <typename DT = void>
eshl(const ngen::InstructionModifier & mod,ngen::RegData dst,ngen::RegData src0,uint16_t src1)281     void eshl(const ngen::InstructionModifier &mod, ngen::RegData dst,
282             ngen::RegData src0, uint16_t src1) {
283         EmulationImplementation::eshl<DT>(
284                 *this, mod, dst, src0, src1, emu_strategy, emu_state);
285     }
286     template <typename DT = void>
eshr(const ngen::InstructionModifier & mod,ngen::RegData dst,ngen::RegData src0,uint16_t src1)287     void eshr(const ngen::InstructionModifier &mod, ngen::RegData dst,
288             ngen::RegData src0, uint16_t src1) {
289         EmulationImplementation::eshr<DT>(
290                 *this, mod, dst, src0, src1, emu_strategy, emu_state);
291     }
292     template <typename DT = void>
emulConstant(const ngen::InstructionModifier & mod,const ngen::RegData & dst,const ngen::RegData & src0,int32_t src1)293     void emulConstant(const ngen::InstructionModifier &mod,
294             const ngen::RegData &dst, const ngen::RegData &src0, int32_t src1) {
295         EmulationImplementation::emulConstant<DT>(
296                 *this, mod, dst, src0, src1, emu_strategy, emu_state);
297     }
298 
slm_buf_size() const299     int slm_buf_size() const {
300         return cfg.pad_a
301                 ? 10752 // 4.5k A (128x32 + 4*128 padding) + 6k B (192x32)
302                 : 10240; //   4k A (128x32)                 + 6k B (192x32)
303     }
304 
packed_ldc() const305     int packed_ldc() const { return 32 / getBytes(cfg.b_type); }
306 
307     void barrier_prep(
308             const ngen::InstructionModifier &swsb, const ngen::GRF &header);
309     void mul_constant(const ngen::InstructionModifier &mod,
310             const ngen::RegData &dst, const ngen::RegData &src0, int32_t src1);
311     void zero_c();
312 
313     void scattered_setup_c(int stride, bool load);
314     void block_setup_c(bool remainder, bool load);
315 
316     int interleave(int j);
317 
318     void load_c_bias_scattered(bias_t c_bias, bool remainder);
319     void load_c_bias_block(bias_t c_bias);
320     void load_c_bias(bias_t c_bias, bool remainder);
321     void load_c_bias(bool remainder);
322     void convert_c_bias(bias_t c_bias, ngen::DataType dst_type);
323     void add_c_bias(bias_t c_bias);
324     void add_c_bias();
325     bool merge_abc_bias();
326     void add_ab_bias();
327 
328     void load_update_c_internal(bool remainder, bool c_align16);
329     void load_update_c(bool remainder, bool c_align16);
330     void store_c(bool remainder, bool c_align16);
331     void update_c(bool remainder);
332     void update_c();
333 
334     void dpasw_typed(const ngen::InstructionModifier &mod, uint8_t sdepth,
335             uint8_t rcount, const ngen::GRF &c_reg, const ngen::GRF &a_reg,
336             const ngen::GRF &b_reg);
337 
338     void multiply_chunk(int ao, int i0, bool waitb,
339             const ngen::InstructionModifier &swsb0
340             = ngen::InstructionModifier(),
341             const ngen::InstructionModifier &swsb_end
342             = ngen::InstructionModifier());
343     void multiply(int buffer, bool last_multiply = false);
344 
345     void copy_load(int store_buffer, bool use_c = false);
346     void copy_store(int store_buffer, bool first = false);
347     void store_signal(bool force_fence = false);
348 
349     void body();
350 
351 public:
352     xehp_systolic_gemm_kernel_t(config_t cfg_);
353 };
354 
355 } // namespace jit
356 } // namespace gpu
357 } // namespace impl
358 } // namespace dnnl
359 
360 #endif // GPU_JIT_XE_HP_SYSTOLIC_GEMM_KERNEL_HPP
361