1 /*******************************************************************************
2 * Copyright 2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "gpu/jit/conv/kernel_builder.hpp"
18
19 #include <algorithm>
20 #include <array>
21 #include <iostream>
22 #include <memory>
23 #include <utility>
24 #include <vector>
25 #include <unordered_map>
26
27 #include "gpu/jit/conv/config.hpp"
28 #include "gpu/jit/conv/fma_support.hpp"
29 #include "gpu/jit/conv/gemm_schedule.hpp"
30 #include "gpu/jit/conv/ir.hpp"
31 #include "gpu/jit/conv/message_support.hpp"
32 #include "gpu/jit/conv/post_op_support.hpp"
33 #include "gpu/jit/conv/reduce_support.hpp"
34 #include "gpu/jit/conv/reorder_support.hpp"
35 #include "gpu/jit/conv/tensor.hpp"
36
37 namespace dnnl {
38 namespace impl {
39 namespace gpu {
40 namespace jit {
41
42 class permutation_injector_t : public ir_mutator_t {
43 public:
permutation_injector_t(const grf_permutator_t & grf_perm)44 permutation_injector_t(const grf_permutator_t &grf_perm)
45 : grf_perm_(new grf_permutator_t(grf_perm)) {}
46
_mutate(const func_call_t & obj)47 object_t _mutate(const func_call_t &obj) override {
48 if (!is_func_call<reorder_t>(&obj)) return ir_mutator_t::_mutate(obj);
49
50 auto &func = obj.func.as<reorder_t>();
51 auto new_func
52 = reorder_t::make(func.src_layout, func.dst_layout, grf_perm_);
53
54 return new_func.call(obj.args);
55 }
56
57 private:
58 std::shared_ptr<grf_permutator_t> grf_perm_;
59 };
60
61 class dpasw_injector_t {
62 public:
dpasw_injector_t(ngen::HW hw,const stmt_t & load_mul_stmt,const expr_t & c_buf,const stmt_t & c_store_stmt,alloc_updater_t & alloc_updater,const expr_t & tg_idx0)63 dpasw_injector_t(ngen::HW hw, const stmt_t &load_mul_stmt,
64 const expr_t &c_buf, const stmt_t &c_store_stmt,
65 alloc_updater_t &alloc_updater, const expr_t &tg_idx0)
66 : hw_(hw)
67 , load_mul_stmt_(load_mul_stmt)
68 , c_buf_(c_buf)
69 , c_store_stmt_(c_store_stmt)
70 , alloc_updater_(alloc_updater)
71 , tg_idx0_(tg_idx0) {}
72
load_mul_stmt() const73 const stmt_t &load_mul_stmt() const { return load_mul_stmt_; }
74
c_store_stmt() const75 const stmt_t &c_store_stmt() const { return c_store_stmt_; }
76
inject()77 void inject() {
78 expr_t src2_base;
79 extract_dpas_calls(src2_base);
80
81 grf_permutator_t grf_perm(hw_, c_buf_);
82
83 bool was_injected = false;
84 int dpas_count = int(dpas_infos_.size());
85 for (int i = 0; i < dpas_count;) {
86 if (i + 1 < dpas_count) {
87 auto &a = dpas_infos_[i];
88 auto &b = dpas_infos_[i + 1];
89 if (try_convert_to_dpasw(a, b, grf_perm)) {
90 was_injected = true;
91 i += 2;
92 continue;
93 }
94 }
95 if (try_convert_to_dpasw(dpas_infos_[i], grf_perm)) {
96 was_injected = true;
97 }
98 ++i;
99 }
100 // Nothing to update, no dpas -> dpasw transformation.
101 if (!was_injected) return;
102
103 int src2_size = 0;
104 object_map_t<stmt_t, int> send2off;
105 std::function<int(const stmt_t &)> get_src2_off;
106 get_src2_off = [&](const stmt_t &s) {
107 auto &si = find_send_info(s);
108 if (!si.base_call.is_empty()) return get_src2_off(si.base_call);
109 if (!si.prev_send.is_empty()) return get_src2_off(si.prev_send);
110
111 auto it = send2off.find(s);
112 if (it != send2off.end()) return it->second;
113
114 auto ret = send2off.insert({s, src2_size});
115 if (!ret.second) return ret.first->second;
116
117 int new_size = si.new_reg_buf_size();
118 src2_size += new_size;
119 return ret.first->second;
120 };
121 for (auto &si : send_infos_) {
122 if (!si.reg_buf_base().is_equal(src2_base)) continue;
123
124 int src2_off = get_src2_off(si.call);
125 auto src2_sub = src2_base[src2_off];
126 auto new_call = si.new_call;
127 if (!new_call.is_empty()) {
128 new_call = substitute(
129 new_call, send_t::arg_reg_buf(new_call), src2_sub, 1);
130 }
131
132 load_mul_stmt_ = substitute(load_mul_stmt_, si.call, new_call, 1);
133 for (auto &d : si.dpas_consumers) {
134 auto &di = find_dpas_info(d);
135 ir_assert(si.promote_to_dpasw == di.promote_to_dpasw)
136 << "Both send and dpas must be updated.";
137 if (di.update_applied) {
138 ir_error_not_expected() << "Can it happen?";
139 continue;
140 }
141 auto new_call = di.new_call;
142 new_call = substitute(new_call, dpas_t::arg_src2(new_call),
143 src2_sub[di.src2_relative_off], 1);
144 load_mul_stmt_
145 = substitute(load_mul_stmt_, di.call, new_call, 1);
146 di.update_applied = true;
147 }
148 }
149
150 // Apply permutation to C store.
151 c_store_stmt_ = apply_permutation_to_reorder(c_store_stmt_, grf_perm);
152
153 // Update src2 size after applying send updates.
154 alloc_updater_.resize(src2_base, src2_size);
155 }
156
157 private:
158 struct send_info_t {
159 send_info_t() = default;
160
send_info_tdnnl::impl::gpu::jit::dpasw_injector_t::send_info_t161 send_info_t(const stmt_t &call) : call(call), new_call(call) {}
162
senddnnl::impl::gpu::jit::dpasw_injector_t::send_info_t163 const send_t &send() const {
164 return call.as<func_call_t>().func.as<send_t>();
165 }
166
new_senddnnl::impl::gpu::jit::dpasw_injector_t::send_info_t167 const send_t &new_send() const {
168 ir_assert(!new_call.is_same(call));
169 return new_call.as<func_call_t>().func.as<send_t>();
170 }
171
argsdnnl::impl::gpu::jit::dpasw_injector_t::send_info_t172 const std::vector<expr_t> &args() const {
173 return call.as<func_call_t>().args;
174 }
175
reg_bufdnnl::impl::gpu::jit::dpasw_injector_t::send_info_t176 const expr_t ®_buf() const { return send_t::arg_reg_buf(call); }
177
reg_buf_basednnl::impl::gpu::jit::dpasw_injector_t::send_info_t178 const expr_t ®_buf_base() const {
179 return reg_buf().as<ptr_t>().base;
180 }
181
reg_buf_sizednnl::impl::gpu::jit::dpasw_injector_t::send_info_t182 int reg_buf_size() const { return send().register_size(); }
183
new_reg_buf_sizednnl::impl::gpu::jit::dpasw_injector_t::send_info_t184 int new_reg_buf_size() const {
185 if (new_call.is_same(call)) return 0;
186 return new_send().register_size();
187 }
188
set_new_calldnnl::impl::gpu::jit::dpasw_injector_t::send_info_t189 void set_new_call(const stmt_t &s, const stmt_t &base = stmt_t()) {
190 if (!promote_to_dpasw) {
191 promote_to_dpasw = true;
192 new_call = s;
193 base_call = base;
194 return;
195 }
196 ir_assert(new_call.is_equal(s));
197 ir_assert(base_call.is_equal(base));
198 }
199
set_prev_senddnnl::impl::gpu::jit::dpasw_injector_t::send_info_t200 void set_prev_send(const stmt_t &s) {
201 int prev_size
202 = s.as<func_call_t>().func.as<send_t>().register_size();
203 if (reg_buf_size() != prev_size) return;
204 prev_send = s;
205 }
206
207 stmt_t call;
208 std::vector<stmt_t> dpas_consumers;
209
210 bool promote_to_dpasw = false;
211 stmt_t new_call;
212 stmt_t base_call;
213 stmt_t prev_send;
214 };
215
216 struct dpas_info_t {
217 dpas_info_t() = default;
218
dpas_info_tdnnl::impl::gpu::jit::dpasw_injector_t::dpas_info_t219 dpas_info_t(const stmt_t &call) : call(call), new_call(call) {}
220
dpasdnnl::impl::gpu::jit::dpasw_injector_t::dpas_info_t221 const dpas_t &dpas() const {
222 return call.as<func_call_t>().func.as<dpas_t>();
223 }
224
argsdnnl::impl::gpu::jit::dpasw_injector_t::dpas_info_t225 const std::vector<expr_t> &args() const {
226 return call.as<func_call_t>().args;
227 }
228
src1_bufdnnl::impl::gpu::jit::dpasw_injector_t::dpas_info_t229 const expr_t &src1_buf() const { return dpas_t::arg_src1(call); }
230
src2_bufdnnl::impl::gpu::jit::dpasw_injector_t::dpas_info_t231 const expr_t &src2_buf() const { return dpas_t::arg_src2(call); }
232
src2_sizednnl::impl::gpu::jit::dpasw_injector_t::dpas_info_t233 int src2_size() const { return dpas().src2_size(); }
234
set_new_calldnnl::impl::gpu::jit::dpasw_injector_t::dpas_info_t235 void set_new_call(const stmt_t &s, int src2_relative_off) {
236 if (!promote_to_dpasw) {
237 promote_to_dpasw = true;
238 this->src2_relative_off = src2_relative_off;
239 new_call = s;
240 return;
241 }
242 ir_assert(this->src2_relative_off == src2_relative_off);
243 ir_assert(new_call.is_equal(s));
244 }
245
246 stmt_t call;
247 stmt_t send_producer;
248
249 bool promote_to_dpasw = false;
250 bool update_applied = false;
251 int src2_relative_off = 0;
252 stmt_t new_call;
253 };
254
find_send_info(const stmt_t & s)255 send_info_t &find_send_info(const stmt_t &s) {
256 for (auto &si : send_infos_)
257 if (si.call.is_same(s)) return si;
258 ir_error_not_expected();
259 return send_infos_.front();
260 }
261
find_dpas_info(const stmt_t & s)262 dpas_info_t &find_dpas_info(const stmt_t &s) {
263 for (auto &si : dpas_infos_)
264 if (si.call.is_same(s)) return si;
265 ir_error_not_expected();
266 return dpas_infos_.front();
267 }
is_send(const stmt_t & s,send_info_t & info)268 static bool is_send(const stmt_t &s, send_info_t &info) {
269 if (!is_func_call<send_t>(s)) return false;
270 info = send_info_t(s);
271 return true;
272 }
273
is_dpas(const stmt_t & s,dpas_info_t & info)274 static bool is_dpas(const stmt_t &s, dpas_info_t &info) {
275 if (!is_func_call<dpas_t>(s)) return false;
276 info = dpas_info_t(s);
277 return true;
278 }
279
extract_dpas_calls(expr_t & src2_base)280 void extract_dpas_calls(expr_t &src2_base) {
281 object_eq_map_t<expr_t, stmt_t> buf2send;
282
283 auto set_src2_base = [&](const expr_t &ptr) {
284 auto &ptr_base = ptr.as<ptr_t>().base;
285 if (src2_base.is_empty()) {
286 src2_base = ptr_base;
287 return;
288 }
289 // This may need a fix in the future.
290 ir_assert(src2_base.is_same(ptr_base));
291 };
292
293 // Iterate through dpas and send calls.
294 auto stmt_vec = flatten_statements(load_mul_stmt_);
295 for (auto &s : stmt_vec) {
296 send_info_t send_info;
297 if (is_send(s, send_info)) {
298 auto &buf = send_info.reg_buf();
299 stmt_t prev_send;
300 auto it = buf2send.find(buf);
301 if (it != buf2send.end()) prev_send = it->second;
302 buf2send[buf] = s;
303 send_infos_.push_back(send_info);
304 if (!prev_send.is_empty()) {
305 send_infos_.back().set_prev_send(prev_send);
306 }
307 continue;
308 }
309 dpas_info_t dpas_info;
310 if (is_dpas(s, dpas_info)) {
311 set_src2_base(dpas_info.src2_buf());
312 auto &buf = dpas_info.src2_buf();
313 auto it = buf2send.find(buf);
314 if (it == buf2send.end()) continue;
315 auto &send_info = find_send_info(it->second);
316 // Ensure read size matches DPAS src2 size.
317 // FIXME: This may not be always the case.
318 ir_assert(send_info.reg_buf_size() == dpas_info.src2_size());
319 dpas_info.send_producer = send_info.call;
320 send_info.dpas_consumers.push_back(s);
321 dpas_infos_.push_back(dpas_info);
322 }
323 }
324 }
325
326 // Checks for the following pattern:
327 // dpas.sxr(a_dst, a_src0, src1, src2)
328 // dpas.sxr(b_dst, b_src0, src1, src2 + s * r * 4)
can_convert_to_dpasw(const dpas_info_t & a,const dpas_info_t & b)329 static bool can_convert_to_dpasw(
330 const dpas_info_t &a, const dpas_info_t &b) {
331 if (!a.dpas().is_equal(b.dpas())) return false;
332 if (!a.src1_buf().is_equal(b.src1_buf())) return false;
333
334 auto src2_off0 = to_cpp<int>(a.src2_buf().as<ptr_t>().off);
335 auto src2_off1 = to_cpp<int>(b.src2_buf().as<ptr_t>().off);
336
337 if (src2_off1 - src2_off0 != a.src2_size()) return false;
338
339 return true;
340 }
341
try_convert_to_dpasw(dpas_info_t & a,dpas_info_t & b,grf_permutator_t & grf_perm)342 bool try_convert_to_dpasw(
343 dpas_info_t &a, dpas_info_t &b, grf_permutator_t &grf_perm) {
344 if (hw_ >= ngen::HW::XeHPC) return false;
345
346 // Check if DPAS -> DPASW transformation is possible.
347 if (!can_convert_to_dpasw(a, b)) return false;
348
349 // Perform the transformation:
350 // Before:
351 // send(slm, a_off, src2[0])
352 // send(slm, b_off, src2[s * r * 4])
353 // dpas.sxr(a_dst, a_src0, src1, src2[0])
354 // dpas.sxr(b_dst, b_src0, src1, src2[s * r * 4])
355 // After:
356 // send(slm, a_off + (tg_idx0 % 2) * (b_off - a_off), src2)
357 // dpasw.sxr(p_a_dst, p_a_src0, src1, src2[0])
358 // dpasw.sxr(p_b_dst, p_b_src0, src1, src2[s * r * 4 / 2])
359 // Where:
360 // p_a_dst[:] = a_dst[0:rcount / 2] + b_dst[0:rcount / 2]
361 // p_b_dst[:] = a_dst[rcount / 2:rcount] + b_dst[rcount / 2:rcount]
362 ir_assert(a.dpas().is_equal(b.dpas()));
363 auto _dpasw = dpas_t::make_dpasw(a.dpas());
364 auto &dpasw = _dpasw.as<dpas_t>();
365
366 auto a_args = a.args();
367 auto b_args = b.args();
368 dpas_t::arg_src2(b_args) -= dpasw.src2_size();
369
370 a.set_new_call(dpasw.call(a.args()), 0);
371 b.set_new_call(dpasw.call(b_args), dpasw.src2_size());
372
373 // Record permutation for registers to apply it for the destination
374 // store later.
375 const auto grf_size = ngen::GRF::bytes(hw_);
376 const auto rcount = a.dpas().rcount;
377 for (int j = 0; j < rcount; j++) {
378 int k = j % (rcount / 2);
379 auto a_old = dpas_t::arg_dst(a_args) + grf_size * j;
380 auto b_old = dpas_t::arg_dst(b_args) + grf_size * j;
381 expr_t grf_new;
382 if (j < rcount / 2) {
383 grf_new = dpas_t::arg_dst(a_args)[grf_size * k];
384 } else {
385 grf_new = dpas_t::arg_dst(b_args)[grf_size * k];
386 }
387 grf_perm.set_permute(a_old, grf_new);
388 grf_perm.set_permute(b_old, grf_new + grf_size * rcount / 2);
389 }
390
391 auto &a_send = find_send_info(a.send_producer);
392 auto &b_send = find_send_info(b.send_producer);
393
394 auto &a_mem_off = send_t::arg_mem_off(a_send.call);
395 auto &b_mem_off = send_t::arg_mem_off(b_send.call);
396 auto ab_addr_diff = simplify(b_mem_off - a_mem_off);
397 ir_assert(is_const(ab_addr_diff));
398
399 auto new_send_args = a_send.args();
400 send_t::arg_mem_off(new_send_args)
401 += (tg_idx0_ % 2) * to_cpp<int64_t>(ab_addr_diff);
402
403 a_send.set_new_call(a_send.send().call(new_send_args));
404 b_send.set_new_call(stmt_t(), a_send.call);
405
406 return true;
407 }
408
can_convert_to_dpasw(const dpas_info_t & a_dpas,const send_info_t & a_send,const expr_t & tg_idx0)409 static bool can_convert_to_dpasw(const dpas_info_t &a_dpas,
410 const send_info_t &a_send, const expr_t &tg_idx0) {
411 if (contains_object(a_send.call, tg_idx0)) return false;
412 return a_dpas.dpas().rcount % 2 == 0;
413 }
414
create_half_send(const send_t & send)415 static func_t create_half_send(const send_t &send) {
416 ir_assert(send.data_elems % 2 == 0) << "Can't create half-send.";
417 auto _s = send.with_data_elems(send.data_elems / 2);
418 auto &s = _s.as<send_t>();
419 ir_assert(s.is_supported())
420 << "Can't find send reading half of the original send.";
421 MAYBE_UNUSED(s);
422 return _s;
423 }
424
try_convert_to_dpasw(dpas_info_t & a,grf_permutator_t & grf_perm)425 bool try_convert_to_dpasw(dpas_info_t &a, grf_permutator_t &grf_perm) {
426 if (hw_ >= ngen::HW::XeHPC) return false;
427 if (!can_convert_to_dpasw(a, find_send_info(a.send_producer), tg_idx0_))
428 return false;
429
430 // Perform the transformation:
431 // Before:
432 // send(slm, a_off, src2[0])
433 // dpas.sxr(a_dst, a_src0, src1, src2[0])
434 // After:
435 // send(slm, a_off + (tg_idx0 % 2) * (s * r * 4 / 2), src2)
436 // dpasw.sxr(a_dst, a_src0, src1, src2[0])
437
438 auto _dpasw = dpas_t::make_dpasw(a.dpas());
439 auto &dpasw = _dpasw.as<dpas_t>();
440
441 a.set_new_call(dpasw.call(a.args()), 0);
442
443 // Real permutation is not required but it needs to be set anyway.
444 const auto grf_size = ngen::GRF::bytes(hw_);
445 const auto rcount = a.dpas().rcount;
446 for (int j = 0; j < rcount; j++) {
447 auto grf = dpas_t::arg_dst(a.args()) + grf_size * j;
448 grf_perm.set_permute(grf, grf);
449 }
450
451 auto &a_send = find_send_info(a.send_producer);
452 auto new_send_args = a_send.args();
453 send_t::arg_mem_off(new_send_args)
454 += (tg_idx0_ % 2) * to_cpp<int64_t>(a.src2_size() / 2);
455 a_send.set_new_call(
456 create_half_send(a_send.send()).call(new_send_args));
457
458 return true;
459 }
460
apply_permutation_to_reorder(const stmt_t & stmt,const grf_permutator_t & grf_perm)461 static stmt_t apply_permutation_to_reorder(
462 const stmt_t &stmt, const grf_permutator_t &grf_perm) {
463 return permutation_injector_t(grf_perm).mutate(stmt);
464 }
465
466 ngen::HW hw_;
467 stmt_t load_mul_stmt_;
468 expr_t c_buf_;
469 stmt_t c_store_stmt_;
470 alloc_updater_t &alloc_updater_;
471 expr_t tg_idx0_;
472
473 std::vector<dpas_info_t> dpas_infos_;
474 std::vector<send_info_t> send_infos_;
475 };
476
477 // Transforms DPAS to DPASW.
inject_dpasw(ngen::HW hw,stmt_t & load_mul_stmt,const expr_t & c_buf,stmt_t & c_store_stmt,alloc_updater_t & alloc_updater,const expr_t & tg_idx0)478 void inject_dpasw(ngen::HW hw, stmt_t &load_mul_stmt, const expr_t &c_buf,
479 stmt_t &c_store_stmt, alloc_updater_t &alloc_updater,
480 const expr_t &tg_idx0) {
481 dpasw_injector_t injector(
482 hw, load_mul_stmt, c_buf, c_store_stmt, alloc_updater, tg_idx0);
483 injector.inject();
484
485 load_mul_stmt = injector.load_mul_stmt();
486 c_store_stmt = injector.c_store_stmt();
487 }
488
489 // Adds {Atomic} modifier to DPAS/DPASW instructions when applicable.
inject_atomic(const stmt_t & stmt)490 stmt_t inject_atomic(const stmt_t &stmt) {
491 stmt_t ret = stmt;
492 auto stmt_vec = flatten_statements(stmt);
493 for (size_t i = 0; i < stmt_vec.size(); i++) {
494 bool ok = true;
495 ok &= is_func_call<dpas_t>(stmt_vec[i]);
496 ok &= (i + 1 < stmt_vec.size()
497 && is_func_call<dpas_t>(stmt_vec[i + 1]));
498 if (ok) {
499 auto &cur_src1 = dpas_t::arg_src1(stmt_vec[i]);
500 auto &next_src1 = dpas_t::arg_src1(stmt_vec[i + 1]);
501 // Compare src1, apply {Atomic} if they are equal.
502 if (cur_src1.is_equal(next_src1)) {
503 auto &s = stmt_vec[i];
504 auto atomic_attr = instruction_modifier_attr_t::make(
505 ngen_proxy::InstructionModifier().with_atomic());
506 ret = substitute(ret, s, atomic_attr.apply_to(s));
507 }
508 }
509 }
510 return ret;
511 }
512
513 // Trace for debugging purposes.
trace_pass(const char * pass_name,const stmt_t & stmt)514 void trace_pass(const char *pass_name, const stmt_t &stmt) {
515 ir_trace() << "=== After " << pass_name << std::endl;
516 ir_trace() << stmt << std::endl;
517 }
518
519 class external_var_visitor_t : public scope_visitor_t {
520 public:
_visit(const var_t & obj)521 void _visit(const var_t &obj) {
522 if (!is_expr_defined(obj)) external_vars.insert(obj);
523 }
524
525 object_eq_set_t<expr_t> external_vars;
526 };
527
inject_external_var_let(const stmt_t & _stmt)528 stmt_t inject_external_var_let(const stmt_t &_stmt) {
529 auto stmt = _stmt;
530 external_var_visitor_t v;
531 v.visit(stmt);
532
533 for (auto &var : v.external_vars)
534 stmt = let_t::make(var, {}, stmt);
535
536 trace_pass("inject_external_var_let", stmt);
537 return stmt;
538 }
539
540 class slm_buffer_merger_t : public ir_mutator_t {
541 public:
slm_buffer_merger_t()542 slm_buffer_merger_t() {
543 slm_base_ = make_buffer("slm");
544 slm_off_.push_back(0);
545 }
546
slm_base() const547 const expr_t &slm_base() const { return slm_base_; }
548
slm_size() const549 int slm_size() const { return slm_size_; }
550
_mutate(const alloc_t & obj)551 object_t _mutate(const alloc_t &obj) override {
552 if (obj.kind != alloc_kind_t::slm) return ir_mutator_t::_mutate(obj);
553
554 auto new_buf = push(obj);
555 auto new_obj = ir_mutator_t::_mutate(obj);
556 pop();
557
558 auto &alloc = new_obj.as<alloc_t>();
559 new_obj = substitute(alloc.body, alloc.buf, new_buf);
560
561 return new_obj;
562 }
563
564 private:
push(const alloc_t & obj)565 expr_t push(const alloc_t &obj) {
566 int cur_off = slm_off_.back();
567 expr_t new_buf = slm_base_ + cur_off;
568 slm_off_.push_back(cur_off + obj.size);
569 slm_size_ = std::max(slm_size_, cur_off + obj.size);
570 return new_buf;
571 }
572
pop()573 void pop() { slm_off_.pop_back(); }
574
575 expr_t slm_base_;
576 std::vector<int> slm_off_;
577 int slm_size_ = 0;
578 };
579
580 // Merges all SLM buffers into a single one.
merge_slm_buffers(const stmt_t & _stmt)581 stmt_t merge_slm_buffers(const stmt_t &_stmt) {
582 stmt_t stmt = _stmt;
583 slm_buffer_merger_t merger;
584 stmt = merger.mutate(stmt);
585 stmt = alloc_t::make(
586 merger.slm_base(), merger.slm_size(), alloc_kind_t::slm, {}, stmt);
587 trace_pass("merge_slm_buffers", stmt);
588 return stmt;
589 }
590
591 class buffer_offset_lifter_t : public ir_mutator_t {
592 public:
_mutate(const func_call_t & obj)593 object_t _mutate(const func_call_t &obj) {
594 if (!obj.func.is<send_t>()) return ir_mutator_t::_mutate(obj);
595
596 auto &mem_buf = send_t::arg_mem_buf(obj);
597 if (!mem_buf.is<ptr_t>()) return ir_mutator_t::_mutate(obj);
598
599 auto &base = mem_buf.as<ptr_t>().base;
600 auto &off = mem_buf.as<ptr_t>().off;
601
602 std::vector<expr_t> new_args = obj.args;
603 send_t::arg_mem_buf(new_args) = base;
604 send_t::arg_mem_off(new_args) += off;
605 return obj.func.call(new_args, obj.attr);
606 }
607 };
608
lift_buffer_offsets_in_send(const stmt_t & s)609 stmt_t lift_buffer_offsets_in_send(const stmt_t &s) {
610 buffer_offset_lifter_t lifter;
611 auto ret = lifter.mutate(s);
612 trace_pass("lift_buffer_offsets_in_send", ret);
613 return ret;
614 }
615
simplify_pass(const stmt_t & s,const constraint_set_t & cset)616 stmt_t simplify_pass(const stmt_t &s, const constraint_set_t &cset) {
617 auto ret = simplify(s, cset);
618 trace_pass("simplify_pass", ret);
619 return ret;
620 }
621
622 class send_injector_t : public ir_mutator_t {
623 public:
send_injector_t(ir_context_t & ir_ctx,const constraint_set_t & cset)624 send_injector_t(ir_context_t &ir_ctx, const constraint_set_t &cset)
625 : ir_ctx_(ir_ctx), cset_(cset) {}
626
_mutate(const func_call_t & obj)627 object_t _mutate(const func_call_t &obj) {
628 auto *send = obj.func.as_ptr<send_t>();
629 if (!send) return ir_mutator_t::_mutate(obj);
630
631 auto &mem_buf = send_t::arg_mem_buf(obj);
632 auto &mem_off = send_t::arg_mem_off(obj);
633 auto ®_buf = send_t::arg_reg_buf(obj);
634 auto &mask = send_t::arg_mask(obj);
635
636 ir_assert(is_var(mem_buf)) << mem_buf;
637
638 auto header_buf = ir_ctx_.create_tmp_var(type_t::byte_ptr(), "h");
639 auto off_store = simplify_store(
640 send->create_offset_store(header_buf, mem_buf, mem_off));
641
642 auto new_call = func_call_t::make(
643 obj.func, {mem_buf, header_buf, reg_buf, mask}, obj.attr);
644 auto body = stmt_seq_t::make(off_store, new_call);
645
646 // Allocate header.
647 return alloc_t::make(
648 header_buf, send->header_size(), alloc_kind_t::grf, {}, body);
649 }
650
651 private:
simplify_store(const stmt_t & _store) const652 stmt_t simplify_store(const stmt_t &_store) const {
653 auto &store = _store.as<store_t>();
654
655 auto value = store.value;
656 value = simplify(value, cset_);
657
658 // Convert to N-ary form and back to expand multiplications. This
659 // helps to find more common subexpressions during the pass.
660 value = nary_op_canonicalize(value);
661 value = nary_op_back_transform(value);
662
663 return store_t::make(store.buf, store.off, value, store.stride);
664 }
665
666 ir_context_t &ir_ctx_;
667 const constraint_set_t &cset_;
668 };
669
inject_send(const stmt_t & s,ir_context_t & ir_ctx,const constraint_set_t & cset)670 stmt_t inject_send(
671 const stmt_t &s, ir_context_t &ir_ctx, const constraint_set_t &cset) {
672 auto ret = send_injector_t(ir_ctx, cset).mutate(s);
673 trace_pass("inject_send", ret);
674 return ret;
675 }
676
677 class alloc_lifter_t : public ir_mutator_t {
678 public:
alloc_lifter_t(const stmt_t & root,bool reuse_headers)679 alloc_lifter_t(const stmt_t &root, bool reuse_headers)
680 : reuse_headers_(reuse_headers) {
681 if (!reuse_headers_) return;
682 auto calls = find_objects<func_call_t>(root);
683 for (auto &c : calls) {
684 if (!is_func_call<send_t>(c)) continue;
685 auto header_buf = send_t::arg_mem_off(c);
686 ir_assert(is_var(header_buf)) << header_buf;
687 header_bufs_.insert(header_buf);
688 }
689 }
690
_mutate(const alloc_t & obj)691 object_t _mutate(const alloc_t &obj) override {
692 if (!do_lift(obj)) return ir_mutator_t::_mutate(obj);
693 // Remove alloc and insert it before the compute loop.
694 allocs_.push_back(&obj);
695 return obj.body;
696 }
697
_mutate(const stmt_group_t & obj)698 object_t _mutate(const stmt_group_t &obj) override {
699 bool is_compute_loop = (obj.label == stmt_label_t::compute_loop());
700 if (is_compute_loop) in_compute_loop_ = true;
701 auto new_obj = ir_mutator_t::_mutate(obj);
702 if (is_compute_loop) {
703 in_compute_loop_ = false;
704 // Outermost loop.
705 for (auto it = allocs_.rbegin(); it != allocs_.rend(); ++it) {
706 auto &a = it->as<alloc_t>();
707 new_obj = alloc_t::make(a.buf, a.size, a.kind, a.attr, new_obj);
708 }
709 allocs_.resize(0);
710 }
711 return new_obj;
712 }
713
714 private:
do_lift(const alloc_t & obj) const715 bool do_lift(const alloc_t &obj) const {
716 if (!in_compute_loop_) return false;
717 if (reuse_headers_) {
718 bool is_header_alloc = (header_bufs_.count(obj.buf) != 0);
719 return !is_header_alloc;
720 }
721 return true;
722 }
723
724 bool reuse_headers_;
725 object_set_t<expr_t> header_bufs_;
726
727 bool in_compute_loop_ = false;
728 std::vector<stmt_t> allocs_;
729 };
730
731 // Lifts alloc statements out of loops.
lift_alloc(const stmt_t & s,const conv_config_t & cfg)732 stmt_t lift_alloc(const stmt_t &s, const conv_config_t &cfg) {
733 auto ret = alloc_lifter_t(s, cfg.reuse_headers).mutate(s);
734 trace_pass("lift_alloc", ret);
735 return ret;
736 }
737
738 // Common subexpression elimination support.
739
740 // Represents an expression-candidate to eliminate.
741 class cse_expr_t {
742 public:
cse_expr_t(const expr_t & expr,const ir_path_t & path,int refs=1,const expr_t & cse_var={})743 cse_expr_t(const expr_t &expr, const ir_path_t &path, int refs = 1,
744 const expr_t &cse_var = {})
745 : expr(expr), path(path), refs(refs), cse_var(cse_var) {
746 ir_trace() << "cse_pass: add expression: " << expr << std::endl;
747 }
748
add_usage(const ir_path_t & other_path,bool do_increment=true)749 void add_usage(const ir_path_t &other_path, bool do_increment = true) {
750 if (do_increment) refs++;
751 path.merge(other_path);
752 ir_trace() << "cse_pass: add usage: " << expr
753 << ", total refs: " << refs << std::endl;
754 }
755
756 // Expression to eliminate via let.
757 expr_t expr;
758 // Path to the innermost IR node where the expression can be defined.
759 ir_path_t path;
760 // Number of references to the expression.
761 int refs;
762 // Variable assigned to the expression (if decided to eliminate).
763 expr_t cse_var;
764 };
765
766 // Stores information about all expressions subject to CSEing.
767 class cse_context_t {
768 public:
cse_context_t(ir_context_t & ir_ctx)769 cse_context_t(ir_context_t &ir_ctx) : ir_ctx_(ir_ctx) {}
770
ir_ctx()771 ir_context_t &ir_ctx() { return ir_ctx_; }
772
has(const expr_t & e) const773 bool has(const expr_t &e) const { return cse_exprs_.count(e) != 0; }
774
find_cse_expr(const expr_t & e)775 cse_expr_t &find_cse_expr(const expr_t &e) {
776 ir_assert(has(e)) << e;
777 return cse_exprs_.at(e);
778 }
779
find_cse_expr(const expr_t & e) const780 const cse_expr_t &find_cse_expr(const expr_t &e) const {
781 ir_assert(has(e)) << e;
782 return cse_exprs_.at(e);
783 }
784
has_var(const expr_t & e) const785 bool has_var(const expr_t &e) const {
786 return !find_cse_expr(e).cse_var.is_empty();
787 }
788
get_refs(const expr_t & e) const789 int get_refs(const expr_t &e) const {
790 if (!has(e)) return 0;
791 return find_cse_expr(e).refs;
792 }
793
register_expr(const expr_t & e,const ir_path_t & path)794 void register_expr(const expr_t &e, const ir_path_t &path) {
795 if (e.type().is_bool()) return; // Ignore booleans.
796 auto ret = cse_exprs_.insert({e, cse_expr_t(e, path)});
797 ir_assert(ret.second) << e;
798 MAYBE_UNUSED(ret);
799 }
800
register_expr(const cse_expr_t & cse_expr)801 void register_expr(const cse_expr_t &cse_expr) {
802 auto ret = cse_exprs_.insert({cse_expr.expr, cse_expr});
803 ir_assert(ret.second);
804 MAYBE_UNUSED(ret);
805 }
806
get_or_assign_var(const expr_t & e)807 expr_t get_or_assign_var(const expr_t &e) {
808 auto &cse_expr = find_cse_expr(e);
809 if (cse_expr.cse_var.is_empty()) {
810 cse_expr.cse_var = ir_ctx_.create_tmp_var(e.type());
811 ir_trace() << "cse_pass: assigning var: " << e << " -> "
812 << cse_expr.cse_var << std::endl;
813 }
814 return cse_expr.cse_var;
815 }
816
get_var(const expr_t & e) const817 const expr_t &get_var(const expr_t &e) const {
818 return find_cse_expr(e).cse_var;
819 }
820
get_path(const expr_t & e) const821 const ir_path_t &get_path(const expr_t &e) const {
822 return find_cse_expr(e).path;
823 }
824
add_usage(const expr_t & e,const ir_path_t & path,bool do_increment=true)825 void add_usage(
826 const expr_t &e, const ir_path_t &path, bool do_increment = true) {
827 if (e.type().is_bool()) return; // Ignore booleans.
828 return find_cse_expr(e).add_usage(path, do_increment);
829 }
830
update_expr(const expr_t & old_expr,const expr_t & new_expr)831 void update_expr(const expr_t &old_expr, const expr_t &new_expr) {
832 auto it = cse_exprs_.find(old_expr);
833 ir_assert(it != cse_exprs_.end()) << old_expr;
834 auto &old_cse_expr = it->second;
835 auto new_cse_expr = cse_expr_t(new_expr, old_cse_expr.path,
836 old_cse_expr.refs, old_cse_expr.cse_var);
837 cse_exprs_.erase(it);
838 auto ret = cse_exprs_.insert({new_expr, new_cse_expr});
839 ir_assert(ret.second);
840 MAYBE_UNUSED(ret);
841 }
842
843 template <typename F>
for_each(const F & f) const844 void for_each(const F &f) const {
845 for (auto &kv : cse_exprs_)
846 f(kv.first);
847 }
848
849 private:
850 ir_context_t &ir_ctx_;
851 object_eq_map_t<expr_t, cse_expr_t> cse_exprs_;
852 };
853
854 // Collects statistics about expressions for common subexpression elimination.
855 class cse_visitor_t : public ir_visitor_t {
856 public:
cse_visitor_t(cse_context_t & ctx)857 cse_visitor_t(cse_context_t &ctx) : ctx_(ctx) {}
858
_visit(const binary_op_t & obj)859 void _visit(const binary_op_t &obj) override { visit_expr(obj); }
_visit(const shuffle_t & obj)860 void _visit(const shuffle_t &obj) override {
861 if (is_const_broadcast(obj)) return;
862 visit_expr(obj);
863 }
_visit(const unary_op_t & obj)864 void _visit(const unary_op_t &obj) override { visit_expr(obj); }
865
866 #define HANDLE_IR_OBJECT(type) \
867 void _visit(const type &obj) override { visit_stmt(obj); }
868
869 HANDLE_STMT_IR_OBJECTS()
870
871 #undef HANDLE_IR_OBJECT
872
873 private:
874 template <typename T>
visit_expr(const T & obj)875 void visit_expr(const T &obj) {
876 // Exclude loads as they may have side effects.
877 if (count_objects<load_t>(obj) > 0) {
878 ir_visitor_t::_visit(obj);
879 return;
880 }
881
882 if (propagate_path_) {
883 if (ctx_.has(obj))
884 ctx_.add_usage(obj, root_path_, /*do_increment=*/false);
885 ir_visitor_t::_visit(obj);
886 return;
887 }
888 if (ctx_.has(obj)) {
889 ctx_.add_usage(obj, root_path_);
890 propagate_path_ = true;
891 ir_visitor_t::_visit(obj);
892 propagate_path_ = false;
893 return;
894 }
895 ir_visitor_t::_visit(obj);
896 ctx_.register_expr(obj, root_path_);
897 }
898
899 template <typename T>
visit_stmt(const T & obj)900 void visit_stmt(const T &obj) {
901 if (std::is_same<T, for_t>::value) {
902 visit_for((const object_impl_t &)obj);
903 return;
904 }
905 if (std::is_same<T, let_t>::value) {
906 visit_let((const object_impl_t &)obj);
907 return;
908 }
909 root_path_.push(&obj);
910 ir_visitor_t::_visit(obj);
911 root_path_.pop();
912 }
913
visit_for(const object_impl_t & _obj)914 void visit_for(const object_impl_t &_obj) {
915 auto &obj = (const for_t &)_obj;
916
917 visit(obj.var);
918 visit(obj.init);
919 visit(obj.bound);
920 root_path_.push(&obj);
921 visit(obj.body);
922 root_path_.pop();
923 }
924
visit_let(const object_impl_t & _obj)925 void visit_let(const object_impl_t &_obj) {
926 auto &obj = (const let_t &)_obj;
927
928 visit(obj.var);
929 visit(obj.value);
930 root_path_.push(&obj);
931 visit(obj.body);
932 root_path_.pop();
933 }
934
935 cse_context_t &ctx_;
936 ir_path_t root_path_;
937
938 bool propagate_path_ = false;
939 };
940
941 // Verifies all IR paths are correct (for debugging purposes).
942 class cse_verifier_t : public scope_visitor_t {
943 public:
cse_verifier_t(cse_context_t & ctx)944 cse_verifier_t(cse_context_t &ctx) : ctx_(ctx) {}
945
~cse_verifier_t()946 ~cse_verifier_t() override { ir_assert(to_check_.empty()); }
947
_visit(const binary_op_t & obj)948 void _visit(const binary_op_t &obj) override { visit_expr(obj); }
_visit(const shuffle_t & obj)949 void _visit(const shuffle_t &obj) override { return visit_expr(obj); }
_visit(const unary_op_t & obj)950 void _visit(const unary_op_t &obj) override { visit_expr(obj); }
951
952 #define HANDLE_IR_OBJECT(type) \
953 void _visit(const type &obj) override { visit_stmt(obj); }
954
HANDLE_STMT_IR_OBJECTS()955 HANDLE_STMT_IR_OBJECTS()
956
957 #undef HANDLE_IR_OBJECT
958
959 void verify(const stmt_t &s) {
960 // Phase 0: collect IR paths for expressions.
961 phase_ = 0;
962 visit(s);
963
964 // Phase 1: verify all expressions are defined at their path.
965 phase_ = 1;
966 visit(s);
967 }
968
969 private:
970 template <typename T>
visit_expr(const T & obj)971 void visit_expr(const T &obj) {
972 // Expressions are not used during phase 1.
973 if (phase_ == 1) return;
974 if (ctx_.has(obj)) {
975 auto &path = ctx_.get_path(obj);
976 to_check_[path.back()].push_back(obj);
977 }
978 scope_visitor_t::_visit(obj);
979 }
980
981 template <typename T>
visit_stmt(const T & obj)982 void visit_stmt(const T &obj) {
983 scope_visitor_t::_visit(obj);
984
985 // Statements are not used during phase 0.
986 if (phase_ == 0) return;
987
988 // Phase 1: check that all attached expressions are defined at this
989 // statement.
990 auto it = to_check_.find(obj);
991 if (it != to_check_.end()) {
992 for (auto &e : it->second) {
993 ir_assert(is_expr_defined(e))
994 << "Expression contains undefined variables: " << e;
995 MAYBE_UNUSED(e);
996 }
997 to_check_.erase(it);
998 }
999 }
1000
1001 cse_context_t &ctx_;
1002
1003 int phase_ = 0;
1004 object_map_t<stmt_t, std::vector<expr_t>> to_check_;
1005 };
1006
1007 // Generates let statements for expressions being eliminated.
1008 class cse_let_generator_t : public ir_visitor_t {
1009 public:
cse_let_generator_t(const cse_context_t & ctx,const stmt_t & stmt)1010 cse_let_generator_t(const cse_context_t &ctx, const stmt_t &stmt)
1011 : ctx_(ctx), stmt_(stmt) {}
1012
_visit(const binary_op_t & obj)1013 void _visit(const binary_op_t &obj) override { visit_expr(obj); }
_visit(const shuffle_t & obj)1014 void _visit(const shuffle_t &obj) override { visit_expr(obj); }
_visit(const unary_op_t & obj)1015 void _visit(const unary_op_t &obj) override { visit_expr(obj); }
_visit(const var_t & obj)1016 void _visit(const var_t &obj) override {
1017 auto it = all_vars_.find(obj);
1018 if (it == all_vars_.end()) return;
1019 if (seen_vars_.count(obj) == 0) generate_for_expr(it->second);
1020 }
1021
generate()1022 stmt_t generate() {
1023 ctx_.for_each([&](const expr_t &e) {
1024 auto &cse_var = ctx_.get_var(e);
1025 auto ret = all_vars_.insert({cse_var, e});
1026 ir_assert(ret.second);
1027 MAYBE_UNUSED(ret);
1028 });
1029 ctx_.for_each([&](const expr_t &e) { generate_for_expr(e); });
1030 for (auto it = lets_.rbegin(); it != lets_.rend(); ++it) {
1031 auto &let = it->as<let_t>();
1032 stmt_ = let_t::make(let.var, let.value, stmt_);
1033 }
1034 return stmt_;
1035 }
1036
1037 private:
generate_for_expr(const expr_t & e)1038 void generate_for_expr(const expr_t &e) {
1039 auto &cse_var = ctx_.get_var(e);
1040 if (seen_vars_.count(cse_var) == 1) return;
1041 visit(e);
1042 }
1043
1044 template <typename T>
visit_expr(const T & obj)1045 void visit_expr(const T &obj) {
1046 ir_visitor_t::_visit(obj);
1047 if (ctx_.has(obj) && ctx_.has_var(obj)) {
1048 auto &var = ctx_.get_var(obj);
1049 auto ret = seen_vars_.insert(var);
1050 if (ret.second) lets_.push_back(let_t::make(var, obj));
1051 }
1052 }
1053
1054 const cse_context_t &ctx_;
1055 stmt_t stmt_;
1056
1057 object_map_t<expr_t, expr_t> all_vars_; // Var -> expression.
1058 object_set_t<expr_t> seen_vars_;
1059
1060 std::vector<stmt_t> lets_;
1061 };
1062
1063 // Eliminiates expressions from the statement.
1064 class cse_mutator_t : public ir_mutator_t {
1065 public:
cse_mutator_t(cse_context_t & ctx)1066 cse_mutator_t(cse_context_t &ctx) : ctx_(ctx) {}
1067
_mutate(const binary_op_t & obj)1068 object_t _mutate(const binary_op_t &obj) override {
1069 return mutate_expr(obj);
1070 }
_mutate(const shuffle_t & obj)1071 object_t _mutate(const shuffle_t &obj) override { return mutate_expr(obj); }
_mutate(const unary_op_t & obj)1072 object_t _mutate(const unary_op_t &obj) override {
1073 return mutate_expr(obj);
1074 }
1075
1076 #define HANDLE_IR_OBJECT(type) \
1077 object_t _mutate(const type &obj) override { return mutate_stmt(obj); }
1078
1079 HANDLE_STMT_IR_OBJECTS()
1080
1081 #undef HANDLE_IR_OBJECT
1082
1083 private:
1084 template <typename T>
mutate_expr(const T & obj)1085 object_t mutate_expr(const T &obj) {
1086 auto new_obj = ir_mutator_t::_mutate(obj);
1087 if (ctx_.has(obj) && !new_obj.is_equal(obj)) {
1088 ctx_.update_expr(obj, new_obj);
1089 }
1090 if (ctx_.get_refs(new_obj) > 1) {
1091 bool has_var = ctx_.has_var(new_obj);
1092 auto var = ctx_.get_or_assign_var(new_obj);
1093 auto &path = ctx_.get_path(new_obj);
1094 if (!has_var) to_update_[path.back()].push_back(new_obj);
1095 return std::move(var);
1096 }
1097 return new_obj;
1098 }
1099
1100 template <typename T>
mutate_stmt(const T & obj)1101 object_t mutate_stmt(const T &obj) {
1102 auto new_obj = ir_mutator_t::_mutate(obj);
1103 auto it = to_update_.find(obj);
1104 if (it == to_update_.end()) return new_obj;
1105
1106 cse_context_t local_ctx(ctx_.ir_ctx());
1107 for (auto &e : it->second) {
1108 local_ctx.register_expr(ctx_.find_cse_expr(e));
1109 }
1110 to_update_.erase(it);
1111
1112 auto body = get_stmt_body(new_obj);
1113 cse_let_generator_t g(local_ctx, body);
1114 body = g.generate();
1115 new_obj = replace_stmt_body(new_obj, body);
1116 return new_obj;
1117 }
1118
1119 cse_context_t &ctx_;
1120 object_map_t<stmt_t, std::vector<expr_t>> to_update_;
1121 };
1122
eliminate_common_subexprs(const stmt_t & _stmt,ir_context_t & ir_ctx)1123 stmt_t eliminate_common_subexprs(const stmt_t &_stmt, ir_context_t &ir_ctx) {
1124 auto stmt = _stmt;
1125
1126 cse_context_t ctx(ir_ctx);
1127
1128 // Collect statistics.
1129 cse_visitor_t visitor(ctx);
1130 visitor.visit(stmt);
1131
1132 #ifndef NDEBUG
1133 // Verify that collected IR paths are correct (cse_expr_t objects are
1134 // defined at those paths).
1135 cse_verifier_t verifier(ctx);
1136 verifier.verify(stmt);
1137 #endif
1138
1139 // Eliminate subexpressions.
1140 cse_mutator_t mutator(ctx);
1141 stmt = mutator.mutate(stmt);
1142
1143 trace_pass("eliminate_common_subexprs", stmt);
1144 return stmt;
1145 }
1146
1147 class hoist_exprs_mutator_t : public ir_mutator_t {
1148 public:
hoist_exprs_mutator_t(ir_context_t & ir_ctx)1149 hoist_exprs_mutator_t(ir_context_t &ir_ctx) : ir_ctx_(ir_ctx) {}
1150
~hoist_exprs_mutator_t()1151 ~hoist_exprs_mutator_t() override { ir_assert(let_vars_.empty()); }
1152
_mutate(const func_call_t & obj)1153 object_t _mutate(const func_call_t &obj) override {
1154 if (!obj.func.is<send_t>()) return ir_mutator_t::_mutate(obj);
1155
1156 std::vector<expr_t> new_args;
1157 for (auto &e : obj.args) {
1158 new_args.push_back(hoist_expr(e));
1159 }
1160
1161 if (ir_utils::is_equal(new_args, obj.args)) return obj;
1162
1163 return func_call_t::make(obj.func, new_args, obj.attr);
1164 }
1165
_mutate(const stmt_group_t & obj)1166 object_t _mutate(const stmt_group_t &obj) override {
1167 if (obj.body.is<for_t>()) {
1168 loops_.emplace_back(obj.body.as<for_t>().var);
1169 const for_t *for_obj = obj.body.as_ptr<for_t>();
1170 auto body = for_obj ? ir_mutator_t::_mutate(*for_obj) : for_obj;
1171 if (body.is_same(obj.body)) return obj;
1172 auto new_obj = stmt_group_t::make(obj.label, body);
1173 return injects_lets_and_pop_loop(new_obj);
1174 }
1175 return ir_mutator_t::_mutate(obj);
1176 }
1177
_mutate(const store_t & obj)1178 object_t _mutate(const store_t &obj) override {
1179 auto value = hoist_expr(obj.value);
1180 if (value.is_equal(obj.value)) return obj;
1181 return store_t::make(obj.buf, obj.off, value, obj.stride);
1182 }
1183
_mutate(const for_t & obj)1184 object_t _mutate(const for_t &obj) override {
1185 loops_.emplace_back(obj.var);
1186 auto new_obj = ir_mutator_t::_mutate(obj);
1187 return injects_lets_and_pop_loop(new_obj);
1188 }
1189
_mutate(const let_t & obj)1190 object_t _mutate(const let_t &obj) override {
1191 bool fully_hoisted = false;
1192 auto new_value = hoist_expr(obj.value, obj.var, &fully_hoisted);
1193 if (fully_hoisted) return mutate(obj.body);
1194 register_let(obj.var, new_value);
1195 auto new_obj = let_t::make(
1196 obj.var, new_value, ir_mutator_t::mutate(obj.body));
1197 unregister_let(obj.var);
1198 return std::move(new_obj);
1199 }
1200
1201 private:
1202 struct loop_info_t {
loop_info_tdnnl::impl::gpu::jit::hoist_exprs_mutator_t::loop_info_t1203 loop_info_t(const expr_t &var) : var(var) {}
1204
1205 expr_t var;
1206 int var_count = 0;
1207 std::vector<stmt_t> lets;
1208 };
1209
hoist_expr(const expr_t & expr,const expr_t & expr_var={},bool * fully_hoisted=nullptr)1210 expr_t hoist_expr(const expr_t &expr, const expr_t &expr_var = {},
1211 bool *fully_hoisted = nullptr) {
1212 if (expr.is_empty()) return expr;
1213 if (expr.type().is_ptr()) return expr;
1214 if (expr.type().is_bool()) return expr;
1215 if (is_const(expr) || is_shuffle_const(expr) || is_var(expr))
1216 return expr;
1217
1218 auto hoisted_expr = hoist_expr_with_add(expr, expr_var, fully_hoisted);
1219 if (!hoisted_expr.is_equal(expr)) return hoisted_expr;
1220
1221 // hoist_expr_with_add() doesn't handle cast so try to hoist it manually.
1222 auto *cast = expr.as_ptr<cast_t>();
1223 if (!cast) return hoisted_expr;
1224
1225 auto hoisted_cast_expr = hoist_expr(cast->expr);
1226 if (!hoisted_cast_expr.is_equal(cast->expr)) {
1227 hoisted_expr = cast_t::make(
1228 cast->type, hoisted_cast_expr, cast->saturate);
1229 }
1230 return hoisted_expr;
1231 }
1232
hoist_expr_with_add(const expr_t & expr,const expr_t & expr_var={},bool * fully_hoisted=nullptr)1233 expr_t hoist_expr_with_add(const expr_t &expr, const expr_t &expr_var = {},
1234 bool *fully_hoisted = nullptr) {
1235 auto cur_expr = nary_op_canonicalize(expr);
1236
__anon70af794b0502(const expr_t &e) 1237 auto is_nary_add = [](const expr_t &e) {
1238 auto *nary = e.as_ptr<nary_op_t>();
1239 return nary && (nary->op_kind == op_kind_t::_add);
1240 };
1241
1242 for (size_t i = 0; i < loops_.size(); i++) {
1243 std::vector<expr_t> invariant_args;
1244 std::vector<expr_t> other_args;
1245 std::vector<expr_t> nary_args;
1246 if (is_nary_add(cur_expr)) {
1247 nary_args = cvt_expr_to_nary_op_args(cur_expr);
1248 } else {
1249 nary_args.push_back(cur_expr);
1250 }
1251 for (auto &_a : nary_args) {
1252 auto a = nary_op_back_transform(_a);
1253 bool is_inv_arg = true;
1254 for (size_t j = i; j < loops_.size(); j++) {
1255 if (!is_invariant(a, loops_[j].var)) is_inv_arg = false;
1256 }
1257 if (is_inv_arg) {
1258 invariant_args.push_back(_a);
1259 } else {
1260 other_args.push_back(_a);
1261 }
1262 }
1263 // Nothing to hoist for this loop, continue.
1264 if (invariant_args.empty()) continue;
1265 if (invariant_args.size() == 1 && is_var(invariant_args[0]))
1266 continue;
1267
1268 // Introduce new variable for the invariant sub-expression.
1269 auto inv_expr = nary_op_back_transform(
1270 make_nary_op(op_kind_t::_add, invariant_args));
1271 expr_t inv_var;
1272 if (!expr_var.is_empty() && other_args.empty()) {
1273 // If nothing to hoist further, reuse the old variable and
1274 // return.
1275 inv_var = expr_var;
1276 } else {
1277 inv_var = ir_ctx_.create_tmp_var(inv_expr.type());
1278 }
1279 auto let = let_t::make(inv_var, inv_expr);
1280 register_let(inv_var, inv_expr);
1281 loops_[i].lets.push_back(let);
1282
1283 if (other_args.empty()) {
1284 if (fully_hoisted) *fully_hoisted = true;
1285 return inv_var;
1286 }
1287
1288 other_args.push_back(inv_var);
1289 cur_expr = make_nary_op(op_kind_t::_add, other_args);
1290 }
1291 return nary_op_back_transform(cur_expr);
1292 }
1293
injects_lets_and_pop_loop(const stmt_t & _s)1294 stmt_t injects_lets_and_pop_loop(const stmt_t &_s) {
1295 stmt_t s = _s;
1296 // Inject let statements if any.
1297 auto &lets = loops_.back().lets;
1298 for (auto it = lets.rbegin(); it != lets.rend(); ++it) {
1299 auto &let = it->as<let_t>();
1300 s = let_t::make(let.var, let.value, s);
1301 unregister_let(let.var);
1302 }
1303 loops_.pop_back();
1304 return s;
1305 }
1306
register_let(const expr_t & var,const expr_t & value)1307 void register_let(const expr_t &var, const expr_t &value) {
1308 let_vars_.insert({var, value});
1309 }
1310
unregister_let(const expr_t & var)1311 void unregister_let(const expr_t &var) { let_vars_.erase(var); }
1312
is_invariant(const expr_t & e,const expr_t & var) const1313 bool is_invariant(const expr_t &e, const expr_t &var) const {
1314 if (contains_object(e, var)) return false;
1315 if (!find_objects<load_t>(e).empty()) return false;
1316
1317 // Check value if this is a let variable.
1318 auto it = let_vars_.find(e);
1319 if (it != let_vars_.end()) return is_invariant(it->second, var);
1320
1321 if (is_var(e)) return true;
1322
1323 // Check transitive dependencies.
1324 auto vars = find_unique_objects<var_t>(e);
1325 for (auto &v : vars) {
1326 if (!is_invariant(v, var)) return false;
1327 }
1328 return true;
1329 }
1330
1331 ir_context_t &ir_ctx_;
1332 std::vector<loop_info_t> loops_;
1333
1334 object_map_t<expr_t, expr_t> let_vars_;
1335 };
1336
1337 // Moves invariant expressions out of loops.
hoist_exprs(const stmt_t & s,ir_context_t & ir_ctx)1338 stmt_t hoist_exprs(const stmt_t &s, ir_context_t &ir_ctx) {
1339 auto ret = hoist_exprs_mutator_t(ir_ctx).mutate(s);
1340 trace_pass("hoist_exprs", ret);
1341 return ret;
1342 }
1343
1344 class loop_strength_reducer_t : public ir_mutator_t {
1345 public:
loop_strength_reducer_t()1346 loop_strength_reducer_t() {
1347 // Create top-level dummy loop.
1348 loops_.emplace_back();
1349 }
1350
~loop_strength_reducer_t()1351 ~loop_strength_reducer_t() override {
1352 // Sanity check, all stores must be applied.
1353 ir_assert(post_inc_stores.empty());
1354 }
1355
_mutate(const for_t & obj)1356 object_t _mutate(const for_t &obj) override {
1357 loops_.emplace_back(obj);
1358 auto new_obj = ir_mutator_t::_mutate(obj);
1359 return inject_stores_and_pop_loop(new_obj);
1360 }
1361
_mutate(const let_t & obj)1362 object_t _mutate(const let_t &obj) override {
1363 int loop_level = int(loops_.size()) - 1;
1364 auto ret = lets_.insert(
1365 {obj.var, let_info_t(obj.var, obj.value, loop_level)});
1366 ir_assert(ret.second);
1367 MAYBE_UNUSED(ret);
1368 auto new_obj = ir_mutator_t::_mutate(obj);
1369 lets_.erase(obj.var);
1370 return new_obj;
1371 }
1372
_mutate(const stmt_group_t & obj)1373 object_t _mutate(const stmt_group_t &obj) override {
1374 if (obj.body.is<for_t>()) {
1375 loops_.emplace_back(obj.body);
1376 const for_t *for_obj = obj.body.as_ptr<for_t>();
1377 auto body = for_obj ? ir_mutator_t::_mutate(*for_obj) : for_obj;
1378 if (body.is_same(obj.body)) return obj;
1379 auto new_obj = stmt_group_t::make(obj.label, body);
1380 return inject_stores_and_pop_loop(new_obj);
1381 }
1382 return ir_mutator_t::_mutate(obj);
1383 }
1384
1385 // Pattern to handle:
1386 // for (...) {
1387 // store(buf_ptr, ...) <- Write (producer).
1388 // // ...
1389 // stmt_t(..., buf_ptr, ...) <- Read (consumer).
1390 // }
_mutate(const store_t & obj)1391 object_t _mutate(const store_t &obj) override {
1392 if (loops_.size() == 1) return ir_mutator_t::_mutate(obj);
1393
1394 // Try to reduce strength, moving the store up.
1395 int init_store_level = -1;
1396 stmt_t init_store_stmt = obj;
1397 post_inc_store_info_t post_inc_store(obj);
1398 for (int level = int(loops_.size()) - 1; level >= 1; level--) {
1399 auto &loop_info = loops_[level];
1400 int refs = count_object(loop_info.loop, obj.buf);
1401 // Producer and consumer - must be 2 references.
1402 if (refs != 2) break;
1403
1404 // Try to insert the store before level-th loop.
1405 auto &store = init_store_stmt.as<store_t>();
1406 auto &store_value = store.value;
1407 auto &loop_var = loop_info.loop_var();
1408
1409 auto cur_value = substitute_let(store_value, level);
1410 auto next_value = substitute(cur_value, loop_var, loop_var + 1);
1411 auto inc = simplify(next_value - cur_value);
1412
1413 // Cannot eliminate loop variable, break.
1414 if (contains_object(inc, loop_var)) break;
1415
1416 // Success, replace store by post-increment store.
1417 init_store_level = level;
1418
1419 auto new_store_value
1420 = substitute(cur_value, loop_var, loop_info.loop_init());
1421 init_store_stmt = store_t::make(store.buf, store.off,
1422 simplify(new_store_value), store.stride);
1423
1424 post_inc_store.update(loop_info, inc);
1425 }
1426
1427 // Can't do anything, return as is.
1428 if (init_store_level == -1) return ir_mutator_t::_mutate(obj);
1429
1430 // Move this store up, remove from here.
1431 loops_[init_store_level].init_stores.push_back(init_store_stmt);
1432 if (!post_inc_store.is_empty()) {
1433 auto ret = post_inc_stores.insert({obj.buf, post_inc_store});
1434 ir_assert(ret.second);
1435 MAYBE_UNUSED(ret);
1436 }
1437 return stmt_t();
1438 }
1439
_mutate(const func_call_t & obj)1440 object_t _mutate(const func_call_t &obj) override {
1441 for (auto &kv : post_inc_stores) {
1442 int refs = count_object(obj, kv.first);
1443 if (refs == 1) {
1444 auto ret = stmt_seq_t::make(obj, kv.second.stmt());
1445 post_inc_stores.erase(kv.first);
1446 return std::move(ret);
1447 }
1448 }
1449 return ir_mutator_t::_mutate(obj);
1450 }
1451
1452 private:
1453 struct loop_info_t {
loop_info_tdnnl::impl::gpu::jit::loop_strength_reducer_t::loop_info_t1454 loop_info_t(const stmt_t &loop = {}) : loop(loop) {}
1455
loop_vardnnl::impl::gpu::jit::loop_strength_reducer_t::loop_info_t1456 const expr_t &loop_var() const { return loop.as<for_t>().var; }
1457
loop_initdnnl::impl::gpu::jit::loop_strength_reducer_t::loop_info_t1458 const expr_t &loop_init() const { return loop.as<for_t>().init; }
1459
loop_bounddnnl::impl::gpu::jit::loop_strength_reducer_t::loop_info_t1460 const expr_t &loop_bound() const { return loop.as<for_t>().bound; }
1461
loop_extentdnnl::impl::gpu::jit::loop_strength_reducer_t::loop_info_t1462 expr_t loop_extent() const { return loop_bound() - loop_init(); }
1463
1464 // Loop being analyzed.
1465 stmt_t loop;
1466 // Stores to insert before the loop.
1467 std::vector<stmt_t> init_stores;
1468
1469 std::vector<stmt_t> lets;
1470 };
1471
1472 struct let_info_t {
let_info_tdnnl::impl::gpu::jit::loop_strength_reducer_t::let_info_t1473 let_info_t(const expr_t &var, const expr_t &value, int loop_level)
1474 : var(var), value(value), loop_level(loop_level) {}
1475
1476 expr_t var;
1477 expr_t value;
1478 int loop_level;
1479 };
1480
1481 struct post_inc_store_info_t {
post_inc_store_info_tdnnl::impl::gpu::jit::loop_strength_reducer_t::post_inc_store_info_t1482 post_inc_store_info_t(const store_t &obj)
1483 : store(&obj), inc(0), last_iter_cond(true), compensation(0) {}
1484
stmtdnnl::impl::gpu::jit::loop_strength_reducer_t::post_inc_store_info_t1485 stmt_t stmt() const {
1486 auto load
1487 = load_t::make(store->value.type(), store->buf, store->off);
1488 return store_t::make(store->buf, store->off, load + inc);
1489 }
1490
is_emptydnnl::impl::gpu::jit::loop_strength_reducer_t::post_inc_store_info_t1491 bool is_empty() const { return is_zero(inc); }
1492
updatednnl::impl::gpu::jit::loop_strength_reducer_t::post_inc_store_info_t1493 void update(const loop_info_t &loop, const expr_t &loop_inc) {
1494 inc = simplify(iif_t::make(
1495 last_iter_cond, inc - compensation + loop_inc, inc));
1496 if (last_iter_cond.is_equal(expr_t(true))) {
1497 last_iter_cond = (loop.loop_var() == loop.loop_bound() - 1);
1498 } else {
1499 last_iter_cond = last_iter_cond
1500 & (loop.loop_var() == loop.loop_bound() - 1);
1501 }
1502 compensation = simplify(loop.loop_extent() * loop_inc);
1503 }
1504
1505 const store_t *store;
1506 expr_t inc;
1507
1508 expr_t last_iter_cond;
1509 expr_t compensation;
1510 };
1511
1512 // Recursively substitutes all variable from let statements located under
1513 // the given loop level.
substitute_let(const expr_t & _e,int loop_level) const1514 expr_t substitute_let(const expr_t &_e, int loop_level) const {
1515 auto e = _e;
1516 for (;;) {
1517 bool found = false;
1518 auto vars = find_unique_objects<var_t>(e);
1519 for (auto &v : vars) {
1520 auto it = lets_.find(v);
1521 if (it == lets_.end()) continue;
1522 auto &let_info = it->second;
1523 // Do not substitute top-level let variables.
1524 if (let_info.loop_level < loop_level) continue;
1525 found = true;
1526 e = substitute(e, v, let_info.value);
1527 }
1528 if (!found) break;
1529 }
1530 return e;
1531 }
1532
1533 // Injects initial store statements if any.
inject_stores_and_pop_loop(const stmt_t & _s)1534 object_t inject_stores_and_pop_loop(const stmt_t &_s) {
1535 stmt_t s = _s;
1536 auto &stores = loops_.back().init_stores;
1537 for (auto it = stores.rbegin(); it != stores.rend(); ++it) {
1538 s = stmt_seq_t::make(*it, s);
1539 }
1540 loops_.pop_back();
1541 // The top-level dummy loop shouldn't be removed.
1542 ir_assert(loops_.size() >= 1);
1543 return std::move(s);
1544 }
1545
1546 // Loops, ordered from outermost to innermost. The first loop is dummy, to
1547 // represent let statements in the top-level scope.
1548 std::vector<loop_info_t> loops_;
1549
1550 // Buffers whose references are to be updated.
1551 object_map_t<expr_t, post_inc_store_info_t> post_inc_stores;
1552
1553 // Let statements available at the current IR node.
1554 object_map_t<expr_t, let_info_t> lets_;
1555 };
1556
1557 // Detects and converts expensive expression operations inside a loop to less
1558 // expensive operations. Example:
1559 // Before:
1560 // for (int j = 0; j < N; j++) {
1561 // int off = off_i + j * K;
1562 // a[off] = j;
1563 // }
1564 // After:
1565 // int off = off_i;
1566 // for (int j = 0; j < N; j++) {
1567 // a[off] = j;
1568 // off += K;
1569 // }
loop_strength_reduce(const stmt_t & s)1570 stmt_t loop_strength_reduce(const stmt_t &s) {
1571 auto ret = loop_strength_reducer_t().mutate(s);
1572 trace_pass("loop_strength_reduce", ret);
1573 return ret;
1574 }
1575
1576 class alloc_let_optimizer_t : public ir_mutator_t {
1577 public:
1578 // Also track alloc_t and for_t to validate all variable usages.
_mutate(const alloc_t & obj)1579 object_t _mutate(const alloc_t &obj) override {
1580 return mutate_scope(obj, obj.buf);
1581 }
1582
_mutate(const for_t & obj)1583 object_t _mutate(const for_t &obj) override {
1584 level_++;
1585 auto new_obj = mutate_scope(obj, obj.var);
1586 level_--;
1587 return new_obj;
1588 }
1589
_mutate(const let_t & obj)1590 object_t _mutate(const let_t &obj) override {
1591 return mutate_scope(obj, obj.var);
1592 }
1593
_mutate(const store_t & obj)1594 object_t _mutate(const store_t &obj) override {
1595 auto &base = (obj.buf.is<var_t>() ? obj.buf : obj.buf.as<ptr_t>().base);
1596 // Do not count store references. If there are only stores to a buffer
1597 // and no other usages, the buffer can be safely removed.
1598 skip_var_ = base;
1599 auto new_obj = ir_mutator_t::_mutate(obj);
1600 skip_var_ = expr_t();
1601 return new_obj;
1602 }
1603
_mutate(const var_t & obj)1604 object_t _mutate(const var_t &obj) override {
1605 ir_assert(refs_.count(obj) == 1)
1606 << "Variable is not defined: " << expr_t(&obj);
1607 if (!skip_var_.is_same(obj)) refs_[&obj].update(increment_, level_);
1608 return ir_mutator_t::_mutate(obj);
1609 }
1610
1611 private:
1612 struct ref_info_t {
ref_info_tdnnl::impl::gpu::jit::alloc_let_optimizer_t::ref_info_t1613 ref_info_t(int level = 0)
1614 : refs(0), min_level(level), max_level(level) {}
1615
updatednnl::impl::gpu::jit::alloc_let_optimizer_t::ref_info_t1616 void update(int increment, int level) {
1617 refs += increment;
1618 max_level = std::max(max_level, level);
1619 }
1620
is_same_leveldnnl::impl::gpu::jit::alloc_let_optimizer_t::ref_info_t1621 bool is_same_level() const { return min_level == max_level; }
1622
1623 int refs;
1624 int min_level;
1625 int max_level;
1626 };
1627
1628 template <typename T>
mutate_scope(const T & obj,const expr_t & var)1629 object_t mutate_scope(const T &obj, const expr_t &var) {
1630 auto ret = refs_.insert({var, ref_info_t(level_)});
1631 ir_assert(ret.second) << stmt_t(obj);
1632 MAYBE_UNUSED(ret);
1633
1634 auto new_obj = ir_mutator_t::_mutate(obj);
1635 auto &ref_info = refs_[var];
1636
1637 if (std::is_same<T, let_t>()) {
1638 new_obj = mutate_let(new_obj.template as<let_t>(), ref_info);
1639 } else if (std::is_same<T, alloc_t>()) {
1640 new_obj = mutate_alloc(new_obj.template as<alloc_t>(), ref_info);
1641 }
1642
1643 refs_.erase(var);
1644 return new_obj;
1645 }
1646
mutate_let(const let_t & obj,const ref_info_t & ref_info)1647 object_t mutate_let(const let_t &obj, const ref_info_t &ref_info) {
1648 ir_assert(ref_info.refs >= 1);
1649 if (ref_info.refs == 1) {
1650 // Variable is not used.
1651 remove_refs(obj);
1652 return obj.body;
1653 }
1654 // Check following conditions to substitute let value:
1655 // - 2 references: one from producer, one from consumer - means single usage
1656 // - Consumer and producer are on the same level (same loop)
1657 // - Variable is not external
1658 if (ref_info.refs == 2 && ref_info.is_same_level()
1659 && !obj.value.is_empty()) {
1660 return substitute(obj.body, obj.var, obj.value);
1661 }
1662 return obj;
1663 }
1664
mutate_alloc(const alloc_t & obj,const ref_info_t & ref_info)1665 object_t mutate_alloc(const alloc_t &obj, const ref_info_t &ref_info) {
1666 ir_assert(ref_info.refs >= 1);
1667 // Buffer is not used, single reference from alloc_t itself. Remove
1668 // stores to the buffer if any.
1669 if (ref_info.refs == 1) return remove_stores(obj.body, obj.buf);
1670 return obj;
1671 }
1672
remove_refs(const let_t & obj)1673 void remove_refs(const let_t &obj) {
1674 increment_ = -1;
1675 mutate(obj.value);
1676 increment_ = 1;
1677 }
1678
1679 // Removes all nested stores to the buffer.
remove_stores(const stmt_t & stmt,const expr_t & buf)1680 stmt_t remove_stores(const stmt_t &stmt, const expr_t &buf) {
1681 auto ret = stmt;
1682 auto stores = find_objects<store_t>(stmt);
1683 for (auto &_s : stores) {
1684 auto &s = _s.as<store_t>();
1685 auto &base = (s.buf.is<var_t>() ? s.buf : s.buf.as<ptr_t>().base);
1686 if (base.is_same(buf)) ret = substitute(ret, _s, stmt_t());
1687 }
1688 return ret;
1689 }
1690
1691 int increment_ = 1;
1692 int level_ = 0;
1693
1694 expr_t skip_var_;
1695 object_map_t<expr_t, ref_info_t> refs_;
1696 };
1697
optimize_alloc_let(const stmt_t & s)1698 stmt_t optimize_alloc_let(const stmt_t &s) {
1699 auto ret = alloc_let_optimizer_t().mutate(s);
1700 trace_pass("optimize_alloc_let", ret);
1701 return ret;
1702 }
1703
1704 class unrolling_updater_t : public ir_mutator_t {
1705 public:
_mutate(const let_t & obj)1706 object_t _mutate(const let_t &obj) override {
1707 if (level_ == 0) {
1708 // Skip top-level let statements.
1709 return ir_mutator_t::_mutate(obj);
1710 }
1711 lets_.push_back(&obj);
1712 auto new_body = mutate(obj.body);
1713 if (!lets_.back()) {
1714 // Let was moved to the innermost loop.
1715 lets_.pop_back();
1716 return new_body;
1717 }
1718 lets_.pop_back();
1719 if (new_body.is_same(obj.body)) return obj;
1720 return let_t::make(obj.var, obj.value, new_body);
1721 }
1722
_mutate(const for_t & obj)1723 object_t _mutate(const for_t &obj) override {
1724 level_++;
1725 found_loop_ = false;
1726 auto new_obj = ir_mutator_t::_mutate(obj);
1727 level_--;
1728 if (!found_loop_) {
1729 // Innermost loop, inject let statements.
1730 auto body = get_stmt_body(new_obj);
1731 for (auto it = lets_.rbegin(); it != lets_.rend(); ++it) {
1732 body = let_t::make((*it)->var, (*it)->value, body);
1733 *it = nullptr;
1734 }
1735 new_obj = replace_stmt_body(new_obj, body);
1736 }
1737 found_loop_ = true;
1738 return new_obj;
1739 }
1740
1741 private:
1742 bool found_loop_ = false;
1743 int level_ = 0;
1744 std::vector<const let_t *> lets_;
1745 };
1746
1747 // Eliminates let statements from the outer loops to be able to unroll loop
1748 // nest for SLM buffering or prefetch injection. Example:
1749 // Before:
1750 // for (int i = 0; i < I; i++) {
1751 // int tmp = TMP;
1752 // for (int j = 0; j < J; j++) {
1753 // ...
1754 // }
1755 // }
1756 // After:
1757 // for (int i = 0; i < I; i++) {
1758 // for (int j = 0; j < J; j++) {
1759 // int tmp = TMP;
1760 // ...
1761 // }
1762 // }
update_loops_for_unrolling(const stmt_t & s,const conv_config_t & cfg)1763 stmt_t update_loops_for_unrolling(const stmt_t &s, const conv_config_t &cfg) {
1764 auto ret = s;
1765 if (cfg.do_loop_unroll) ret = unrolling_updater_t().mutate(s);
1766 trace_pass("update_loops_for_unrolling", ret);
1767 return ret;
1768 }
1769
1770 // Helper structure for for_t.
1771 struct loop_info_t {
1772 loop_info_t() = default;
1773
loop_info_tdnnl::impl::gpu::jit::loop_info_t1774 loop_info_t(const stmt_t &s) {
1775 ir_assert(s.is<for_t>()) << s;
1776 auto &loop = s.as<for_t>();
1777 stmt = s;
1778 var = loop.var;
1779 init_ = loop.init;
1780 bound_ = loop.bound;
1781
1782 auto e_size = simplify(bound_ - init_);
1783 ir_assert(is_const(e_size));
1784 size_ = to_cpp<int>(e_size);
1785 }
1786
initdnnl::impl::gpu::jit::loop_info_t1787 int init() const {
1788 ir_assert(is_const(init_));
1789 return to_cpp<int>(init_);
1790 }
1791
bounddnnl::impl::gpu::jit::loop_info_t1792 int bound() const {
1793 ir_assert(is_const(bound_));
1794 return to_cpp<int>(bound_);
1795 }
1796
sizednnl::impl::gpu::jit::loop_info_t1797 int size() const { return size_; }
1798
1799 stmt_t stmt;
1800 expr_t var;
1801
1802 private:
1803 expr_t init_;
1804 expr_t bound_;
1805 int size_;
1806 };
1807
1808 // Iterates through multiple nested loops with fixed bounds. Used to unroll
1809 // such nested loops.
1810 class multi_loop_iterator_t {
1811 public:
1812 // Ordered from innermost to outermost.
multi_loop_iterator_t(const std::vector<loop_info_t> & loops)1813 multi_loop_iterator_t(const std::vector<loop_info_t> &loops)
1814 : loops_(loops) {
1815 for (auto &l : loops)
1816 var_values_.push_back(l.init());
1817 }
1818
var_value(const expr_t & var) const1819 int var_value(const expr_t &var) const {
1820 for (size_t i = 0; i < loops_.size(); i++) {
1821 if (loops_[i].var.is_same(var)) return var_values_[i];
1822 }
1823 ir_error_not_expected();
1824 return 0;
1825 }
1826
advance(int n=1)1827 void advance(int n = 1) {
1828 if (loops_.empty()) return;
1829 for (int i_n = 0; i_n < n; i_n++) {
1830 for (size_t i = 0; i < loops_.size(); i++) {
1831 auto &l = loops_[i];
1832 if (++var_values_[i] < l.bound()) break;
1833 var_values_[i] = l.init();
1834 }
1835 ir_assert(var_values_.back() < loops_.back().bound());
1836 }
1837 }
1838
is_outer_loop_end() const1839 bool is_outer_loop_end() const {
1840 if (loops_.empty()) return true;
1841 for (size_t i = 0; i < loops_.size() - 1; i++) {
1842 auto &l = loops_[i];
1843 if (var_values_[i] != l.bound() - 1) return false;
1844 }
1845 return true;
1846 }
1847
str() const1848 std::string str() const {
1849 std::ostringstream oss;
1850 oss << "multi_loop_iterator_t(";
1851 for (size_t i = 0; i < loops_.size(); i++) {
1852 oss << (i != 0 ? ", " : "");
1853 oss << loops_[i].var << " = " << var_values_[i];
1854 }
1855 oss << ")";
1856 return oss.str();
1857 }
1858
1859 IR_DEFINE_DUMP()
1860
1861 private:
1862 std::vector<loop_info_t> loops_;
1863 std::vector<int> var_values_;
1864 };
1865
1866 // Extracts different parts of the compute iteration and verifies the loop nest
1867 // is properly formed and can be further injected with SLM buffering.
1868 class compute_step_visitor_t : public ir_visitor_t {
1869 public:
find_stmt_group(const stmt_label_t & label) const1870 stmt_t find_stmt_group(const stmt_label_t &label) const {
1871 auto groups = find_stmt_groups(label);
1872 if (groups.empty()) return stmt_t();
1873 ir_assert(groups.size() == 1);
1874 return groups[0];
1875 }
1876
find_stmt_groups(const stmt_label_t & label) const1877 std::vector<stmt_t> find_stmt_groups(const stmt_label_t &label) const {
1878 std::vector<stmt_t> ret;
1879 for (auto &_g : stmt_groups_) {
1880 auto &g = _g.as<stmt_group_t>();
1881 if (g.label == label) ret.push_back(_g);
1882 }
1883 return ret;
1884 }
1885
inner_let_stmts() const1886 const std::vector<stmt_t> &inner_let_stmts() const {
1887 return inner_let_stmts_;
1888 }
1889
1890 #define HANDLE_IR_OBJECT(type) \
1891 void _visit(const type &obj) override { visit_stmt(obj); }
1892
HANDLE_STMT_IR_OBJECTS()1893 HANDLE_STMT_IR_OBJECTS()
1894
1895 #undef HANDLE_IR_OBJECT
1896
1897 template <typename T>
1898 void visit_stmt(const T &obj) {
1899 auto obj_type_id = T::_type_id();
1900 bool is_for = (obj_type_id == for_t::_type_id());
1901 bool is_stmt_group = (obj_type_id == stmt_group_t::_type_id());
1902 bool is_let = (obj_type_id == let_t::_type_id());
1903 bool is_stmt_seq = (obj_type_id == stmt_seq_t::_type_id());
1904
1905 // Loop may contain:
1906 // - Another loop
1907 // - Container statement (stmt_seq_t or stmt_group_t)
1908 // - Let statement (in the innermost loop only)
1909 // - Barrier
1910 if (loop_level_ > 0) {
1911 bool ok = false;
1912 if (is_for || is_let || is_stmt_group || is_stmt_seq) {
1913 ok = true;
1914 } else if (obj_type_id == func_call_t::_type_id()) {
1915 auto &call = obj.template as<func_call_t>();
1916 ok = call.func.is_equal(funcs::barrier_func());
1917 }
1918
1919 if (!ok) {
1920 ir_error_not_expected()
1921 << "Found unexpected statement inside loop.\n"
1922 << stmt_t(obj);
1923 }
1924 }
1925
1926 bool is_compute_loop = false;
1927 if (is_stmt_group) {
1928 auto label = obj.template as<stmt_group_t>().label;
1929 stmt_groups_.push_back(obj);
1930 if (utils::one_of(label, stmt_label_t::g2s_load(),
1931 stmt_label_t::g2s_store(), stmt_label_t::g2r_load(),
1932 stmt_label_t::s2r_load(), stmt_label_t::prefetch(),
1933 stmt_label_t::mul())) {
1934 // Leaf labels, do not visit them.
1935 return;
1936 }
1937 if (label == stmt_label_t::compute_loop()) {
1938 is_compute_loop = true;
1939 in_compute_loop_ = true;
1940 }
1941 }
1942
1943 if (is_for) loop_level_++;
1944 found_loop_ = false;
1945 ir_visitor_t::_visit(obj);
1946 if (in_compute_loop_ && is_let) {
1947 if (found_loop_)
1948 ir_error_not_expected()
1949 << "Let is allowed in the innermost loop only.";
1950
1951 inner_let_stmts_.push_back(replace_stmt_body(obj, stmt_t()));
1952 }
1953 if (is_for) {
1954 loop_level_--;
1955 found_loop_ = true;
1956 }
1957
1958 if (is_compute_loop) in_compute_loop_ = false;
1959 }
1960
1961 private:
1962 bool found_loop_ = false;
1963 bool in_compute_loop_ = false;
1964 int loop_level_ = 0;
1965
1966 std::vector<stmt_t> stmt_groups_;
1967 std::vector<stmt_t> inner_let_stmts_;
1968 };
1969
1970 // Provides access to different parts of the inner compute iteration.
1971 class compute_step_t {
1972 public:
compute_step_t(const stmt_t & parent)1973 compute_step_t(const stmt_t &parent) {
1974 compute_step_visitor_t v;
1975 v.visit(parent);
1976
1977 compute_loop_ = v.find_stmt_group(stmt_label_t::compute_loop());
1978 g2s_load_ = v.find_stmt_group(stmt_label_t::g2s_load());
1979 g2s_store_ = v.find_stmt_group(stmt_label_t::g2s_store());
1980 prefetch_ = v.find_stmt_group(stmt_label_t::prefetch());
1981 g2r_load_ = v.find_stmt_groups(stmt_label_t::g2r_load());
1982 s2r_load_ = v.find_stmt_groups(stmt_label_t::s2r_load());
1983 mul_ = v.find_stmt_groups(stmt_label_t::mul());
1984 c_zero_out_ = v.find_stmt_group(stmt_label_t::c_zero_out());
1985 inner_let_stmts_ = v.inner_let_stmts();
1986
1987 ir_assert(g2r_load_.size() == mul_.size());
1988 ir_assert(s2r_load_.size() == mul_.size());
1989
1990 // Assign preload/mul tags to let statements.
1991 for (auto &_let : inner_let_stmts_) {
1992 auto &var = _let.as<let_t>().var;
1993 bool is_preload = (count_object(g2s_load_, var) > 0)
1994 || (count_object(prefetch_, var) > 0);
1995 bool is_mul = count_object(g2r_load_, var) > 0;
1996 if (is_preload) preload_lets_.insert(_let);
1997 if (is_mul) mul_lets_.insert(_let);
1998 }
1999
2000 // Propagate preload/mul tags up based on dependencies between let
2001 // statements.
2002 std::vector<let_info_t> let_infos;
2003 object_set_t<stmt_t> seen;
2004 std::function<void(const stmt_t &)> propagate;
2005 propagate = [&](const stmt_t &_let) {
2006 if (seen.count(_let) > 0) return;
2007 auto &let = _let.as<let_t>();
2008 for (auto &_child : inner_let_stmts_) {
2009 auto &child = _child.as<let_t>();
2010 if (_child.is_same(_let)) continue;
2011 if (contains_object(child.value, let.var)) {
2012 // Visit child let statements first.
2013 propagate(_child);
2014 // Propagate child preload/mul values to this let statement.
2015 if (is_preload_let(_child)) preload_lets_.insert(_let);
2016 if (is_mul_let(_child)) mul_lets_.insert(_let);
2017 }
2018 }
2019 auto let_info = create_let_info(
2020 let, is_preload_let(_let), is_mul_let(_let));
2021 let_infos.push_back(let_info);
2022 seen.insert(_let);
2023 };
2024 for (auto &_let : inner_let_stmts_)
2025 propagate(_let);
2026
2027 // Duplicate lets that are used in both preload and mul contexts.
2028 duplicate_lets(let_infos);
2029 }
2030
2031 // See ir_core.hpp for the description.
compute_loop() const2032 const stmt_t &compute_loop() const { return compute_loop_; }
g2s_load() const2033 const stmt_t &g2s_load() const { return g2s_load_; }
g2s_store() const2034 const stmt_t &g2s_store() const { return g2s_store_; }
prefetch() const2035 const stmt_t &prefetch() const { return prefetch_; }
g2r_load() const2036 const std::vector<stmt_t> &g2r_load() const { return g2r_load_; }
s2r_load() const2037 const std::vector<stmt_t> &s2r_load() const { return s2r_load_; }
mul() const2038 const std::vector<stmt_t> &mul() const { return mul_; }
c_zero_out() const2039 const stmt_t &c_zero_out() const { return c_zero_out_; }
inner_let_stmts() const2040 const std::vector<stmt_t> &inner_let_stmts() const {
2041 return inner_let_stmts_;
2042 }
2043
is_preload_let(const stmt_t & s) const2044 bool is_preload_let(const stmt_t &s) const {
2045 return preload_lets_.count(s) > 0;
2046 }
is_mul_let(const stmt_t & s) const2047 bool is_mul_let(const stmt_t &s) const { return mul_lets_.count(s) > 0; }
2048
2049 private:
2050 struct let_info_t {
let_info_tdnnl::impl::gpu::jit::compute_step_t::let_info_t2051 let_info_t(const expr_t &var) : var(var) {}
2052
2053 expr_t var;
2054 expr_t preload_var;
2055 expr_t mul_var;
2056
is_preloaddnnl::impl::gpu::jit::compute_step_t::let_info_t2057 bool is_preload() const { return !preload_var.is_empty(); }
is_muldnnl::impl::gpu::jit::compute_step_t::let_info_t2058 bool is_mul() const { return !mul_var.is_empty(); }
2059
needs_updatednnl::impl::gpu::jit::compute_step_t::let_info_t2060 bool needs_update() const { return is_preload() && is_mul(); }
2061 };
2062
create_let_info(const let_t & let,bool is_preload,bool is_mul)2063 let_info_t create_let_info(const let_t &let, bool is_preload, bool is_mul) {
2064 let_info_t info(let.var);
2065 if (is_preload && !is_mul) {
2066 info.preload_var = let.var;
2067 } else if (!is_preload && is_mul) {
2068 info.mul_var = let.var;
2069 } else if (is_preload && is_mul) {
2070 info.preload_var = create_var_with_suffix(let.var, "p");
2071 info.mul_var = create_var_with_suffix(let.var, "m");
2072 }
2073 return info;
2074 }
2075
duplicate_lets(const std::vector<let_info_t> & let_infos)2076 void duplicate_lets(const std::vector<let_info_t> &let_infos) {
2077 int nlets = int(inner_let_stmts_.size());
2078 ir_assert(int(let_infos.size()) == nlets);
2079
2080 std::vector<stmt_t> new_lets;
2081 for (int i = nlets - 1; i >= 0; i--) {
2082 auto &info = let_infos[i];
2083 auto &old_let = inner_let_stmts_[i].as<let_t>();
2084 if (!info.needs_update()) {
2085 auto new_value = update_var(old_let.value, let_infos,
2086 info.is_preload(), info.is_mul());
2087 auto new_let = inner_let_stmts_[i];
2088 if (!new_value.is_same(old_let.value)) {
2089 new_let = let_t::make(old_let.var, new_value, old_let.body);
2090 }
2091 new_lets.push_back(new_let);
2092 continue;
2093 }
2094
2095 preload_lets_.erase(&old_let);
2096 mul_lets_.erase(&old_let);
2097
2098 auto preload_value
2099 = update_var(old_let.value, let_infos, true, false);
2100 auto preload_let = let_t::make(
2101 info.preload_var, preload_value, old_let.body);
2102
2103 auto mul_value = update_var(old_let.value, let_infos, false, true);
2104 auto mul_let = let_t::make(info.mul_var, mul_value, old_let.body);
2105
2106 preload_lets_.insert(preload_let);
2107 new_lets.push_back(preload_let);
2108
2109 mul_lets_.insert(mul_let);
2110 new_lets.push_back(mul_let);
2111
2112 // Update statements.
2113 g2s_load_ = update_var(g2s_load_, let_infos, true, false);
2114 g2s_store_ = update_var(g2s_store_, let_infos, true, false);
2115 prefetch_ = update_var(prefetch_, let_infos, true, false);
2116 g2r_load_ = update_var(g2r_load_, let_infos, false, true);
2117 s2r_load_ = update_var(s2r_load_, let_infos, false, true);
2118 mul_ = update_var(mul_, let_infos, false, true);
2119 }
2120
2121 std::reverse(new_lets.begin(), new_lets.end());
2122 inner_let_stmts_ = new_lets;
2123 }
2124
2125 template <typename T>
update_var(const std::vector<T> & vec,const std::vector<let_info_t> & let_infos,bool is_preload,bool is_mul)2126 static std::vector<T> update_var(const std::vector<T> &vec,
2127 const std::vector<let_info_t> &let_infos, bool is_preload,
2128 bool is_mul) {
2129 std::vector<T> ret;
2130 for (auto &v : vec)
2131 ret.push_back(update_var(v, let_infos, is_preload, is_mul));
2132 return ret;
2133 }
2134
update_var(const object_t & obj,const std::vector<let_info_t> & let_infos,bool is_preload,bool is_mul)2135 static object_t update_var(const object_t &obj,
2136 const std::vector<let_info_t> &let_infos, bool is_preload,
2137 bool is_mul) {
2138 auto ret = obj;
2139 for (auto &info : let_infos) {
2140 if (!info.needs_update()) continue;
2141 if (!contains_object(ret, info.var)) continue;
2142 if (is_preload) {
2143 ir_assert(info.is_preload());
2144 ret = substitute(ret, info.var, info.preload_var);
2145 } else if (is_mul) {
2146 ir_assert(info.is_mul());
2147 ret = substitute(ret, info.var, info.mul_var);
2148 }
2149 }
2150 return ret;
2151 }
2152
create_var_with_suffix(const expr_t & _var,const std::string & suffix)2153 static expr_t create_var_with_suffix(
2154 const expr_t &_var, const std::string &suffix) {
2155 auto &var = _var.as<var_t>();
2156 auto new_name = var.name + "_" + suffix;
2157 return var_t::make(var.type, new_name);
2158 }
2159
2160 stmt_t compute_loop_;
2161 stmt_t g2s_load_;
2162 stmt_t g2s_store_;
2163 stmt_t prefetch_;
2164 std::vector<stmt_t> g2r_load_;
2165 std::vector<stmt_t> s2r_load_;
2166 std::vector<stmt_t> mul_;
2167 stmt_t c_zero_out_;
2168
2169 std::vector<stmt_t> inner_let_stmts_;
2170
2171 // Due to loop unrolling the inner let statements may depend on different
2172 // indices of the outer loops. There are two contexts:
2173 // - "preload" loop iteration, e.g. index I
2174 // - "multiplication" loop iteration, e.g. index (I + nbuf)
2175 // Preloads (either via SLM or via prefetches) for the corresponding
2176 // multiplication are executed several iterations before the real
2177 // multiplication. That's why we need to know exactly in which context the
2178 // given let statement is used. It might be that the same variable is used
2179 // from two different contexts. In this case it is duplicated and
2180 // initialized with different values for each case.
2181 object_set_t<stmt_t> preload_lets_;
2182 object_set_t<stmt_t> mul_lets_;
2183 };
2184
2185 // Helper class to access the outer loop index after pipelining. Pipelining
2186 // in general requires tracking two versions of a loop index:
2187 // - Multiplication version - corresponding to the iteration that is currently
2188 // used for multiplication
2189 // - Preload version - corresponding to the iteration that is currently used
2190 // for preload for one of the next multiplications
2191 // The multiplilcation version is a few steps behind the preload version.
2192 class outer_loop_info_t : public loop_info_t {
2193 public:
2194 outer_loop_info_t() = default;
2195
outer_loop_info_t(const stmt_t & s,ir_context_t & ir_ctx)2196 outer_loop_info_t(const stmt_t &s, ir_context_t &ir_ctx) : loop_info_t(s) {
2197 // Outer loop may not be used for unrolling hence loop iterations must
2198 // not use its index. If this doesn't hold, introduce a GRF buffer to
2199 // represent that variable and apply post-increment updates after each
2200 // outer loop iteration.
2201 if (count_object(s.as<for_t>().body, var) != 0) {
2202 has_var_refs_ = true;
2203 mul_var_buf_ = ir_ctx.create_tmp_var(
2204 type_t::byte_ptr(), var.as<var_t>().name + "_mul_buf");
2205 preload_var_buf_ = ir_ctx.create_tmp_var(
2206 type_t::byte_ptr(), var.as<var_t>().name + "_preload_buf");
2207
2208 auto mul_alloc = alloc_t::make(
2209 mul_var_buf_, var.type().size(), alloc_kind_t::grf);
2210 auto preload_alloc = alloc_t::make(
2211 preload_var_buf_, var.type().size(), alloc_kind_t::grf);
2212 allocs_.push_back(mul_alloc);
2213 allocs_.push_back(preload_alloc);
2214
2215 auto mul_init = store_t::make(mul_var_buf_, 0, init());
2216 auto preload_init = store_t::make(preload_var_buf_, 0, init());
2217 init_stmt_ = mul_init.append(preload_init);
2218
2219 mul_post_inc_stmt_
2220 = store_t::make(mul_var_buf_, 0, mul_var_load() + 1);
2221 preload_post_inc_stmt_ = store_t::make(
2222 preload_var_buf_, 0, preload_var_load() + 1);
2223 }
2224 }
2225
has_var_refs() const2226 bool has_var_refs() const { return has_var_refs_; }
2227
mul_var_load() const2228 expr_t mul_var_load() const {
2229 return load_t::make(var.type(), mul_var_buf_, 0);
2230 }
preload_var_load() const2231 expr_t preload_var_load() const {
2232 return load_t::make(var.type(), preload_var_buf_, 0);
2233 }
2234
inject_alloc_stmts(const stmt_t & stmt) const2235 stmt_t inject_alloc_stmts(const stmt_t &stmt) const {
2236 return jit::inject_alloc_stmts(stmt, allocs_);
2237 }
2238
init_stmt() const2239 const stmt_t &init_stmt() const { return init_stmt_; }
2240
mul_post_inc_stmt() const2241 const stmt_t &mul_post_inc_stmt() const { return mul_post_inc_stmt_; }
preload_post_inc_stmt() const2242 const stmt_t &preload_post_inc_stmt() const {
2243 return preload_post_inc_stmt_;
2244 }
2245
2246 private:
2247 bool has_var_refs_ = false;
2248
2249 // Helper expressions/statements to partially unroll the loop.
2250 expr_t mul_var_buf_;
2251 expr_t preload_var_buf_;
2252 std::vector<stmt_t> allocs_;
2253 stmt_t init_stmt_;
2254 stmt_t mul_post_inc_stmt_;
2255 stmt_t preload_post_inc_stmt_;
2256 };
2257
2258 // Helper class to work with loop nest of the compute loop.
2259 class compute_loop_nest_t {
2260 public:
2261 compute_loop_nest_t() = default;
2262
compute_loop_nest_t(const stmt_t & root,ir_context_t & ir_ctx)2263 compute_loop_nest_t(const stmt_t &root, ir_context_t &ir_ctx)
2264 : root_(root) {
2265 for (auto &l : find_objects<for_t>(root)) {
2266 loops_.emplace_back(l);
2267 }
2268
2269 if (loops_.empty()) {
2270 outer_loop_size_ = 1;
2271 return;
2272 }
2273
2274 outer_loop_ = outer_loop_info_t(loops_.back().stmt, ir_ctx);
2275 outer_loop_size_ = outer_loop_.size();
2276 }
2277
loops() const2278 const std::vector<loop_info_t> &loops() const { return loops_; }
2279
2280 // Number of iterations of all loops.
size() const2281 int size() const {
2282 int ret = 1;
2283 for (auto &l : loops_)
2284 ret *= l.size();
2285 return ret;
2286 }
2287
2288 // Number of iterations in the outermost loop (see comments in ctor).
outer_loop_size() const2289 int outer_loop_size() const { return outer_loop_size_; }
2290
outer_loop_info() const2291 const outer_loop_info_t &outer_loop_info() const { return outer_loop_; }
2292
2293 template <typename F>
for_each_loop_var(const F & f) const2294 void for_each_loop_var(const F &f) const {
2295 for (auto &l : loops_)
2296 f(l.var);
2297 }
2298
2299 // Number of iterations of all loops except the outermost.
inner_loops_size() const2300 int inner_loops_size() const { return size() / outer_loop_size(); }
2301
2302 private:
2303 stmt_t root_;
2304 std::vector<loop_info_t> loops_;
2305
2306 int outer_loop_size_;
2307 outer_loop_info_t outer_loop_;
2308 };
2309
2310 struct compute_params_t {
2311 compute_params_t() = default;
2312
compute_params_tdnnl::impl::gpu::jit::compute_params_t2313 compute_params_t(int slm_bufs, int gmem_bufs, int slm_buf_size,
2314 int prefetch_bufs, int inner_loops_iters)
2315 : slm_bufs(slm_bufs)
2316 , gmem_bufs(gmem_bufs)
2317 , slm_buf_size(slm_buf_size)
2318 , prefetch_bufs(prefetch_bufs) {
2319 use_slm = (slm_buf_size > 0);
2320 use_prefetch = (prefetch_bufs > 0);
2321 ir_assert(!use_slm || !use_prefetch)
2322 << "Can't have both SLM buffering and prefetch enabled.";
2323 if (use_slm) {
2324 ir_assert(utils::one_of(slm_bufs, 1, 2, 3));
2325 ir_assert(utils::one_of(gmem_bufs, 1, 2));
2326 preload_bufs = slm_bufs;
2327 unroll = math::lcm(slm_bufs * gmem_bufs, inner_loops_iters);
2328 } else if (use_prefetch) {
2329 preload_bufs = prefetch_bufs;
2330 ir_assert(slm_bufs == 0);
2331 ir_assert(gmem_bufs == 0);
2332 unroll = math::lcm(prefetch_bufs, inner_loops_iters);
2333 } else {
2334 preload_bufs = 0;
2335 ir_assert(slm_bufs == 0);
2336 ir_assert(gmem_bufs == 0);
2337 unroll = inner_loops_iters;
2338 }
2339 }
2340
2341 int slm_bufs;
2342 int gmem_bufs;
2343 int slm_buf_size;
2344 int prefetch_bufs;
2345 int preload_bufs;
2346 int unroll;
2347
2348 bool use_slm;
2349 bool use_prefetch;
2350 };
2351
2352 // Helper class to implement SLM buffering.
2353 class compute_iterator_t {
2354 public:
compute_iterator_t(const compute_params_t & params,const compute_loop_nest_t & loop_nest)2355 compute_iterator_t(const compute_params_t ¶ms,
2356 const compute_loop_nest_t &loop_nest)
2357 : params(params)
2358 , preload_loop_it(loop_nest.loops())
2359 , mul_loop_it(loop_nest.loops()) {
2360
2361 int compute_iters = loop_nest.size();
2362 iters = compute_iters;
2363 ir_assert(iters >= 1) << "Empty loop is not expected.";
2364
2365 iters += std::max(0, preload_bufs() - 1) + std::max(0, gmem_bufs() - 1);
2366 ramp_up_iters
2367 = std::max(1, preload_bufs() + std::max(0, gmem_bufs() - 1));
2368 ramp_down_iters = std::min(
2369 std::max(0, preload_bufs() - 1) + std::max(0, gmem_bufs() - 1),
2370 iters - ramp_up_iters);
2371 body_iters = iters - ramp_up_iters - ramp_down_iters;
2372 body_iters = utils::rnd_dn(body_iters, params.unroll);
2373 ramp_down_iters = iters - ramp_up_iters - body_iters;
2374
2375 ir_assert(ramp_up_iters + body_iters + ramp_down_iters == iters);
2376
2377 iter = 0;
2378 linear_id = 0;
2379 riter = iters - 1;
2380 }
2381
unroll() const2382 int unroll() const { return params.unroll; }
2383
preload_bufs() const2384 int preload_bufs() const { return params.preload_bufs; }
2385
slm_bufs() const2386 int slm_bufs() const { return params.slm_bufs; }
2387
gmem_bufs() const2388 int gmem_bufs() const { return params.gmem_bufs; }
2389
operator ++()2390 compute_iterator_t &operator++() {
2391 if (do_preload()) preload_loop_it.advance();
2392 if (do_mul()) mul_loop_it.advance();
2393 ++iter;
2394 ++linear_id;
2395 --riter;
2396 return *this;
2397 }
2398
advance(int n)2399 void advance(int n) {
2400 if (n == 0) return;
2401
2402 ir_assert(n % params.unroll == 0);
2403 ir_assert(iter + n <= iters);
2404
2405 if (preload_bufs() > 0) ir_assert(do_preload());
2406 ir_assert(do_mul());
2407
2408 iter += n;
2409 riter -= n;
2410
2411 if (preload_bufs() > 0) preload_loop_it.advance(n);
2412 mul_loop_it.advance(n);
2413 }
2414
do_mul() const2415 bool do_mul() const {
2416 return iter >= std::max(0, preload_bufs() - 1)
2417 + std::max(0, gmem_bufs() - 1);
2418 }
2419
is_first_mul() const2420 bool is_first_mul() const {
2421 return iter
2422 == std::max(0, preload_bufs() - 1)
2423 + std::max(0, gmem_bufs() - 1);
2424 }
is_last_mul() const2425 bool is_last_mul() const { return riter == 0; }
2426
do_preload() const2427 bool do_preload() const {
2428 if (preload_bufs() == 0) return false;
2429 return riter >= (preload_bufs() - 1) + std::max(0, gmem_bufs() - 1);
2430 }
2431
do_prefetch() const2432 bool do_prefetch() const {
2433 if (!params.use_prefetch) return false;
2434 return do_preload();
2435 }
2436
do_g2s_load() const2437 bool do_g2s_load() const {
2438 if (!params.use_slm) return false;
2439 return do_preload();
2440 }
2441
do_s2r_load() const2442 bool do_s2r_load() const {
2443 if (!params.use_slm) return false;
2444 ir_assert(gmem_bufs() >= 1);
2445 return iter >= (gmem_bufs() - 1) && riter >= (slm_bufs() - 1);
2446 }
2447
gmem_write_buf_index() const2448 int gmem_write_buf_index() const {
2449 ir_assert(do_g2s_load());
2450 return iter % gmem_bufs();
2451 }
2452
gmem_read_buf_index() const2453 int gmem_read_buf_index() const {
2454 ir_assert(do_s2r_load());
2455 return (iter - (gmem_bufs() - 1)) % gmem_bufs();
2456 }
2457
slm_read_offset_update() const2458 int slm_read_offset_update() const {
2459 ir_assert(params.use_slm);
2460 ir_assert(do_mul());
2461
2462 int slm_iter = iter - (gmem_bufs() - 1) - (slm_bufs() - 1);
2463 int cur_slm_idx = slm_iter % slm_bufs();
2464 int next_slm_idx = (slm_iter + 1) % slm_bufs();
2465 int ret = next_slm_idx * params.slm_buf_size
2466 - cur_slm_idx * params.slm_buf_size;
2467 return ret;
2468 }
2469
slm_write_offset_update() const2470 int slm_write_offset_update() const {
2471 ir_assert(params.use_slm);
2472 ir_assert(do_s2r_load());
2473
2474 int slm_iter = iter - (gmem_bufs() - 1);
2475 int cur_slm_idx = slm_iter % slm_bufs();
2476 int next_slm_idx = (slm_iter + 1) % slm_bufs();
2477 int ret = next_slm_idx * params.slm_buf_size
2478 - cur_slm_idx * params.slm_buf_size;
2479 return ret;
2480 }
2481
2482 compute_params_t params;
2483 multi_loop_iterator_t preload_loop_it;
2484 multi_loop_iterator_t mul_loop_it;
2485
2486 // ramp_up_iters + body_iters + ramp_down_iters == iters
2487 int iters;
2488 int ramp_up_iters;
2489 int body_iters;
2490 int ramp_down_iters;
2491
2492 // Invariant: iter + riter = iters - 1
2493 int iter;
2494 int riter;
2495
2496 int linear_id;
2497 };
2498
2499 // Basic LRU SBID allocator, tries to use the same SBIDs for the same GRF
2500 // buffers.
2501 class sbid_manager_t {
2502 public:
sbid_manager_t()2503 sbid_manager_t() : tuple_func_(builtin_t::make("tuple")) {}
2504
get_sbid(const expr_t & buf,int index=0)2505 ngen_proxy::SBID get_sbid(const expr_t &buf, int index = 0) {
2506 auto key = tuple_func_.call({buf, expr_t(index)});
2507
2508 int free_idx = -1;
2509 for (int i = 0; i < sbid_count; i++) {
2510 auto &e = entries_[i];
2511 if (key.is_equal(e.key)) {
2512 e.time = cur_time_++;
2513 return ngen_proxy::SBID(i);
2514 }
2515 if (free_idx == -1 && e.key.is_empty()) free_idx = i;
2516 }
2517
2518 // Not found but there is a free SBID.
2519 if (free_idx != -1) {
2520 entries_[free_idx] = {key, cur_time_++};
2521 return ngen_proxy::SBID(free_idx);
2522 }
2523
2524 // Find the oldest SBID and use it.
2525 int old_idx = 0;
2526 int old_time = entries_[0].time;
2527 for (int i = 1; i < sbid_count; i++) {
2528 if (entries_[i].time < old_time) {
2529 old_idx = i;
2530 old_time = entries_[i].time;
2531 }
2532 }
2533
2534 entries_[old_idx] = entry_t({key, cur_time_++});
2535 return ngen_proxy::SBID(old_idx);
2536 }
2537
2538 private:
2539 struct entry_t {
2540 stmt_t key;
2541 int time;
2542 };
2543
2544 static const int sbid_count = 16;
2545 std::array<entry_t, sbid_count> entries_;
2546
2547 func_t tuple_func_;
2548 int cur_time_ = 0;
2549 };
2550
2551 // Helper to assign SBIDs to IR function calls.
2552 class sbid_assigner_t {
2553 public:
2554 sbid_assigner_t() = default;
2555
sbid_assigner_t(sbid_manager_t & external_sbid_mgr)2556 sbid_assigner_t(sbid_manager_t &external_sbid_mgr)
2557 : external_sbid_mgr_(&external_sbid_mgr) {}
2558
assign(const stmt_t & stmt)2559 stmt_t assign(const stmt_t &stmt) {
2560 auto stmt_vec = flatten_statements(stmt);
2561 stmt_t ret = stmt;
2562 int prefetch_idx = 0;
2563 for (auto &_s : stmt_vec) {
2564 if (!_s.is<func_call_t>()) continue;
2565 auto s = _s;
2566 if (is_func_call<send_t>(s)) {
2567 auto &send = s.as<func_call_t>().func.as<send_t>();
2568 int idx = (send.is_prefetch ? prefetch_idx++ : 0);
2569 auto sbid = get_sbid(send_t::arg_reg_buf(s), idx);
2570 s = update_call_with_sbid(s, sbid);
2571 } else if (is_func_call<dpas_t>(s)) {
2572 auto &attr = s.as<func_call_t>().attr;
2573 auto *mod_attr = attr.as_ptr<instruction_modifier_attr_t>();
2574 if (!mod_attr || !mod_attr->mod.is_atomic) {
2575 // Last dpas in Atomic chain.
2576 auto sbid = get_sbid(dpas_t::arg_src1(s));
2577 s = update_call_with_sbid(s, sbid);
2578 }
2579 } else if (s.is<func_call_t>()) {
2580 auto &c = s.as<func_call_t>();
2581 if (c.func.is_equal(funcs::signal_func())
2582 || c.func.is_equal(funcs::slm_fence_func())
2583 || c.func.is_equal(funcs::barrier_func())) {
2584 // Use 0 as the key for signals and SLM fences.
2585 auto sbid = get_sbid(expr_t(0));
2586 s = update_call_with_sbid(s, sbid);
2587 }
2588 } else {
2589 ir_error_not_expected() << s;
2590 }
2591 ret = substitute(ret, _s, s);
2592 }
2593 return ret;
2594 }
2595
2596 private:
get_sbid(const expr_t & ptr,int index=0)2597 ngen_proxy::SBID get_sbid(const expr_t &ptr, int index = 0) {
2598 auto &sbid_mgr
2599 = (external_sbid_mgr_ ? *external_sbid_mgr_ : local_sbid_mgr_);
2600 return sbid_mgr.get_sbid(ptr, index);
2601 }
2602
update_call_with_sbid(const stmt_t & s,const ngen_proxy::SBID & sbid)2603 static stmt_t update_call_with_sbid(
2604 const stmt_t &s, const ngen_proxy::SBID &sbid) {
2605 return instruction_modifier_attr_t::make(
2606 ngen_proxy::InstructionModifier().with_sbid(sbid))
2607 .apply_to(s);
2608 }
2609
2610 sbid_manager_t local_sbid_mgr_;
2611 sbid_manager_t *external_sbid_mgr_ = nullptr;
2612 };
2613
2614 class simple_slm_buffering_injector_t {
2615 public:
simple_slm_buffering_injector_t(ngen::HW hw,const stmt_t & root,const conv_config_t & cfg,ir_context_t & ir_ctx,int ab_slm_size)2616 simple_slm_buffering_injector_t(ngen::HW hw, const stmt_t &root,
2617 const conv_config_t &cfg, ir_context_t &ir_ctx, int ab_slm_size)
2618 : hw_(hw)
2619 , cfg_(cfg)
2620 , ir_ctx_(ir_ctx)
2621 , ab_slm_size_(ab_slm_size)
2622 , root_(root)
2623 , alloc_mgr_(root_)
2624 , step_(root)
2625 , loop_nest_(root, ir_ctx) {}
2626
inject()2627 stmt_t inject() {
2628 ir_assert(cfg_.gmem_bufs == 1) << "GRF buffering is not supported.";
2629 if (utils::one_of(cfg_.slm_bufs, 0, 1)) return root_;
2630
2631 ir_assert(cfg_.use_a_slm == cfg_.use_b_slm)
2632 << "Mixed SLM/GMEM loads are not supported.";
2633
2634 auto loop = step_.compute_loop();
2635
2636 // SLM indices are allocated as follows:
2637 // slm_idx[0] -> slm_buf_store
2638 // slm_idx[1] -> slm_buf_compute
2639 // slm_idx[2] -> slm_counter
2640 auto slm_idx_buf
2641 = ir_ctx_.create_tmp_var(type_t::byte_ptr(), "slm_idx");
2642 int slm_idx_size = type_t::s32().size();
2643
2644 auto slm_idx_load = [&](int off, int elems) {
2645 return load_t::make(
2646 type_t::s32(elems), slm_idx_buf, slm_idx_size * off);
2647 };
2648
2649 // Initialize slm_idx.
2650 int off = 0;
2651 auto store0 = store_t::make(slm_idx_buf, off, 0);
2652 off += slm_idx_size;
2653
2654 auto store1 = store_t::make(slm_idx_buf, off, 1);
2655 off += slm_idx_size;
2656
2657 auto store2 = store_t::make(
2658 slm_idx_buf, off, int_imm_t::make(0, type_t::s32()));
2659
2660 auto slm_idx_init = store0.append(store1).append(store2);
2661
2662 auto slm_idx_load2 = slm_idx_load(0, 2);
2663 auto slm_idx_load4 = slm_idx_load(0, 4);
2664 auto slm_idx_store = store_t::make(slm_idx_buf, 0,
2665 slm_idx_load4 + shuffle_t::make_broadcast(1, 4));
2666
2667 // Update slm_idx.
2668 auto mask = (slm_idx_load2
2669 == shuffle_t::make_broadcast(cfg_.slm_bufs, 2));
2670 auto slm_idx_store_fix = store_t::make(slm_idx_buf, 0,
2671 shuffle_t::make_broadcast(int_imm_t::make(0, type_t::s32()), 2),
2672 store_t::default_stride, mask);
2673
2674 auto slm_idx_update = slm_idx_store.append(slm_idx_store_fix);
2675
2676 loop = slm_idx_init.append(loop);
2677
2678 auto &g2s_store_orig = step_.g2s_store();
2679 auto &s2r_load = step_.s2r_load();
2680 auto &mul = step_.mul();
2681
2682 auto g2s_store = g2s_store_orig;
2683
2684 ir_assert(s2r_load.size() == mul.size());
2685
2686 stmt_t s2r_mul;
2687 for (int i = 0; i < int(mul.size()); i++) {
2688 s2r_mul = s2r_mul.append(s2r_load[i]);
2689 loop = substitute(loop, s2r_load[i], stmt_t(), 1);
2690 s2r_mul = s2r_mul.append(mul[i]);
2691 loop = substitute(loop, mul[i], stmt_t(), 1);
2692 }
2693
2694 loop = remove_synchronization(loop);
2695
2696 s2r_mul = sub_slm_bufs(s2r_mul, slm_idx_load(1, 1));
2697 g2s_store = sub_slm_bufs(g2s_store, slm_idx_load(0, 1));
2698 g2s_store = g2s_store.append(slm_idx_update);
2699
2700 auto s2r_mul_body = s2r_mul;
2701 auto s2r_mul_tail = s2r_mul;
2702 auto slm_counter = slm_idx_load(2, 1);
2703 auto cond = (slm_counter >= cfg_.slm_bufs - 1);
2704
2705 if (cfg_.slm_bufs == 2) {
2706 s2r_mul_body = if_t::make(cond, s2r_mul_body);
2707 g2s_store = g2s_store.append(funcs::barrier());
2708 } else {
2709 // In general we have to use SLM fence before signal to flush all
2710 // previous SLM stores. However any SLM load behaves as implicit
2711 // SLM fence for all previous SLM stores. This means we don't need
2712 // explicit SLM fence when we perform SLM load/multiplication
2713 // before signal.
2714 auto fence_signal = funcs::slm_fence().append(funcs::signal());
2715 s2r_mul_body = s2r_mul_body.append(funcs::signal());
2716 s2r_mul_body = if_t::make(cond, s2r_mul_body, fence_signal);
2717 s2r_mul_body = funcs::barrier_wait().append(s2r_mul_body);
2718 }
2719
2720 loop = substitute(
2721 loop, g2s_store_orig, s2r_mul_body.append(g2s_store), 1);
2722
2723 if (cfg_.slm_bufs == 3) {
2724 // Emit initial signal, to match wait-signal pairs in the loop.
2725 loop = funcs::signal().append(loop);
2726 }
2727
2728 // Complete the remaining iterations.
2729 int rem_iters = cfg_.slm_bufs - 1;
2730 int mul_start = std::max(0, rem_iters - loop_nest_.size());
2731 for (int i = 0; i < rem_iters; i++) {
2732 if (cfg_.slm_bufs == 3) loop = loop.append(funcs::barrier_wait());
2733 if (i >= mul_start) {
2734 // SLM load/multiplication works as implicit SLM fence.
2735 loop = loop.append(s2r_mul_tail);
2736 } else {
2737 loop = loop.append(funcs::slm_fence());
2738 }
2739 loop = loop.append(slm_idx_update);
2740 if (cfg_.slm_bufs == 3 && i + 1 < rem_iters)
2741 loop = loop.append(funcs::signal());
2742 }
2743
2744 if (cfg_.assign_sbids) loop = sbid_assigner_t().assign(loop);
2745
2746 const auto grf_size = ngen::GRF::bytes(hw_);
2747 loop = alloc_t::make(
2748 slm_idx_buf, grf_size, alloc_kind_t::grf, {}, loop);
2749
2750 alloc_updater_t alloc_updater;
2751
2752 auto slm_buffers = alloc_mgr_.find_buffers(alloc_kind_t::slm);
2753 ir_assert(slm_buffers.size() == 1);
2754 auto &slm_buf = slm_buffers[0];
2755 int non_ab_slm_size = alloc_mgr_.alloc_size(slm_buf) - ab_slm_size_;
2756 alloc_updater.resize(
2757 slm_buf, non_ab_slm_size + ab_slm_size_ * cfg_.slm_bufs);
2758
2759 auto ret = substitute(root_, step_.compute_loop(), loop, 1);
2760 ret = alloc_updater.update(ret);
2761 return ret;
2762 }
2763
remove_synchronization(const stmt_t & s)2764 static stmt_t remove_synchronization(const stmt_t &s) {
2765 auto ret = s;
2766 for (auto &_c : find_objects<func_call_t>(s)) {
2767 auto &c = _c.as<func_call_t>();
2768 if (c.func.is_equal(funcs::signal_func())
2769 || c.func.is_equal(funcs::slm_fence_func())
2770 || c.func.is_equal(funcs::barrier_func())) {
2771 ret = substitute(ret, _c, stmt_t(), 1);
2772 }
2773 }
2774 return ret;
2775 }
2776
sub_slm_bufs(const stmt_t & stmt,const expr_t & slm_idx) const2777 stmt_t sub_slm_bufs(const stmt_t &stmt, const expr_t &slm_idx) const {
2778 auto stmt_vec = flatten_statements(stmt);
2779
2780 stmt_t ret = stmt;
2781 for (auto &s : stmt_vec) {
2782 if (!is_func_call<send_t>(s)) continue;
2783
2784 auto &send = s.as<func_call_t>().func.as<send_t>();
2785
2786 // This is not send to SLM, skip.
2787 if (send.address_model != ngen_proxy::AddressModel::ModelSLM)
2788 continue;
2789
2790 auto new_args = s.as<func_call_t>().args;
2791 send_t::arg_mem_off(new_args) += ab_slm_size_ * slm_idx;
2792 auto new_send = send.call(new_args);
2793 ret = substitute(ret, s, new_send, 1);
2794 }
2795
2796 return ret;
2797 }
2798
2799 ngen::HW hw_;
2800 const conv_config_t &cfg_;
2801 ir_context_t &ir_ctx_;
2802 int ab_slm_size_;
2803
2804 stmt_t root_;
2805 alloc_manager_t alloc_mgr_;
2806 compute_step_t step_;
2807 compute_loop_nest_t loop_nest_;
2808 };
2809
2810 // Injects SLM buffering without unrolling based on the config.
inject_simple_slm_buffering(ngen::HW hw,const stmt_t & s,const conv_config_t & cfg,ir_context_t & ir_ctx,int ab_slm_size)2811 stmt_t inject_simple_slm_buffering(ngen::HW hw, const stmt_t &s,
2812 const conv_config_t &cfg, ir_context_t &ir_ctx, int ab_slm_size) {
2813 auto ret = simple_slm_buffering_injector_t(hw, s, cfg, ir_ctx, ab_slm_size)
2814 .inject();
2815 trace_pass("inject_simple_slm_buffering", ret);
2816 return ret;
2817 }
2818
2819 class unrolling_injector_t {
2820 public:
unrolling_injector_t(const stmt_t & root,const conv_config_t & cfg,ir_context_t & ir_ctx,int ab_slm_size)2821 unrolling_injector_t(const stmt_t &root, const conv_config_t &cfg,
2822 ir_context_t &ir_ctx, int ab_slm_size)
2823 : cfg_(cfg)
2824 , ir_ctx_(ir_ctx)
2825 , ab_slm_size_(ab_slm_size)
2826 , root_(root)
2827 , alloc_mgr_(root_)
2828 , step_(root)
2829 , loop_nest_(root, ir_ctx) {
2830 int inner_iters = loop_nest_.inner_loops_size();
2831 params_ = compute_params_t(cfg_.slm_bufs, cfg_.gmem_bufs, ab_slm_size,
2832 cfg_.prefetch_bufs, inner_iters);
2833 if (params_.use_slm) {
2834 for (auto &b :
2835 find_send_buffers(step_.g2s_load(), /*is_mem=*/false)) {
2836 g2s_reg_bufs_.emplace_back(b, alloc_mgr_.alloc_size(b));
2837 }
2838 }
2839 }
2840
inject()2841 stmt_t inject() {
2842 compute_iterator_t it(params_, loop_nest_);
2843 stmt_t body;
2844
2845 sbid_manager_t sbid_mgr;
2846
2847 auto &outer_loop_info = loop_nest_.outer_loop_info();
2848
2849 auto append_outer_post_inc = [&](const stmt_t &_s) {
2850 auto &mul = outer_loop_info.mul_post_inc_stmt();
2851 auto &preload = outer_loop_info.preload_post_inc_stmt();
2852 auto s = _s;
2853 if (it.mul_loop_it.is_outer_loop_end() && it.do_mul()) {
2854 s = s.append(mul);
2855 }
2856 if (it.preload_loop_it.is_outer_loop_end() && it.do_preload()) {
2857 s = s.append(preload);
2858 }
2859 return s;
2860 };
2861
2862 // Ramp-up.
2863 for (int i = 0; i < it.ramp_up_iters; i++) {
2864 body = stmt_seq_t::make(body, create_iteration(it, sbid_mgr));
2865 body = append_outer_post_inc(body);
2866 ++it;
2867 }
2868
2869 // Body.
2870 if (it.body_iters > 0) {
2871 int extent = it.body_iters / it.unroll();
2872 bool has_loop = (extent > 1);
2873
2874 stmt_t loop_body;
2875 for (int i = 0; i < it.unroll(); i++) {
2876 loop_body = loop_body.append(create_iteration(
2877 it, sbid_mgr, /*in_loop_body=*/has_loop));
2878 ir_assert(it.do_mul());
2879 ir_assert(it.do_preload());
2880 loop_body = append_outer_post_inc(loop_body);
2881 ++it;
2882 }
2883 if (!has_loop) {
2884 body = body.append(loop_body);
2885 } else {
2886 ir_assert(extent > 0);
2887 auto for_var = ir_ctx_.create_tmp_var(type_t::s32(), "i");
2888 body = body.append(for_t::make(for_var, 0, extent, loop_body));
2889 }
2890 it.advance(it.body_iters - it.unroll());
2891 }
2892
2893 // Ramp-down.
2894 for (int i = 0; i < it.ramp_down_iters; i++) {
2895 ir_assert(it.do_mul());
2896 body = body.append(create_iteration(it, sbid_mgr));
2897 body = append_outer_post_inc(body);
2898 ++it;
2899 }
2900
2901 if (outer_loop_info.has_var_refs()) {
2902 body = outer_loop_info.init_stmt().append(body);
2903 body = outer_loop_info.inject_alloc_stmts(body);
2904 }
2905
2906 auto ret = substitute(root_, step_.compute_loop(), body, 1);
2907 if (params_.use_slm) {
2908 alloc_updater_t alloc_updater;
2909
2910 // Update buffer sizes.
2911 for (auto &b : g2s_reg_bufs_) {
2912 alloc_updater.resize(
2913 b.buf, alloc_mgr_.alloc_size(b.buf) * cfg_.gmem_bufs);
2914 }
2915
2916 auto slm_buffers = alloc_mgr_.find_buffers(alloc_kind_t::slm);
2917 if (!slm_buffers.empty()) {
2918 ir_assert(slm_buffers.size() == 1);
2919
2920 auto &slm_buf = slm_buffers[0];
2921 int non_ab_slm_size
2922 = alloc_mgr_.alloc_size(slm_buf) - ab_slm_size_;
2923 alloc_updater.resize(slm_buf,
2924 non_ab_slm_size + ab_slm_size_ * cfg_.slm_bufs);
2925 }
2926
2927 ret = alloc_updater.update(ret);
2928 }
2929
2930 // Remove zero-out statement for C (handled by sub_fma_acc_with_zero).
2931 ret = substitute(ret, step_.c_zero_out(), stmt_t(), 1);
2932
2933 return ret;
2934 }
2935
2936 private:
2937 struct buffer_info_t {
buffer_info_tdnnl::impl::gpu::jit::unrolling_injector_t::buffer_info_t2938 buffer_info_t(const expr_t &buf, int size) : buf(buf), size(size) {}
2939
2940 expr_t buf;
2941 int size;
2942 };
2943
create_iteration(const compute_iterator_t & it,sbid_manager_t & sbid_mgr,bool in_loop_body=false) const2944 stmt_t create_iteration(const compute_iterator_t &it,
2945 sbid_manager_t &sbid_mgr, bool in_loop_body = false) const {
2946 auto g2s_load = step_.g2s_load();
2947 auto g2s_store = step_.g2s_store();
2948 auto prefetch = step_.prefetch();
2949 auto g2r_load = step_.g2r_load();
2950 auto s2r_load = step_.s2r_load();
2951 auto mul = step_.mul();
2952 auto lets = step_.inner_let_stmts();
2953 auto &outer_loop_info = loop_nest_.outer_loop_info();
2954
2955 loop_nest_.for_each_loop_var([&](const expr_t &v) {
2956 g2s_load = const_fold(substitute(
2957 g2s_load, v, expr_t(it.preload_loop_it.var_value(v))));
2958 g2s_store = const_fold(substitute(
2959 g2s_store, v, expr_t(it.preload_loop_it.var_value(v))));
2960 prefetch = const_fold(substitute(
2961 prefetch, v, expr_t(it.preload_loop_it.var_value(v))));
2962 expr_t mul_var_value;
2963 expr_t preload_var_value;
2964 if (v.is_same(outer_loop_info.var) && in_loop_body
2965 && outer_loop_info.has_var_refs()) {
2966 mul_var_value = outer_loop_info.mul_var_load();
2967 preload_var_value = outer_loop_info.preload_var_load();
2968 } else {
2969 mul_var_value = it.mul_loop_it.var_value(v);
2970 preload_var_value = it.preload_loop_it.var_value(v);
2971 }
2972 for (auto &s : g2r_load) {
2973 s = const_fold(substitute(s, v, mul_var_value));
2974 }
2975 for (auto &s : s2r_load) {
2976 s = const_fold(substitute(s, v, preload_var_value));
2977 }
2978 for (int i = 0; i < int(lets.size()); i++) {
2979 auto &let = lets[i];
2980 auto &orig_let = step_.inner_let_stmts()[i];
2981 expr_t var_value;
2982 bool is_preload_let = step_.is_preload_let(orig_let);
2983 bool is_mul_let = step_.is_mul_let(orig_let);
2984 if (is_preload_let && !is_mul_let) {
2985 var_value = preload_var_value;
2986 } else if (is_mul_let && !is_preload_let) {
2987 var_value = mul_var_value;
2988 } else {
2989 ir_assert(count_object(let.as<let_t>().value, v) == 0)
2990 << "Unexpected reference to variable " << v
2991 << " from " << let;
2992 continue;
2993 }
2994 let = const_fold(substitute(let, v, var_value));
2995 }
2996 });
2997
2998 if (params_.use_slm) {
2999 g2s_load = sub_gmem_bufs(g2s_load, it, /*is_read=*/false);
3000 g2s_store = sub_gmem_bufs(g2s_store, it, /*is_read=*/true);
3001
3002 g2s_store = sub_slm_bufs(g2s_store, it, /*is_read=*/false);
3003 for (auto &s : s2r_load) {
3004 s = sub_slm_bufs(s, it, /*is_read=*/true);
3005 }
3006 }
3007
3008 if (it.is_first_mul()) {
3009 for (auto &m : mul) {
3010 m = sub_fma_acc_with_zero(m);
3011 }
3012 }
3013
3014 stmt_t iter_stmt;
3015 if (it.slm_bufs() == 3 && it.do_mul()) {
3016 iter_stmt = iter_stmt.append(funcs::barrier_wait());
3017 }
3018
3019 if (it.do_g2s_load()) iter_stmt = iter_stmt.append(g2s_load);
3020
3021 if (it.slm_bufs() == 3 && it.iter == it.gmem_bufs()) {
3022 iter_stmt = iter_stmt.append(funcs::slm_fence());
3023 iter_stmt = iter_stmt.append(funcs::signal());
3024 }
3025
3026 if (it.do_s2r_load() && it.slm_bufs() == 1) {
3027 iter_stmt = iter_stmt.append(funcs::barrier());
3028 iter_stmt = iter_stmt.append(g2s_store);
3029 iter_stmt = iter_stmt.append(funcs::barrier());
3030 }
3031
3032 if (it.do_prefetch()) iter_stmt = iter_stmt.append(prefetch);
3033
3034 if (it.do_mul()) {
3035 for (size_t i = 0; i < mul.size(); i++) {
3036 iter_stmt = iter_stmt.append(g2r_load[i]);
3037 iter_stmt = iter_stmt.append(s2r_load[i]);
3038 iter_stmt = iter_stmt.append(mul[i]);
3039 }
3040 if (it.slm_bufs() == 3 && !it.is_last_mul()) {
3041 iter_stmt = iter_stmt.append(funcs::signal());
3042 }
3043 }
3044 if (it.do_s2r_load() && it.slm_bufs() >= 2) {
3045 iter_stmt = iter_stmt.append(g2s_store);
3046 if (it.slm_bufs() == 2) {
3047 iter_stmt = iter_stmt.append(funcs::barrier());
3048 }
3049 }
3050
3051 if (cfg_.assign_sbids)
3052 iter_stmt = sbid_assigner_t(sbid_mgr).assign(iter_stmt);
3053
3054 iter_stmt = inject_local_let(iter_stmt, lets, it.linear_id);
3055
3056 return iter_stmt;
3057 }
3058
sub_gmem_bufs(const stmt_t & stmt,const compute_iterator_t & it,bool is_read) const3059 stmt_t sub_gmem_bufs(const stmt_t &stmt, const compute_iterator_t &it,
3060 bool is_read) const {
3061 if (it.slm_bufs() == 0) return stmt;
3062 if (is_read && !it.do_s2r_load()) return stmt;
3063 if (!is_read && !it.do_g2s_load()) return stmt;
3064
3065 int buf_idx = (is_read ? it.gmem_read_buf_index()
3066 : it.gmem_write_buf_index());
3067 if (buf_idx == 0) return stmt;
3068
3069 auto ret = stmt;
3070 for (auto &b : g2s_reg_bufs_) {
3071 ret = substitute(ret, b.buf, b.buf[buf_idx * b.size]);
3072 }
3073 return ret;
3074 }
3075
sub_slm_bufs(const stmt_t & stmt,const compute_iterator_t & it,bool is_read) const3076 stmt_t sub_slm_bufs(const stmt_t &stmt, const compute_iterator_t &it,
3077 bool is_read) const {
3078 if (it.slm_bufs() <= 1) return stmt;
3079 if (is_read && !it.do_mul()) return stmt;
3080 if (!is_read && !it.do_s2r_load()) return stmt;
3081
3082 int upd = (is_read ? it.slm_read_offset_update()
3083 : it.slm_write_offset_update());
3084
3085 auto stmt_vec = flatten_statements(stmt);
3086
3087 stmt_t ret = stmt;
3088 for (auto &s : stmt_vec) {
3089 auto *call = s.as_ptr<func_call_t>();
3090 if (!call) continue;
3091 auto *func = call->func.as_ptr<send_t>();
3092 if (!func) continue;
3093
3094 auto &send = call->func.as<send_t>();
3095 auto &args = call->args;
3096 auto &mem_buf = send_t::arg_mem_buf(args);
3097 auto &header_buf = send_t::arg_mem_off(args);
3098
3099 // This is not send to SLM, skip.
3100 if (send.address_model != ngen_proxy::AddressModel::ModelSLM)
3101 continue;
3102
3103 // May have signed offset.
3104 auto store_obj = send.create_offset_store(
3105 header_buf, mem_buf, upd, /*is_signed_offset=*/true);
3106 auto &store = store_obj.as<store_t>();
3107 expr_t old_value
3108 = load_t::make(send.address_type(), store.buf, store.off);
3109 auto post_inc_store = store_t::make(
3110 store.buf, store.off, old_value + store.value);
3111 ret = substitute(ret, s, stmt_seq_t::make(s, post_inc_store), 1);
3112 }
3113
3114 return ret;
3115 }
3116
sub_fma_acc_with_zero(const stmt_t & stmt)3117 static stmt_t sub_fma_acc_with_zero(const stmt_t &stmt) {
3118 auto stmt_vec = flatten_statements(stmt);
3119
3120 object_eq_set_t<expr_t> seen_dst;
3121 stmt_t ret = stmt;
3122 for (auto &s : stmt_vec) {
3123 if (is_func_call<dpas_t>(s)) {
3124 auto &call = s.as<func_call_t>();
3125
3126 auto &dst = dpas_t::arg_dst(s);
3127 auto src0 = expr_t(0); // Will be translated to null register.
3128 auto &src1 = dpas_t::arg_src1(s);
3129 auto &src2 = dpas_t::arg_src2(s);
3130
3131 auto new_call = func_call_t::make(
3132 call.func, {dst, src0, src1, src2}, call.attr);
3133 ret = substitute(ret, s, new_call, 1);
3134 } else if (is_func_call<mad_t>(s)) {
3135 auto &call = s.as<func_call_t>();
3136
3137 auto &dst = mad_t::arg_dst(s);
3138 auto src0 = expr_t(0); // Will be translated to null register.
3139 auto &src1 = mad_t::arg_src1(s);
3140 auto &src2 = mad_t::arg_src2(s);
3141
3142 if (!seen_dst.insert(dst).second) continue;
3143
3144 auto new_call = func_call_t::make(
3145 call.func, {dst, src0, src1, src2}, call.attr);
3146 ret = substitute(ret, s, new_call, 1);
3147 }
3148 }
3149 return ret;
3150 }
3151
3152 // Returns memory buffers if is_mem is true and register buffers otherwise.
find_send_buffers(const stmt_t & s,bool is_mem)3153 static object_set_t<expr_t> find_send_buffers(
3154 const stmt_t &s, bool is_mem) {
3155 object_set_t<expr_t> ret;
3156 auto calls = find_objects<func_call_t>(s);
3157 for (auto &_c : calls) {
3158 auto &c = _c.as<func_call_t>();
3159 if (!c.func.is<send_t>()) continue;
3160 auto &buf = (is_mem ? send_t::arg_mem_buf(_c)
3161 : send_t::arg_reg_buf(_c));
3162 ret.insert(buf.as<ptr_t>().base);
3163 }
3164 return ret;
3165 }
3166
inject_local_let(const stmt_t & _s,const std::vector<stmt_t> & enclosed_lets,int id)3167 static stmt_t inject_local_let(const stmt_t &_s,
3168 const std::vector<stmt_t> &enclosed_lets, int id) {
3169 auto s = _s;
3170
3171 // Inject let statements from the innermost loop.
3172 for (auto &_let : enclosed_lets) {
3173 auto &let = _let.as<let_t>();
3174 s = let_t::make(let.var, let.value, s);
3175 }
3176
3177 // Substitute variables to avoid clashing.
3178 auto lets = find_objects<let_t>(s);
3179 for (auto &_let : lets) {
3180 auto &let = _let.as<let_t>();
3181 auto &var = let.var.as<var_t>();
3182 auto local_var = var_t::make(
3183 var.type, var.name + "_" + std::to_string(id));
3184 s = substitute(s, let.var, local_var);
3185 }
3186 return s;
3187 }
3188
3189 const conv_config_t &cfg_;
3190 ir_context_t &ir_ctx_;
3191 int ab_slm_size_;
3192
3193 stmt_t root_;
3194 alloc_manager_t alloc_mgr_;
3195 compute_step_t step_;
3196 compute_loop_nest_t loop_nest_;
3197 compute_params_t params_;
3198
3199 std::vector<buffer_info_t> g2s_reg_bufs_; // For SLM buffering.
3200 };
3201
3202 // Injects loop unrolling based on the config. Possible options:
3203 // - Without preload (no SLM buffering, no prefetch)
3204 // - With SLM buffering
3205 // - With prefetch
inject_unrolling(const stmt_t & s,const conv_config_t & cfg,ir_context_t & ir_ctx,int ab_slm_size)3206 stmt_t inject_unrolling(const stmt_t &s, const conv_config_t &cfg,
3207 ir_context_t &ir_ctx, int ab_slm_size) {
3208 auto ret = unrolling_injector_t(s, cfg, ir_ctx, ab_slm_size).inject();
3209 trace_pass("inject_unrolling", ret);
3210 return ret;
3211 }
3212
3213 class store_splitter_t : public ir_mutator_t {
3214 public:
store_splitter_t(ngen::HW hw)3215 store_splitter_t(ngen::HW hw) : hw_(hw) {}
3216
_mutate(const store_t & obj)3217 object_t _mutate(const store_t &obj) override {
3218 int elems = obj.value.type().elems();
3219 int elem_size = obj.value.type().scalar().size();
3220 int stride = (obj.has_default_stride() ? 1 : obj.stride / elem_size);
3221 int store_size = elem_size * stride * elems;
3222 const auto grf_size = ngen::GRF::bytes(hw_);
3223 if (store_size <= 2 * grf_size) return ir_mutator_t::_mutate(obj);
3224
3225 int step = 2 * grf_size / (stride * elem_size);
3226 stmt_t new_stmt;
3227 for (int i = 0; i < elems; i += step) {
3228 int cur_elems = std::min(step, elems - i);
3229 ir_assert(math::is_pow2(cur_elems));
3230 int off = i * stride * elem_size;
3231 auto store = store_t::make(obj.buf, obj.off + off,
3232 split_expr(obj.value, i, i + cur_elems), obj.stride);
3233 new_stmt = new_stmt.append(store);
3234 }
3235 return std::move(new_stmt);
3236 }
3237
3238 private:
split_expr(const expr_t & e,int beg,int end)3239 static expr_t split_expr(const expr_t &e, int beg, int end) {
3240 auto *shuffle = e.as_ptr<shuffle_t>();
3241 if (shuffle) return shuffle_t::make(shuffle, beg, end);
3242
3243 auto *binary = e.as_ptr<binary_op_t>();
3244 if (binary) {
3245 auto a = split_expr(binary->a, beg, end);
3246 auto b = split_expr(binary->b, beg, end);
3247 return binary_op_t::make(binary->op_kind, a, b);
3248 }
3249 ir_error_not_expected();
3250 return expr_t();
3251 }
3252
3253 ngen::HW hw_;
3254 };
3255
3256 // Splits wide GRF stores otherwise unsupported in HW.
split_wide_stores(ngen::HW hw,const stmt_t & s)3257 stmt_t split_wide_stores(ngen::HW hw, const stmt_t &s) {
3258 auto ret = store_splitter_t(hw).mutate(s);
3259 trace_pass("split_wide_stores", ret);
3260 return ret;
3261 }
3262
3263 class peephole_optimizer_t : public ir_mutator_t {
3264 public:
_mutate(const binary_op_t & obj)3265 object_t _mutate(const binary_op_t &obj) override {
3266 auto old_obj = ir_mutator_t::_mutate(obj);
3267 auto new_obj
3268 = simplify_rewrite_with_ternary(old_obj, /*recursive=*/false);
3269 auto *ternary = new_obj.as_ptr<ternary_op_t>();
3270 if (!ternary) return std::move(new_obj);
3271
3272 switch (ternary->op_kind) {
3273 case op_kind_t::_add3: {
3274 bool ok = true;
3275 // Allowed form: add3(dword/word, dword/word, dword/word).
3276 ok &= add3_type_ok(ternary->a);
3277 ok &= add3_type_ok(ternary->b);
3278 ok &= add3_type_ok(ternary->c);
3279 ok &= !is_const(ternary->a);
3280 ok &= !is_const(ternary->b);
3281 if (!ok) new_obj = old_obj;
3282 break;
3283 }
3284 case op_kind_t::_mad: {
3285 auto a_type = real_type(ternary->a);
3286 auto b_type = real_type(ternary->b);
3287 auto c_type = real_type(ternary->c);
3288 bool ok = true;
3289 // Allowed form: mad(dword, dword, word).
3290 ok &= utils::one_of(a_type, type_t::s32(), type_t::u32());
3291 ok &= utils::one_of(b_type, type_t::s32(), type_t::u32());
3292 ok &= utils::one_of(c_type, type_t::s16(), type_t::u16());
3293 if (!ok) new_obj = old_obj;
3294 break;
3295 }
3296 default: ir_error_not_expected();
3297 }
3298 return std::move(new_obj);
3299 }
3300
3301 private:
real_type(const expr_t & e)3302 static type_t real_type(const expr_t &e) {
3303 auto *imm = e.as_ptr<int_imm_t>();
3304 if (!imm) return e.type();
3305 if (int_imm_t::try_shrink_type<int16_t>(imm->value))
3306 return type_t::s16();
3307 if (int_imm_t::try_shrink_type<int32_t>(imm->value))
3308 return type_t::s32();
3309 return type_t::s64();
3310 }
3311
add3_type_ok(const expr_t & e)3312 static bool add3_type_ok(const expr_t &e) {
3313 auto t = real_type(e);
3314 if (!t.is_scalar()) return false;
3315 switch (t.kind()) {
3316 case type_kind_t::s32:
3317 case type_kind_t::u32: return !is_const(e);
3318 case type_kind_t::s16:
3319 case type_kind_t::u16: return true;
3320 default: return false;
3321 }
3322 }
3323 };
3324
optimize_peephole(const stmt_t & s)3325 stmt_t optimize_peephole(const stmt_t &s) {
3326 auto ret = peephole_optimizer_t().mutate(s);
3327 trace_pass("optimize_peephole", ret);
3328 return ret;
3329 }
3330
3331 class if_condition_fixer_t : public ir_mutator_t {
3332 public:
if_condition_fixer_t(int simd_size)3333 if_condition_fixer_t(int simd_size) : simd_size_(simd_size) {}
3334
_mutate(const if_t & obj)3335 object_t _mutate(const if_t &obj) override {
3336 auto _new_obj = ir_mutator_t::_mutate(obj);
3337 auto &new_obj = _new_obj.as<if_t>();
3338 auto cond = shuffle_t::make_broadcast(new_obj.cond, simd_size_);
3339 return if_t::make(cond, new_obj.body, new_obj.else_body);
3340 }
3341
3342 private:
3343 int simd_size_;
3344 };
3345
3346 // Injects broadcasts for scalar if conditions. Example:
3347 // Before:
3348 // if (cond) { ... }
3349 // After (for SIMD8):
3350 // if (bcast8(cond)) { ... }
fixup_if_conditions(const stmt_t & s,const conv_config_t & cfg)3351 stmt_t fixup_if_conditions(const stmt_t &s, const conv_config_t &cfg) {
3352 auto ret = if_condition_fixer_t(cfg.simd_size).mutate(s);
3353 trace_pass("fixup_if_conditions", ret);
3354 return ret;
3355 }
3356
3357 class loop_unroller_t : public ir_mutator_t {
3358 public:
loop_unroller_t(ir_context_t & ir_ctx)3359 loop_unroller_t(ir_context_t &ir_ctx) : ir_ctx_(ir_ctx) {}
3360
_mutate(const for_t & obj)3361 object_t _mutate(const for_t &obj) override {
3362 auto new_obj = ir_mutator_t::_mutate(obj);
3363 auto &_for = new_obj.as<for_t>();
3364 // No unrolling.
3365 if (_for.unroll == 1) return new_obj;
3366
3367 ir_assert(is_const(obj.init))
3368 << "Can't unroll loop with non-const bound: " << obj.init;
3369 ir_assert(is_const(obj.bound))
3370 << "Can't unroll loop with non-const bound: " << obj.bound;
3371
3372 auto init = to_cpp<int>(obj.init);
3373 auto bound = to_cpp<int>(obj.bound);
3374
3375 ir_assert(_for.unroll == (bound - init))
3376 << "Only full loop unroll is supported.";
3377
3378 stmt_t ret;
3379 for (int i = init; i < bound; i++) {
3380 auto iter_stmt
3381 = substitute(obj.body, obj.var, to_expr(i, obj.var.type()));
3382 iter_stmt = rename_let_alloc(iter_stmt, i - init);
3383 ret = ret.append(iter_stmt);
3384 }
3385 return std::move(ret);
3386 }
3387
3388 private:
rename_let_alloc(const stmt_t & s,int idx)3389 stmt_t rename_let_alloc(const stmt_t &s, int idx) {
3390 auto lets = find_objects<let_t>(s);
3391 auto ret = s;
3392 for (auto &_let : lets) {
3393 auto &let = _let.as<let_t>();
3394 auto &var = let.var.as<var_t>();
3395 auto new_var = ir_ctx_.create_tmp_var(var.type, var.name);
3396 ret = substitute(ret, let.var, new_var);
3397 }
3398 auto allocs = find_objects<alloc_t>(s);
3399 for (auto &_alloc : allocs) {
3400 auto &alloc = _alloc.as<alloc_t>();
3401 auto &buf = alloc.buf.as<var_t>();
3402 auto new_buf = ir_ctx_.create_tmp_var(buf.type, buf.name);
3403 ret = substitute(ret, alloc.buf, new_buf);
3404 }
3405 return ret;
3406 }
3407
3408 ir_context_t &ir_ctx_;
3409 };
3410
3411 // Unrolls loops according to their unroll attribute.
3412 // Before:
3413 // for (int i = 0; i < 2; i++) [unroll: 2] {
3414 // body(i);
3415 // }
3416 // After:
3417 // body(0);
3418 // body(1);
unroll_loops(const stmt_t & s,ir_context_t & ir_ctx)3419 stmt_t unroll_loops(const stmt_t &s, ir_context_t &ir_ctx) {
3420 auto ret = loop_unroller_t(ir_ctx).mutate(s);
3421 trace_pass("unroll_loops", ret);
3422 return ret;
3423 }
3424
create_reorder_stmt(const layout_t & src,const layout_t & dst,const expr_t & src_buf,const expr_t & dst_buf)3425 stmt_t create_reorder_stmt(const layout_t &src, const layout_t &dst,
3426 const expr_t &src_buf, const expr_t &dst_buf) {
3427 ir_assert(src.ndims() == dst.ndims()) << "Layouts are incompatible.";
3428 ir_assert(src.elems() == dst.elems()) << "Layouts are incompatible.";
3429 auto func = reorder_t::make(src, dst);
3430 return func.call({dst_buf, src_buf});
3431 }
3432
create_reduce_stmt(const layout_t & src,const layout_t & dst,const expr_t & src_buf,const expr_t & dst_buf,const tensor_t & _sub_tile,uint32_t reduction_mask)3433 stmt_t create_reduce_stmt(const layout_t &src, const layout_t &dst,
3434 const expr_t &src_buf, const expr_t &dst_buf, const tensor_t &_sub_tile,
3435 uint32_t reduction_mask) {
3436 auto sub_tile = _sub_tile;
3437 if (sub_tile.is_empty()) sub_tile = tensor_t(src.dims());
3438 ir_assert(src.ndims() == sub_tile.ndims());
3439 int ndims = src.ndims();
3440
3441 // Align dst layout with src layout according to the mask.
3442 std::vector<int> dst2src(dst.ndims());
3443 int dst_dim_idx = 0;
3444 for (int i = 0; i < ndims; i++) {
3445 if ((reduction_mask & (1 << i)) != 0) {
3446 dst2src[dst_dim_idx] = i;
3447 dst_dim_idx++;
3448 }
3449 }
3450 ir_assert(dst_dim_idx == dst.ndims()) << "Incompatible reduction mask.";
3451
3452 auto dst_blocks = dst.blocks();
3453 for (auto &b : dst_blocks)
3454 b.dim_idx = dst2src[b.dim_idx];
3455
3456 // Create final layouts.
3457 auto dst_aligned = layout_t(dst.type(), ndims, dst.offset(), dst_blocks);
3458
3459 std::vector<dim_t> dst_tile_dims = sub_tile.dims();
3460 std::vector<expr_t> dst_tile_start = sub_tile.start();
3461 for (int i = 0; i < ndims; i++) {
3462 if ((reduction_mask & (1 << i)) == 0) {
3463 dst_tile_dims[i] = 1;
3464 dst_tile_start[i] = expr_t(0);
3465 continue;
3466 }
3467 }
3468 dst_aligned = dst_aligned.map(tensor_t(dst_tile_dims, dst_tile_start));
3469
3470 auto func = reduce_t::make(src, dst_aligned);
3471 return func.call({dst_buf, src_buf});
3472 }
3473
create_zero_out_stmt(ngen::HW hw,const expr_t & buf,int size)3474 stmt_t create_zero_out_stmt(ngen::HW hw, const expr_t &buf, int size) {
3475 stmt_t ret;
3476 int step_bytes = 2 * ngen::GRF::bytes(hw);
3477 for (int i = 0; i < size; i += step_bytes) {
3478 int cur_step_bytes = std::min(step_bytes, size - i);
3479 ret = ret.append(store_t::make(buf, i,
3480 shuffle_t::make_broadcast(
3481 expr_t(0.0f), cur_step_bytes / sizeof(float))));
3482 }
3483 return ret;
3484 }
3485
3486 // Generates loads or stores to move data between memory (global or SLM) and
3487 // GRF. Memory layout is a parameter. GRF layout is deduced automatically,
3488 // according to the decomposition into messages.
3489 class access_builder_t {
3490 public:
3491 access_builder_t() = default;
3492
access_builder_t(ngen::HW hw,ir_context_t & ir_ctx,const constraint_set_t & cset,const view_t & mem_view,const expr_t & mem_buf,const expr_t & reg_buf,bool is_slm,bool is_prefetch,bool is_load,ngen_proxy::AtomicOp atomic_op)3493 access_builder_t(ngen::HW hw, ir_context_t &ir_ctx,
3494 const constraint_set_t &cset, const view_t &mem_view,
3495 const expr_t &mem_buf, const expr_t ®_buf, bool is_slm,
3496 bool is_prefetch, bool is_load, ngen_proxy::AtomicOp atomic_op)
3497 : hw_(hw)
3498 , ir_ctx_(&ir_ctx)
3499 , cset_(&cset)
3500 , mem_view_(mem_view)
3501 , mem_buf_(mem_buf)
3502 , reg_buf_(reg_buf)
3503 , is_slm_(is_slm)
3504 , is_prefetch_(is_prefetch)
3505 , is_load_(is_load)
3506 , atomic_op_(atomic_op) {
3507 build();
3508 }
3509
is_slm() const3510 bool is_slm() const { return is_slm_; }
3511
is_prefetch() const3512 bool is_prefetch() const { return is_prefetch_; }
3513
reg_layout() const3514 const layout_t ®_layout() const { return reg_layout_; }
3515
reg_buf_size() const3516 int reg_buf_size() const { return reg_buf_size_; }
3517
stmt() const3518 const stmt_t &stmt() const { return stmt_; }
3519
str() const3520 std::string str() const {
3521 const auto grf_size = ngen::GRF::bytes(hw_);
3522 std::ostringstream oss;
3523 oss << "Memory view: " << mem_view_ << std::endl;
3524 oss << "Register layout: " << reg_layout_ << std::endl;
3525 oss << "Register buffer: " << reg_buf_ << std::endl;
3526 oss << "Register buffer size: " << reg_buf_size_ << " ("
3527 << reg_buf_size_ / grf_size << " regs)" << std::endl;
3528 oss << "Statement: " << std::endl << stmt_;
3529 return oss.str();
3530 }
3531
3532 private:
build()3533 void build() {
3534 auto send_list = get_send_list(mem_view_.type());
3535
3536 auto mask_tensor = mem_view_.create_mask_tensor(*cset_);
3537
3538 // Find the first send candidate matching the layout.
3539 func_t _send;
3540 tensor_t send_tensor;
3541 for (auto &_s_base : send_list) {
3542 auto &s_base = _s_base.as<send_t>();
3543 int type_size = mem_view_.type().size();
3544 int block_bytes_base = s_base.block_size();
3545 if (block_bytes_base % type_size != 0) continue;
3546 int elems_per_block_base = block_bytes_base / type_size;
3547
3548 dim_t elems_per_block = elems_per_block_base;
3549 dim_t slots = s_base.slots;
3550
3551 // Check if the view can be decomposed for this send.
3552 auto tensor
3553 = mem_view_.split_into_dense_tile(elems_per_block, slots);
3554 if (tensor.is_empty()) continue;
3555
3556 auto _s = s_base.adjust(
3557 int(elems_per_block * type_size), int(slots));
3558 if (_s.is_empty()) continue;
3559 auto &s = _s.as<send_t>();
3560
3561 // Check if this send supports the required mask.
3562 if (!has_compatible_mask(s, mem_view_, tensor, mask_tensor))
3563 continue;
3564
3565 // TODO: Check alignment requirements.
3566
3567 // Success, send is found, stop iterating.
3568 _send = _s;
3569 send_tensor = tensor;
3570 break;
3571 }
3572 // Support for prefetch messages is limited. If message is not found,
3573 // skip prefetch generation.
3574 if (_send.is_empty() && is_prefetch()) return;
3575 ir_assert(!_send.is_empty()) << "Can't decompose view into messages.";
3576
3577 auto &send = _send.as<send_t>();
3578 reg_layout_ = create_register_layout_for_message(
3579 send, mem_view_, reg_buf_size_);
3580
3581 mem_view_.for_each_tile(
3582 send_tensor, [&](const std::vector<dim_t> &start) {
3583 auto tile = tensor_t(send_tensor.dims(), start);
3584 auto sub_view = mem_view_.create_sub_view(tile);
3585 auto sub_mask_tensor = mask_tensor.map(tile);
3586 auto reg_sub_buf = (is_prefetch()
3587 ? expr_t()
3588 : reg_buf_[reg_layout_(start)
3589 * reg_layout_.type().size()]);
3590 stmt_ = stmt_seq_t::make(stmt_,
3591 create_send_stmt(*ir_ctx_, send, mem_buf_,
3592 reg_sub_buf, sub_view, sub_mask_tensor));
3593 });
3594 }
3595
3596 // Returns a list of send functions that can be used for the access.
get_send_list(const type_t & data_type) const3597 std::vector<func_t> get_send_list(const type_t &data_type) const {
3598 using namespace ngen_proxy;
3599 bool is_atomic = (atomic_op_ != AtomicOp::undef);
3600 Access access_type = (is_load_ ? Access::Read : Access::Write);
3601 // TODO: use stateless access on XeHPC until driver fix
3602 bool use_stateful_msgs = is_atomic && hw_ < ngen::HW::XeHPC;
3603 AddressModel address_model
3604 = (is_slm() ? AddressModel::ModelSLM
3605 : use_stateful_msgs ? AddressModel::ModelBTS
3606 : AddressModel::ModelA64);
3607 auto send_list = send_t::get_all(hw_, data_type, access_type,
3608 address_model, atomic_op_, is_prefetch_);
3609 return send_list;
3610 }
3611
3612 ngen::HW hw_;
3613 ir_context_t *ir_ctx_;
3614 const constraint_set_t *cset_;
3615
3616 view_t mem_view_;
3617 expr_t mem_buf_;
3618 layout_t reg_layout_;
3619 expr_t reg_buf_;
3620 int reg_buf_size_;
3621 bool is_slm_;
3622 bool is_prefetch_;
3623 bool is_load_;
3624 stmt_t stmt_;
3625 ngen_proxy::AtomicOp atomic_op_;
3626 };
3627
3628 class read_builder_t : public access_builder_t {
3629 public:
3630 read_builder_t() = default;
3631
read_builder_t(ngen::HW hw,ir_context_t & ir_ctx,const constraint_set_t & cset,const view_t & view,const expr_t & mem_buf,const expr_t & reg_buf,bool is_slm,bool is_prefetch=false)3632 read_builder_t(ngen::HW hw, ir_context_t &ir_ctx,
3633 const constraint_set_t &cset, const view_t &view,
3634 const expr_t &mem_buf, const expr_t ®_buf, bool is_slm,
3635 bool is_prefetch = false)
3636 : access_builder_t(hw, ir_ctx, cset, view, mem_buf, reg_buf, is_slm,
3637 is_prefetch, /*is_load=*/true, ngen_proxy::AtomicOp::undef) {}
3638 };
3639
3640 class write_builder_t : public access_builder_t {
3641 public:
3642 write_builder_t() = default;
3643
write_builder_t(ngen::HW hw,ir_context_t & ir_ctx,const constraint_set_t & cset,const view_t & view,const expr_t & mem_buf,const expr_t & reg_buf,bool is_slm,ngen_proxy::AtomicOp atomic_op=ngen_proxy::AtomicOp::undef)3644 write_builder_t(ngen::HW hw, ir_context_t &ir_ctx,
3645 const constraint_set_t &cset, const view_t &view,
3646 const expr_t &mem_buf, const expr_t ®_buf, bool is_slm,
3647 ngen_proxy::AtomicOp atomic_op = ngen_proxy::AtomicOp::undef)
3648 : access_builder_t(hw, ir_ctx, cset, view, mem_buf, reg_buf, is_slm,
3649 /*is_prefetch=*/false, /*is_load=*/false, atomic_op) {}
3650 };
3651
3652 // Generates loads to the post-op buffer and applies a single post-op.
3653 // There are two types of post-ops:
3654 // - Eltwise: lhs = F(lhs)
3655 // - Binary: lhs = F(lhs, rhs)
3656 // Binary requires rhs load which may be either:
3657 // - Pre-loaded and used for all updates (preferred approach)
3658 // - Loaded for every tile
3659 // Right-hand side tensor supports implicit broadcasting: value is broadcasted
3660 // across a size one dimension.
3661 class post_op_builder_t {
3662 public:
post_op_builder_t(ngen::HW hw,ir_context_t & ir_ctx,const constraint_set_t & cset,const post_op_t & post_op,int & available_pre_load_size)3663 post_op_builder_t(ngen::HW hw, ir_context_t &ir_ctx,
3664 const constraint_set_t &cset, const post_op_t &post_op,
3665 int &available_pre_load_size)
3666 : hw_(hw), ir_ctx_(ir_ctx), cset_(cset), post_op_(post_op) {
3667 if (!post_op_.needs_load()) return;
3668
3669 // Estimate buffer size required to load full rhs, do not do pre-load
3670 // if it requires too much GRF memory.
3671 int estimated_rhs_bytes = 0;
3672
3673 estimated_rhs_bytes
3674 += int(post_op.rhs_view().create_dense_vlayout().size());
3675
3676 if (needs_rhs_convert()) {
3677 estimated_rhs_bytes += int(post_op.rhs_view()
3678 .create_dense_vlayout()
3679 .retype(type_t::f32())
3680 .size());
3681 }
3682
3683 if (estimated_rhs_bytes <= available_pre_load_size) {
3684 available_pre_load_size -= estimated_rhs_bytes;
3685 do_preload_ = true;
3686 }
3687 }
3688
3689 // Pre-loads rhs data for the whole update.
build_pre_load()3690 stmt_t build_pre_load() {
3691 if (!do_preload_) return stmt_t();
3692
3693 auto rhs_load_reg_buf = make_tmp_rhs_buffer();
3694 read_builder_t read(hw_, ir_ctx_, cset_, post_op_.rhs_view(),
3695 post_op_.rhs_buf(), rhs_load_reg_buf, /*is_slm=*/false);
3696 pre_load_rhs_reg_buf_ = rhs_load_reg_buf;
3697 pre_load_rhs_reg_layout_ = read.reg_layout();
3698 if (!needs_rhs_convert()) rhs_reg_buf_ = rhs_load_reg_buf;
3699 update_rhs_buf_size(rhs_load_reg_buf, read.reg_buf_size());
3700 return read.stmt();
3701 }
3702
3703 // Converts the pre-loaded rhs data to f32.
build_pre_convert()3704 stmt_t build_pre_convert() {
3705 if (!do_preload_ || !needs_rhs_convert()) return stmt_t();
3706
3707 auto rhs_f32_reg_buf = make_tmp_rhs_buffer();
3708 auto f32_layout
3709 = pre_load_rhs_reg_layout_.make_dense().retype(type_t::f32());
3710 update_rhs_buf_size(rhs_f32_reg_buf, int(f32_layout.size()));
3711
3712 // Reorder to f32.
3713 auto ret = create_reorder_stmt(pre_load_rhs_reg_layout_, f32_layout,
3714 pre_load_rhs_reg_buf_, rhs_f32_reg_buf);
3715
3716 // Now rhs is converted to f32.
3717 pre_load_rhs_reg_layout_ = f32_layout;
3718 rhs_reg_buf_ = rhs_f32_reg_buf;
3719
3720 return ret;
3721 }
3722
3723 // Loads rhs data for one tile.
build_tile_load(const tensor_t & tile)3724 stmt_t build_tile_load(const tensor_t &tile) {
3725 if (!post_op_.needs_load()) return stmt_t();
3726
3727 stmt_t stmt;
3728 auto rhs_tile = post_op_.apply_mask(tile);
3729 if (post_op_.needs_load() && !do_preload_) {
3730 // Load and convert now.
3731 auto po = post_op_.create_sub_post_op(rhs_tile);
3732 auto rhs_load_reg_buf = make_tmp_rhs_buffer();
3733 read_builder_t read(hw_, ir_ctx_, cset_, po.rhs_view(),
3734 po.rhs_buf(), rhs_load_reg_buf,
3735 /*is_slm=*/false);
3736 stmt = stmt.append(read.stmt());
3737
3738 update_rhs_buf_size(rhs_load_reg_buf, read.reg_buf_size());
3739
3740 if (needs_rhs_convert()) {
3741 auto rhs_f32_reg_buf = make_tmp_rhs_buffer();
3742 auto f32_layout
3743 = read.reg_layout().make_dense().retype(type_t::f32());
3744 update_rhs_buf_size(rhs_f32_reg_buf, int(f32_layout.size()));
3745 // Reorder to f32.
3746 stmt = stmt.append(create_reorder_stmt(read.reg_layout(),
3747 f32_layout, rhs_load_reg_buf, rhs_f32_reg_buf));
3748
3749 // Now rhs is converted to f32.
3750 rhs_reg_layout_ = f32_layout;
3751 rhs_reg_buf_ = rhs_f32_reg_buf;
3752 } else {
3753 rhs_reg_layout_ = read.reg_layout();
3754 rhs_reg_buf_ = rhs_load_reg_buf;
3755 }
3756 } else {
3757 // Already pre-loaded and pre-converted.
3758 rhs_reg_layout_ = pre_load_rhs_reg_layout_.map(rhs_tile);
3759 }
3760 return stmt;
3761 }
3762
3763 // Applies post-op for a single tile.
build_tile_stmt(const tensor_t & tile,const layout_t & lhs_reg_layout,const expr_t & lhs_buf)3764 stmt_t build_tile_stmt(const tensor_t &tile, const layout_t &lhs_reg_layout,
3765 const expr_t &lhs_buf) {
3766 auto po = post_op_.create_sub_post_op(tile);
3767 if (!po.has_rhs()) {
3768 // Apply eltwise post-op.
3769 int lhs_size = lhs_reg_layout.size();
3770 int lhs_elems = lhs_size / int(sizeof(float));
3771 return po.eltwise().call({expr_t(lhs_elems), lhs_buf});
3772 }
3773
3774 auto lhs_layout = lhs_reg_layout;
3775 auto rhs_layout = (po.needs_load()
3776 ? rhs_reg_layout_
3777 : lhs_layout.map(
3778 tensor_t(std::vector<dim_t>(tile.ndims(), 1))));
3779
3780 int inner_dim_idx = lhs_layout.blocks().front().dim_idx;
3781 bool do_broadcast = po.is_broadcast_dim(inner_dim_idx);
3782 if (!do_broadcast) layout_t::align_layouts(lhs_layout, rhs_layout);
3783
3784 auto lhs_blocks = lhs_layout.blocks();
3785 auto rhs_blocks = rhs_layout.blocks();
3786
3787 auto &lhs0 = lhs_blocks[0];
3788
3789 ir_assert(lhs0.dim_idx == inner_dim_idx);
3790 ir_assert(dim_t(lhs0.stride) == 1);
3791
3792 if (!do_broadcast) {
3793 auto &rhs0 = rhs_blocks[0];
3794 ir_assert(lhs0.dim_idx == rhs0.dim_idx);
3795 ir_assert(lhs0.block == rhs0.block);
3796 MAYBE_UNUSED(rhs0);
3797 }
3798
3799 std::vector<dim_t> inner_tile_dims(tile.ndims(), 1);
3800 inner_tile_dims[inner_dim_idx] = lhs0.block;
3801
3802 auto &lhs_type = lhs_layout.type();
3803 auto &rhs_type = rhs_layout.type();
3804 ir_assert(lhs_type == type_t::f32());
3805 ir_assert(rhs_type == type_t::f32());
3806
3807 // Handle one inner tile at a time. Inner tile covers a single block
3808 // with a single dimension.
3809 stmt_t stmt;
3810 lhs_layout.for_each_tile(tensor_t(inner_tile_dims),
3811 [&](const std::vector<dim_t> &lhs_start) {
3812 auto rhs_start = po.apply_mask(lhs_start, 0);
3813 int lhs_off0 = lhs_layout(lhs_start) * lhs_type.size();
3814 int rhs_off0 = rhs_layout(rhs_start) * rhs_type.size();
3815
3816 int elems = lhs0.block;
3817 int step = (elems < 16 ? 8 : 16);
3818 for (int i = 0; i < elems; i += step) {
3819 int cur_elems = std::min(step, elems - i);
3820 ir_assert(math::is_pow2(cur_elems));
3821 auto lhs_vec_type = lhs_type.with_elems(cur_elems);
3822 auto rhs_vec_type = rhs_type.with_elems(
3823 do_broadcast ? 1 : cur_elems);
3824
3825 int lhs_off = lhs_off0 + i * lhs_type.size();
3826 int rhs_off = rhs_off0;
3827 if (!do_broadcast) rhs_off += i * rhs_type.size();
3828
3829 auto lhs = load_t::make(lhs_vec_type, lhs_buf, lhs_off);
3830 expr_t rhs;
3831 if (po.needs_load()) {
3832 int stride
3833 = (do_broadcast ? load_t::default_stride
3834 : int(rhs_blocks[0].stride)
3835 * rhs_type.size());
3836 rhs = load_t::make(rhs_vec_type, rhs_reg_buf_,
3837 rhs_off, stride);
3838 } else {
3839 // rhs is scalar and passed in the kernel arguments.
3840 rhs = po.rhs_buf();
3841 ir_assert(rhs.type().is_scalar());
3842 }
3843
3844 if (rhs.type().elems() != cur_elems) {
3845 rhs = shuffle_t::make_broadcast(rhs, cur_elems);
3846 }
3847
3848 if (po.rhs_scale() != 1) {
3849 // Scale rhs first.
3850 rhs = binary_op_t::make(op_kind_t::_mul, rhs,
3851 shuffle_t::make_broadcast(
3852 po.rhs_scale(), cur_elems));
3853 }
3854
3855 auto new_lhs
3856 = binary_op_t::make(po.op_kind(), lhs, rhs);
3857 if (new_lhs.type().is_bool()) {
3858 // Apply bool -> f32 cast when binary is a comparison op.
3859 new_lhs = cast(new_lhs, type_t::f32(cur_elems));
3860 }
3861 auto store = store_t::make(lhs_buf, lhs_off, new_lhs);
3862 stmt = stmt.append(store);
3863 }
3864 });
3865
3866 // Reset rhs layout.
3867 rhs_reg_layout_ = layout_t();
3868 return stmt;
3869 }
3870
allocs() const3871 std::vector<stmt_t> allocs() const {
3872 std::vector<stmt_t> allocs;
3873 for (auto &kv : rhs_bufs_)
3874 allocs.push_back(
3875 alloc_t::make(kv.first, kv.second, alloc_kind_t::grf));
3876 return allocs;
3877 }
3878
3879 private:
make_tmp_rhs_buffer() const3880 expr_t make_tmp_rhs_buffer() const {
3881 auto &rhs_name = post_op_.rhs_buf().as<var_t>().name;
3882 return ir_ctx_.create_tmp_var(type_t::byte_ptr(), "tmp_" + rhs_name);
3883 }
3884
update_rhs_buf_size(const expr_t & buf,int size)3885 void update_rhs_buf_size(const expr_t &buf, int size) {
3886 rhs_bufs_[buf] = std::max(rhs_bufs_[buf], size);
3887 }
3888
needs_rhs_convert() const3889 bool needs_rhs_convert() const {
3890 if (!post_op_.has_rhs()) return false;
3891 return post_op_.rhs_view().type() != type_t::f32();
3892 }
3893
3894 ngen::HW hw_;
3895 ir_context_t &ir_ctx_;
3896 const constraint_set_t &cset_;
3897 post_op_t post_op_;
3898
3899 bool do_preload_ = false;
3900
3901 expr_t pre_load_rhs_reg_buf_;
3902 layout_t pre_load_rhs_reg_layout_;
3903
3904 expr_t rhs_reg_buf_;
3905 layout_t rhs_reg_layout_;
3906
3907 object_map_t<expr_t, int> rhs_bufs_;
3908 };
3909
3910 // Zero pads a register buffer of f32 type.
3911 class zero_pad_builder_t {
3912 public:
zero_pad_builder_t(const constraint_set_t & cset,const post_op_context_t & post_op_ctx,const view_t & mem_view,const layout_t & reg_layout,const expr_t & reg_buf)3913 zero_pad_builder_t(const constraint_set_t &cset,
3914 const post_op_context_t &post_op_ctx, const view_t &mem_view,
3915 const layout_t ®_layout, const expr_t ®_buf)
3916 : cset_(cset)
3917 , post_op_ctx_(post_op_ctx)
3918 , mem_view_(mem_view)
3919 , reg_layout_(reg_layout)
3920 , reg_buf_(reg_buf) {
3921 ir_assert(mem_view_.nvdims() == reg_layout_.ndims())
3922 << "Incompatible view/layout.";
3923 build();
3924 }
3925
stmt() const3926 const stmt_t &stmt() const { return stmt_; }
3927
3928 private:
build()3929 void build() {
3930 int max_step = 16; // Handle 16 elements at most in one step.
3931 auto base_tile = reg_layout_.split_into_max_tile(
3932 max_step, /*is_dense_tile=*/true);
3933 reg_layout_.for_each_tile(
3934 base_tile, [&](const std::vector<dim_t> &start) {
3935 tensor_t tile(base_tile.dims(), start);
3936 auto sub_layout = reg_layout_.map(tile);
3937 auto sub_view = mem_view_.create_sub_view(tile);
3938 int elems = tile.elems();
3939 int off = reg_layout_(start) * reg_layout_.type().size();
3940 auto mask_tensor = create_mask(sub_view, sub_layout);
3941 auto mask = mask_tensor.to_expr(elems);
3942 auto store = store_t::make(reg_buf_, off,
3943 shuffle_t::make_broadcast(expr_t(0.0f), elems),
3944 store_t::default_stride, -mask);
3945 stmt_ = stmt_.append(store);
3946 });
3947 }
3948
create_mask(const view_t & view,const layout_t & layout) const3949 mask_tensor_t create_mask(
3950 const view_t &view, const layout_t &layout) const {
3951 mask_tensor_t mask_tensor(layout);
3952 std::vector<dim_t> args(layout.ndims());
3953 fill_mask_impl(mask_tensor, 0, args, view, layout);
3954 mask_tensor.simplify(cset_);
3955 return mask_tensor;
3956 }
3957
fill_mask_impl(mask_tensor_t & mask_tensor,int idx,std::vector<dim_t> & args,const view_t & view,const layout_t & layout) const3958 void fill_mask_impl(mask_tensor_t &mask_tensor, int idx,
3959 std::vector<dim_t> &args, const view_t &view,
3960 const layout_t &layout) const {
3961 if (idx == layout.ndims()) {
3962 expr_t mask = bool_imm_t::make(true);
3963 for (int i = 0; i < layout.ndims(); i++) {
3964 if (!post_op_ctx_.is_lhs_dim_zero_padded(i)) continue;
3965 mask &= (view.vstart(i) + args[i] < post_op_ctx_.lhs_dim(i));
3966 }
3967 auto off = layout.offset(args, /*ignore_offset=*/true);
3968 mask_tensor.set_mask(off, mask);
3969 return;
3970 }
3971
3972 for (int i = 0; i < int(layout.dims()[idx]); i++) {
3973 args[idx] = i;
3974 fill_mask_impl(mask_tensor, idx + 1, args, view, layout);
3975 }
3976 }
3977
3978 const constraint_set_t &cset_;
3979 const post_op_context_t &post_op_ctx_;
3980
3981 view_t mem_view_;
3982 layout_t reg_layout_;
3983 expr_t reg_buf_;
3984
3985 stmt_t stmt_;
3986 };
3987
3988 // Performs the following steps after the computation:
3989 // - Conversion
3990 // - Applying post-ops
3991 // - GRF reorder to match the memory layout
3992 // - Store to the destination
3993 class epilogue_builder_t {
3994 public:
epilogue_builder_t(const conv_config_t & cfg,ir_context_t & ir_ctx,const constraint_set_t & cset,const post_op_context_t & post_op_ctx,const tensor_t & tile,const view_t & mem_view,const layout_t & reg_layout,const expr_t & mem_buf,const expr_t & reg_buf)3995 epilogue_builder_t(const conv_config_t &cfg, ir_context_t &ir_ctx,
3996 const constraint_set_t &cset, const post_op_context_t &post_op_ctx,
3997 const tensor_t &tile, const view_t &mem_view,
3998 const layout_t ®_layout, const expr_t &mem_buf,
3999 const expr_t ®_buf)
4000 : cfg_(cfg)
4001 , ir_ctx_(ir_ctx)
4002 , cset_(cset)
4003 , post_op_ctx_(post_op_ctx)
4004 , mem_view_(mem_view)
4005 , reg_layout_(reg_layout)
4006 , mem_buf_(mem_buf)
4007 , reg_buf_(reg_buf) {
4008
4009 int pre_load_size = pre_load_max_size_;
4010 for (auto &po : post_op_ctx_.post_ops()) {
4011 auto sub_po = po.create_sub_post_op(tile);
4012 post_op_builders_.emplace_back(
4013 cfg.hw, ir_ctx, cset_, sub_po, pre_load_size);
4014 }
4015 build();
4016 }
4017
stmt() const4018 const stmt_t &stmt() const { return stmt_; }
4019
4020 private:
4021 // Represents one stage in the flow between multiplication and storing the
4022 // updated result to memory.
4023 //
4024 // Flow with post-ops:
4025 // Multiplication ->
4026 // M_x -> [R_f32] -> P0_f32 -> ... -> Pn_f32 -> [Z_f32] -> S_y ->
4027 // GMEM
4028 // Flow without post-ops:
4029 // Multiplication ->
4030 // M_x -> S_y ->
4031 // GMEM
4032 // Where:
4033 // - x is data type after multiplication
4034 // - y is destination data type
4035 // - M_x is a stage after multiplication
4036 // - R_f32 is a stage after reordering from M_x to f32 (optional)
4037 // - Pi_f32 is a stage after applying Pi post-op
4038 // - Z_f32 is a stage after restoring zero padding (optional)
4039 // - S_y is a stage before storing data to destination
4040 struct stage_t {
stage_tdnnl::impl::gpu::jit::epilogue_builder_t::stage_t4041 stage_t(const layout_t &layout, const expr_t &buf,
4042 const stmt_t &stmt = stmt_t())
4043 : layout(layout), buf(buf), stmt(stmt) {}
4044
set_nextdnnl::impl::gpu::jit::epilogue_builder_t::stage_t4045 void set_next(ngen::HW hw, ir_context_t &ir_ctx, stage_t *next,
4046 bool force_reorder) {
4047 if (!next) return;
4048 bool do_reorder
4049 = !layout.is_equal(next->layout, /*compare_offset=*/false);
4050 if (force_reorder) do_reorder = true;
4051 if (do_reorder) {
4052 ir_assert(stmt.is_empty());
4053 // Generate reorder between stages.
4054 stmt = create_reorder_stmt(
4055 layout, next->layout, buf, next->buf);
4056 } else {
4057 // Reuse the same GRF buffer for the next stage.
4058 int this_off = to_cpp<int>(layout.offset_in_bytes());
4059 int next_off = to_cpp<int>(next->layout.offset_in_bytes());
4060 ir_assert(next_off == 0);
4061 MAYBE_UNUSED(next_off);
4062 next->set_buf(buf[this_off]);
4063 }
4064 }
4065
set_bufdnnl::impl::gpu::jit::epilogue_builder_t::stage_t4066 void set_buf(const expr_t &buf) {
4067 // Replace old buffer if there is an assigned statement.
4068 if (!stmt.is_empty()) { stmt = substitute(stmt, this->buf, buf); }
4069 this->buf = buf;
4070 }
4071
buf_basednnl::impl::gpu::jit::epilogue_builder_t::stage_t4072 const expr_t &buf_base() const {
4073 if (buf.is<var_t>()) return buf;
4074 return buf.as<ptr_t>().base;
4075 }
4076
buf_sizednnl::impl::gpu::jit::epilogue_builder_t::stage_t4077 int buf_size() const {
4078 ir_assert(buf.is_same(buf_base()))
4079 << "Size must be queried from another stage.";
4080 return int(layout.size());
4081 }
4082
prepend_stmtdnnl::impl::gpu::jit::epilogue_builder_t::stage_t4083 void prepend_stmt(const stmt_t &stmt) {
4084 this->stmt = stmt.append(this->stmt);
4085 }
4086
4087 layout_t layout;
4088 expr_t buf;
4089 stmt_t stmt;
4090 };
4091
build()4092 void build() {
4093 for (auto &po_builder : post_op_builders_) {
4094 stmt_ = stmt_.append(po_builder.build_pre_load());
4095 }
4096
4097 for (auto &po_builder : post_op_builders_) {
4098 stmt_ = stmt_.append(po_builder.build_pre_convert());
4099 }
4100
4101 auto tmp_type = (post_op_builders_.empty() ? mem_view_.type()
4102 : type_t::f32());
4103 int tmp_buf_elems = tmp_buf_size_ / tmp_type.size();
4104 auto base_tile = mem_view_.split_into_max_tile(
4105 tmp_buf_elems, /*is_dense=*/false);
4106 mem_view_.for_each_tile(
4107 base_tile, [&](const std::vector<dim_t> &start) {
4108 build_tile(tensor_t(base_tile.dims(), start));
4109 });
4110
4111 // Generate alloc statements for rhs post-op buffers.
4112 std::vector<stmt_t> allocs;
4113 for (auto &po_builder : post_op_builders_) {
4114 auto po_allocs = po_builder.allocs();
4115 allocs.insert(allocs.end(), po_allocs.begin(), po_allocs.end());
4116 }
4117 stmt_ = jit::inject_alloc_stmts(stmt_, allocs, /*put_innermost=*/true);
4118 }
4119
4120 // Builds statements for a tile iterating through all stages (see stage_t
4121 // description).
build_tile(const tensor_t & tile)4122 void build_tile(const tensor_t &tile) {
4123 auto mem_sub_view = mem_view_.create_sub_view(tile);
4124 auto reg_sub_layout = reg_layout_.map(tile);
4125
4126 auto tmp_reg_buf = ir_ctx_.create_tmp_var(type_t::byte_ptr(), "c_tmp");
4127 bool restore_zero_padding = post_op_ctx_.need_to_restore_zero_padding();
4128
4129 // S_y -> GMEM.
4130 ngen_proxy::AtomicOp atomic_op
4131 = (cfg_.do_atomic_update ? ngen_proxy::AtomicOp::fadd
4132 : ngen_proxy::AtomicOp::undef);
4133 write_builder_t r2g(cfg_.hw, ir_ctx_, cset_, mem_sub_view, mem_buf_,
4134 tmp_reg_buf,
4135 /*is_slm=*/false, /*atomic_op=*/atomic_op);
4136
4137 // Initialize stages.
4138 std::vector<stage_t> stages;
4139 stages.emplace_back(reg_sub_layout, reg_buf_); // M_x
4140 if (!post_op_builders_.empty()) {
4141 auto po_layout
4142 = r2g.reg_layout().retype(type_t::f32()).make_dense();
4143 for (int i = 0; i < int(post_op_builders_.size()); i++) {
4144 auto buf = ir_ctx_.create_tmp_var(type_t::byte_ptr(), "c_tmp");
4145 stages.emplace_back(po_layout, buf); // Pi_f32
4146 }
4147 if (restore_zero_padding) {
4148 auto &last = stages.back();
4149 stages.emplace_back(last.layout, last.buf); // Z_f32.
4150 }
4151 }
4152 stages.emplace_back(r2g.reg_layout(), tmp_reg_buf, r2g.stmt()); // S_y
4153
4154 int nstages = int(stages.size());
4155 int npost_ops = int(post_op_builders_.size());
4156
4157 bool is_dpasw = (cfg_.fma_kind == fma_kind_t::dpasw);
4158
4159 // Generate reorders between stages and create buffers.
4160 for (int i = 0; i < nstages; i++) {
4161 auto *next_stage = (i + 1 < nstages ? &stages[i + 1] : nullptr);
4162 // Always perform reorder when dpasw is used. This is to ensure
4163 // that C is properly restored and permuted after dpasw.
4164 stages[i].set_next(cfg_.hw, ir_ctx_, next_stage,
4165 /*force_reorder=*/i == 0 && is_dpasw);
4166 }
4167
4168 std::vector<stmt_t> tile_load_stmts;
4169 for (int i = 0; i < npost_ops; i++) {
4170 auto &po_builder = post_op_builders_[i];
4171 // Generate load for post-op.
4172 tile_load_stmts.push_back(po_builder.build_tile_load(tile));
4173
4174 // Generate post-op statement.
4175 auto &s = stages[i + 1];
4176 s.prepend_stmt(po_builder.build_tile_stmt(tile, s.layout, s.buf));
4177 }
4178
4179 // Restore zero padding if needed.
4180 if (restore_zero_padding) {
4181 auto &s = stages[nstages - 2];
4182 zero_pad_builder_t builder(
4183 cset_, post_op_ctx_, mem_sub_view, s.layout, s.buf);
4184 s.prepend_stmt(builder.stmt());
4185 }
4186
4187 stmt_t tile_stmt;
4188
4189 // Add stage statements. Emit stages in blocks to reduce GRF
4190 // consumption.
4191 int stage_blk = 8;
4192 for (int i = 0; i < nstages; i += stage_blk) {
4193 int stage_beg = i;
4194 int stage_end = std::min(nstages, i + stage_blk);
4195 int po_beg = std::max(0, i - 1);
4196 int po_end = std::min(npost_ops, i + stage_blk - 1);
4197 stmt_t blk_stmt;
4198 for (int j = po_beg; j < po_end; j++) {
4199 blk_stmt = blk_stmt.append(tile_load_stmts[j]);
4200 }
4201 for (int j = stage_beg; j < stage_end; j++) {
4202 blk_stmt = blk_stmt.append(stages[j].stmt);
4203 }
4204 tile_stmt = tile_stmt.append(blk_stmt);
4205 }
4206
4207 // Generate alloc statements for stage buffers.
4208 object_set_t<expr_t> seen;
4209 for (int i = 0; i < nstages; i++) {
4210 auto &s = stages[i];
4211 auto &buf = s.buf_base();
4212 auto ret = seen.insert(buf);
4213 if (i == 0 || !ret.second) continue;
4214 tile_stmt = alloc_t::make(
4215 buf, s.buf_size(), alloc_kind_t::grf, {}, tile_stmt);
4216 }
4217
4218 stmt_ = stmt_.append(tile_stmt);
4219 }
4220
4221 const conv_config_t &cfg_;
4222 ir_context_t &ir_ctx_;
4223 const constraint_set_t &cset_;
4224 const post_op_context_t &post_op_ctx_;
4225
4226 view_t mem_view_;
4227 layout_t reg_layout_;
4228
4229 expr_t mem_buf_;
4230 expr_t reg_buf_;
4231
4232 // Tile size in bytes. The tile data type is:
4233 // - the destination data type without post-ops
4234 // - f32 with post-ops
4235 static const int tmp_buf_size_ = 128;
4236 static const int pre_load_max_size_ = 256;
4237
4238 std::vector<post_op_builder_t> post_op_builders_;
4239
4240 stmt_t stmt_;
4241 };
4242
4243 class multiply_builder_t {
4244 public:
4245 multiply_builder_t() = default;
4246
multiply_builder_t(const conv_config_t & cfg,const bmnk_mapper_t & bmnk_mapper,const view_t & a_view,const view_t & b_view,const expr_t & a_buf,const expr_t & b_buf,const expr_t & c_buf)4247 multiply_builder_t(const conv_config_t &cfg,
4248 const bmnk_mapper_t &bmnk_mapper, const view_t &a_view,
4249 const view_t &b_view, const expr_t &a_buf, const expr_t &b_buf,
4250 const expr_t &c_buf)
4251 : hw_(cfg.hw)
4252 , simd_size_(cfg.simd_size)
4253 , bmnk_mapper_(bmnk_mapper)
4254 , a_view_(a_view)
4255 , b_view_(b_view)
4256 , a_buf_(a_buf)
4257 , b_buf_(b_buf)
4258 , c_buf_(c_buf) {
4259 switch (cfg.fma_kind) {
4260 case fma_kind_t::dpasw:
4261 case fma_kind_t::dpas:
4262 if (try_build_dpas()) return;
4263 break;
4264 case fma_kind_t::mad:
4265 if (try_build_mad()) return;
4266 break;
4267 default: ir_error_not_expected() << "Unknown FMA kind.";
4268 }
4269
4270 ir_error_not_expected()
4271 << "Can't decompose into multiplication instructions.";
4272 }
4273
stmt() const4274 const stmt_t &stmt() const { return stmt_; }
4275
c_layout() const4276 const layout_t &c_layout() const { return c_layout_; }
4277
a_grf_bundle()4278 ngen_proxy::Bundle a_grf_bundle() {
4279 if (!do_transpose_) return ngen_proxy::Bundle();
4280 return ngen_proxy::Bundle(1, ngen_proxy::Bundle::any);
4281 }
4282
b_grf_bundle()4283 ngen_proxy::Bundle b_grf_bundle() {
4284 if (do_transpose_) return ngen_proxy::Bundle();
4285 return ngen_proxy::Bundle(1, ngen_proxy::Bundle::any);
4286 }
4287
c_grf_bundle()4288 ngen_proxy::Bundle c_grf_bundle() {
4289 return ngen_proxy::Bundle(0, ngen_proxy::Bundle::any);
4290 }
4291
str() const4292 std::string str() const {
4293 std::ostringstream oss;
4294 oss << "A view: " << a_view_ << std::endl;
4295 oss << "B view: " << b_view_ << std::endl;
4296 oss << "C layout: " << c_layout_ << std::endl;
4297 oss << "Statement: " << std::endl << stmt_;
4298 return oss.str();
4299 }
4300
4301 private:
4302 struct loop_info_t {
4303 loop_info_t() = default;
4304
loop_info_tdnnl::impl::gpu::jit::multiply_builder_t::loop_info_t4305 loop_info_t(const expr_t &var, bmnk_kind_t bmnk_kind, int dim)
4306 : var(var), bmnk_kind(bmnk_kind), dim(dim) {}
4307
4308 expr_t var;
4309 bmnk_kind_t bmnk_kind;
4310
4311 int dim;
4312 int a_idx = -1;
4313 int b_idx = -1;
4314 int c_idx = -1;
4315 int block = 1;
4316 };
4317
try_build_dpas()4318 bool try_build_dpas() {
4319 ir_assert(a_view_.can_convert_to_vlayout())
4320 << "Views are not supported with dpas/dpasw.";
4321 ir_assert(b_view_.can_convert_to_vlayout())
4322 << "Views are not supported with dpas/dpasw.";
4323
4324 auto a_layout = a_view_.create_vlayout();
4325 auto b_layout = b_view_.create_vlayout();
4326
4327 bmnk_block_mapper_t from_bmnk_mapper(bmnk_mapper_);
4328 from_bmnk_mapper.push_blocks(abc_kind_t::a, a_layout.blocks());
4329 from_bmnk_mapper.push_blocks(abc_kind_t::b, b_layout.blocks());
4330
4331 // Convert to MNK layouts.
4332 a_layout = bmnk_mapper_.map_to_bmnk(
4333 abc_kind_t::a, {bmnk_kind_t::m, bmnk_kind_t::k}, a_layout);
4334 b_layout = bmnk_mapper_.map_to_bmnk(
4335 abc_kind_t::b, {bmnk_kind_t::k, bmnk_kind_t::n}, b_layout);
4336
4337 multiply_desc_t desc(a_layout, b_layout, /*force_c_upconvert=*/true);
4338 if (!dpas_t::matches_types(
4339 hw_, desc.a_type(), desc.b_type(), desc.c_type()))
4340 return false;
4341
4342 int sdepth = 8;
4343 int rcount = std::min(utils::rnd_up_pow2(desc.n()), 8);
4344 auto _dpas = dpas_t::make(/*is_dpasw=*/false, simd_size_, sdepth,
4345 rcount, desc.c_type(), desc.a_type(), desc.b_type());
4346 if (_dpas.as<dpas_t>().matches(desc)) {
4347 build_dpas(from_bmnk_mapper, _dpas.as<dpas_t>(), desc);
4348 return true;
4349 }
4350
4351 // Try to transpose and flip: C += A * B -> C^T = B^T * A^T.
4352 rcount = std::min(utils::rnd_up_pow2(desc.m()), 8);
4353 desc = multiply_desc_t(
4354 b_layout.transpose(), a_layout.transpose(), true);
4355 _dpas = dpas_t::make(/*is_dpasw=*/false, /*exec_size=*/simd_size_,
4356 sdepth, rcount, desc.c_type(), desc.a_type(), desc.b_type());
4357
4358 if (_dpas.as<dpas_t>().matches(desc)) {
4359 do_transpose_ = true;
4360 build_dpas(from_bmnk_mapper, _dpas.as<dpas_t>(), desc);
4361 return true;
4362 }
4363 return false;
4364 }
4365
build_dpas(const bmnk_block_mapper_t & from_bmnk_mapper,const dpas_t & dpas,const multiply_desc_t & desc)4366 void build_dpas(const bmnk_block_mapper_t &from_bmnk_mapper,
4367 const dpas_t &dpas, const multiply_desc_t &desc) {
4368 int m_blk = dpas.simd_size;
4369 int n_blk = dpas.rcount;
4370 int k_blk = dpas.sdepth * 4 / dpas.src1_type.size();
4371
4372 c_layout_ = compute_dpas_c_layout(m_blk, n_blk, dpas.c_layout(), desc);
4373
4374 expr_t a_buf = a_buf_;
4375 expr_t b_buf = b_buf_;
4376 if (do_transpose_) std::swap(a_buf, b_buf);
4377
4378 for (int i_k = 0; i_k < desc.k(); i_k += k_blk) {
4379 for (int i_m = 0; i_m < desc.m(); i_m += m_blk) {
4380 for (int i_n = 0; i_n < desc.n(); i_n += n_blk) {
4381 std::vector<int> a_args = {i_m, 0};
4382 std::vector<int> b_args = {0, i_n};
4383 std::vector<int> c_args = {i_m, i_n};
4384 auto a = a_buf[desc.a_layout()(a_args)
4385 * desc.a_type().size()];
4386 auto b = b_buf[desc.b_layout()(b_args)
4387 * desc.b_type().size()];
4388 auto c = c_buf_[c_layout_(c_args) * desc.c_type().size()];
4389 stmt_ = stmt_.append(dpas(c, c, a, b));
4390 }
4391 }
4392 }
4393
4394 // Transpose C layout back if needed.
4395 if (do_transpose_) c_layout_ = c_layout_.transpose();
4396
4397 // Convert C layout back to problem notation.
4398 c_layout_ = from_bmnk_mapper.map_from_bmnk(
4399 abc_kind_t::c, {bmnk_kind_t::m, bmnk_kind_t::n}, c_layout_);
4400 }
4401
compute_dpas_c_layout(int m_blk,int n_blk,const layout_t & blk_layout,const multiply_desc_t & desc)4402 static layout_t compute_dpas_c_layout(int m_blk, int n_blk,
4403 const layout_t &blk_layout, const multiply_desc_t &desc) {
4404 auto c_layout = blk_layout;
4405 c_layout = c_layout.add_outer_block(1, desc.n() / n_blk);
4406 c_layout = c_layout.add_outer_block(0, desc.m() / m_blk);
4407 return c_layout;
4408 }
4409
try_build_mad()4410 bool try_build_mad() {
4411 auto loops = create_loop_nest();
4412
4413 if (try_build_mad_kmn_block_by_n(loops)) return true;
4414 if (try_build_mad_kmn_block_by_b(loops)) return true;
4415
4416 return false;
4417 }
4418
create_loop_nest() const4419 std::vector<loop_info_t> create_loop_nest() const {
4420 object_map_t<expr_t, loop_info_t> loops;
4421 for (auto *view : {&a_view_, &b_view_}) {
4422 abc_kind_t abc_kind
4423 = (view == &a_view_ ? abc_kind_t::a : abc_kind_t::b);
4424 for (int i = 0; i < view->nvdims(); i++) {
4425 auto &var = bmnk_mapper_.var(abc_kind, i);
4426 int dim = int(view->vdims()[i]);
4427 if (dim == 1) continue;
4428
4429 if (loops.count(var) > 0) continue;
4430 loops[var] = loop_info_t(var, bmnk_mapper_.bmnk_kind(var), dim);
4431 }
4432 }
4433
4434 std::vector<loop_info_t> ret;
4435 for (auto &kv : loops) {
4436 auto &loop = kv.second;
4437 loop.a_idx = bmnk_mapper_.dim_idx(abc_kind_t::a, loop.var);
4438 loop.b_idx = bmnk_mapper_.dim_idx(abc_kind_t::b, loop.var);
4439 loop.c_idx = bmnk_mapper_.dim_idx(abc_kind_t::c, loop.var);
4440 ret.push_back(kv.second);
4441 }
4442 return ret;
4443 }
4444
4445 // Order of loops: BKMN, block by N.
try_build_mad_kmn_block_by_n(std::vector<loop_info_t> & _loops)4446 bool try_build_mad_kmn_block_by_n(std::vector<loop_info_t> &_loops) {
4447 return try_build_mad_impl(_loops,
4448 {bmnk_kind_t::b, bmnk_kind_t::k, bmnk_kind_t::m,
4449 bmnk_kind_t::n},
4450 bmnk_kind_t::n);
4451 }
4452
4453 // Order of loops: BKMN, block by B.
try_build_mad_kmn_block_by_b(std::vector<loop_info_t> & _loops)4454 bool try_build_mad_kmn_block_by_b(std::vector<loop_info_t> &_loops) {
4455 return try_build_mad_impl(_loops,
4456 {bmnk_kind_t::b, bmnk_kind_t::k, bmnk_kind_t::m,
4457 bmnk_kind_t::n},
4458 bmnk_kind_t::b);
4459 }
4460
try_build_mad_impl(std::vector<loop_info_t> & _loops,const std::vector<bmnk_kind_t> & loop_order,bmnk_kind_t block_bmnk_kind)4461 bool try_build_mad_impl(std::vector<loop_info_t> &_loops,
4462 const std::vector<bmnk_kind_t> &loop_order,
4463 bmnk_kind_t block_bmnk_kind) {
4464 auto loops = _loops;
4465 int nloops = int(loops.size());
4466 std::sort(loops.begin(), loops.end(),
4467 [&](const loop_info_t &a, const loop_info_t &b) {
4468 int a_key = ir_utils::find_index(loop_order, a.bmnk_kind);
4469 int b_key = ir_utils::find_index(loop_order, b.bmnk_kind);
4470 ir_assert(a_key != -1);
4471 ir_assert(b_key != -1);
4472 return a_key < b_key;
4473 });
4474
4475 int block_idx = -1;
4476 for (int i = 0; i < nloops; i++) {
4477 auto &l = loops[i];
4478 if (l.bmnk_kind == block_bmnk_kind) {
4479 ir_assert(block_idx == -1) << "Can't block 2+ dimensions.";
4480 block_idx = i;
4481 }
4482 }
4483
4484 // Couldn't find N dimension, try different blocking scheme.
4485 if (block_idx == -1) return false;
4486
4487 auto &block_loop = loops[block_idx];
4488
4489 int block = simd_size_;
4490 while (block >= 1) {
4491 if (block_loop.dim % block == 0) break;
4492 block /= 2;
4493 }
4494
4495 ir_assert(block >= 1) << "Invalid block size.";
4496 block_loop.block = block;
4497
4498 int a_stride = 0;
4499 int b_stride = 0;
4500
4501 // Ensure that A tile is dense.
4502 if (block_loop.a_idx != -1) {
4503 std::vector<dim_t> tile_dims(a_view_.nvdims(), 1);
4504 tile_dims[block_loop.a_idx] = block;
4505 auto layout = a_view_.create_pseudo_vlayout();
4506 auto tile = layout.map(tensor_t(tile_dims));
4507 if (!is_1d_strided(tile)) return false;
4508 a_stride = tile.blocks()[0].stride;
4509 }
4510
4511 // Ensure that B tile is dense.
4512 if (block_loop.b_idx != -1) {
4513 std::vector<dim_t> tile_dims(b_view_.nvdims(), 1);
4514 tile_dims[block_loop.b_idx] = block;
4515 auto layout = b_view_.create_pseudo_vlayout();
4516 auto tile = layout.map(tensor_t(tile_dims));
4517 if (!is_1d_strided(tile)) return false;
4518 b_stride = tile.blocks()[0].stride;
4519 }
4520
4521 build_mad(loops, block_loop, a_stride, b_stride);
4522 return true;
4523 }
4524
is_1d_strided(const layout_t & layout)4525 static bool is_1d_strided(const layout_t &layout) {
4526 auto &blocks = layout.blocks();
4527 if (blocks.size() > 1) return false;
4528 return true;
4529 }
4530
build_mad(const std::vector<loop_info_t> & loops,const loop_info_t & block_loop,int a_stride,int b_stride)4531 void build_mad(const std::vector<loop_info_t> &loops,
4532 const loop_info_t &block_loop, int a_stride, int b_stride) {
4533 ir_assert(utils::one_of(
4534 block_loop.bmnk_kind, bmnk_kind_t::b, bmnk_kind_t::n))
4535 << "Unsupported blocking (expected blocking by B or N).";
4536
4537 auto &a_type = a_view_.type();
4538 auto &b_type = b_view_.type();
4539 auto c_type = multiply_desc_t::get_c_type(a_type, b_type,
4540 /*force_c_upconvert=*/false);
4541
4542 int block = block_loop.block;
4543 auto _mad = mad_t::make(
4544 c_type, block, a_type, a_stride, b_type, b_stride);
4545 auto &mad = _mad.as<mad_t>();
4546
4547 c_layout_ = compute_mad_c_layout(c_type, loops, block_loop);
4548
4549 int nloops = int(loops.size());
4550 std::vector<int> bounds(loops.size());
4551 for (int i = 0; i < nloops; i++) {
4552 bounds[i] = loops[i].dim / loops[i].block;
4553 }
4554 std::vector<int> a_idx(a_view_.nvdims());
4555 std::vector<int> b_idx(b_view_.nvdims());
4556 std::vector<int> c_idx(c_layout_.ndims());
4557 ir_utils::for_each(bounds, [&](const std::vector<int> &idx) {
4558 for (int i = 0; i < nloops; i++) {
4559 int full_idx = idx[i] * loops[i].block;
4560 auto &loop = loops[i];
4561 if (loop.a_idx != -1) a_idx[loop.a_idx] = full_idx;
4562 if (loop.b_idx != -1) b_idx[loop.b_idx] = full_idx;
4563 if (loop.c_idx != -1) c_idx[loop.c_idx] = full_idx;
4564 }
4565 int a_off = a_view_(a_idx) * a_type.size();
4566 int b_off = b_view_(b_idx) * b_type.size();
4567 int c_off = c_layout_(c_idx) * c_type.size();
4568 stmt_ = stmt_.append(mad(c_buf_[c_off], c_buf_[c_off],
4569 a_buf_[a_off], b_buf_[b_off]));
4570 });
4571 }
4572
compute_mad_c_layout(const type_t & c_type,const std::vector<loop_info_t> & loops,const loop_info_t & block_loop) const4573 layout_t compute_mad_c_layout(const type_t &c_type,
4574 const std::vector<loop_info_t> &loops,
4575 const loop_info_t &block_loop) const {
4576 layout_t c_layout(c_type, bmnk_mapper_.ndims(abc_kind_t::c), 0, {});
4577
4578 int c_dim_idx = bmnk_mapper_.dim_idx(abc_kind_t::c, block_loop.var);
4579 c_layout = c_layout.add_outer_block(c_dim_idx, block_loop.block);
4580
4581 for (size_t i = 0; i < loops.size(); i++) {
4582 if (loops[i].bmnk_kind == bmnk_kind_t::k) continue;
4583 int dim_idx = bmnk_mapper_.dim_idx(abc_kind_t::c, loops[i].var);
4584 int bound = loops[i].dim / loops[i].block;
4585 c_layout = c_layout.add_outer_block(dim_idx, bound);
4586 }
4587 return c_layout;
4588 }
4589
4590 ngen::HW hw_;
4591 int simd_size_;
4592 bmnk_mapper_t bmnk_mapper_;
4593
4594 bool do_transpose_ = false;
4595
4596 view_t a_view_;
4597 view_t b_view_;
4598 layout_t c_layout_;
4599
4600 expr_t a_buf_;
4601 expr_t b_buf_;
4602 expr_t c_buf_;
4603
4604 stmt_t stmt_;
4605 };
4606
get_fma_friendly_layout(abc_kind_t abc_kind,int simd_size,const layout_t & bmnk_layout,const type_t & a_type,const type_t & b_type)4607 layout_t get_fma_friendly_layout(abc_kind_t abc_kind, int simd_size,
4608 const layout_t &bmnk_layout, const type_t &a_type,
4609 const type_t &b_type) {
4610 bool is_a = (abc_kind == abc_kind_t::a);
4611 int mn_idx = (is_a ? 0 : 1);
4612 int k_idx = (is_a ? 1 : 0);
4613
4614 dim_t mn_blk = bmnk_layout.dim(mn_idx);
4615 dim_t k_blk = bmnk_layout.dim(k_idx);
4616
4617 // Cannot calculate correct r_count when !is_a, but rcount is effectively
4618 // ignored in that case as rcount mainly effects b_layout.
4619 int rcount = is_a && mn_blk < 8 ? utils::rnd_up_pow2(mn_blk) : 8;
4620 auto _dpas = dpas_t::make(/*is_dpasw=*/false, simd_size, /*sdepth=*/8,
4621 rcount, type_t::undef(), b_type, a_type);
4622 auto &dpas = _dpas.as<dpas_t>();
4623
4624 auto dpas_layout = (is_a ? dpas.b_layout() : dpas.a_layout());
4625 dpas_layout = dpas_layout.transpose();
4626
4627 ir_assert(dpas_layout.dim(k_idx) == k_blk);
4628 MAYBE_UNUSED(k_blk);
4629
4630 dim_t dpas_mn_blk = dpas_layout.dim(mn_idx);
4631 dpas_layout = dpas_layout.add_outer_block(mn_idx, mn_blk / dpas_mn_blk);
4632
4633 return dpas_layout;
4634 }
4635
convert_to_fma_friendly_type(const conv_config_t & cfg,abc_kind_t abc_kind,const layout_t & layout,const type_t & a_type,const type_t & b_type,bool * changed=nullptr)4636 layout_t convert_to_fma_friendly_type(const conv_config_t &cfg,
4637 abc_kind_t abc_kind, const layout_t &layout, const type_t &a_type,
4638 const type_t &b_type, bool *changed = nullptr) {
4639 if (changed) *changed = false;
4640 if (cfg.fma_kind != fma_kind_t::mad) return layout;
4641
4642 if (a_type.is_x8() && b_type.is_x8()) {
4643 if (changed) *changed = true;
4644 return layout.retype(type_t::s16()).make_strided(2);
4645 }
4646 // f16/bf16 mixed mode mad requires src2 to be f32
4647 if (a_type.is_bf16() || b_type.is_bf16()
4648 || (a_type.is_f32() && b_type.is_f16())
4649 || (a_type.is_f16() && b_type.is_f32())) {
4650 if (changed) *changed = true;
4651 return layout.retype(type_t::f32());
4652 }
4653 return layout;
4654 }
4655
convert_to_fma_friendly_layout(const conv_config_t & cfg,abc_kind_t abc_kind,const bmnk_mapper_t & bmnk_mapper,const layout_t & layout,const type_t & a_type,const type_t & b_type,bool * changed=nullptr)4656 layout_t convert_to_fma_friendly_layout(const conv_config_t &cfg,
4657 abc_kind_t abc_kind, const bmnk_mapper_t &bmnk_mapper,
4658 const layout_t &layout, const type_t &a_type, const type_t &b_type,
4659 bool *changed = nullptr) {
4660 if (changed) *changed = false;
4661 if (!cfg.allow_grf_reorder) return layout;
4662
4663 // GRF reorder is only supported with dpas/dpasw.
4664 if (!utils::one_of(cfg.fma_kind, fma_kind_t::dpas, fma_kind_t::dpasw)) {
4665 // mad may require type conversion.
4666 return convert_to_fma_friendly_type(
4667 cfg, abc_kind, layout, a_type, b_type, changed);
4668 }
4669
4670 std::vector<bmnk_kind_t> bmnk_kinds;
4671 if (abc_kind == abc_kind_t::a) {
4672 bmnk_kinds.push_back(bmnk_kind_t::m);
4673 bmnk_kinds.push_back(bmnk_kind_t::k);
4674 } else {
4675 bmnk_kinds.push_back(bmnk_kind_t::k);
4676 bmnk_kinds.push_back(bmnk_kind_t::n);
4677 }
4678
4679 auto bmnk_layout = bmnk_mapper.map_to_bmnk(abc_kind, bmnk_kinds, layout);
4680
4681 auto dpas_layout = get_fma_friendly_layout(
4682 abc_kind, cfg.simd_size, bmnk_layout, a_type, b_type);
4683 if (dpas_layout == bmnk_layout) return layout;
4684
4685 if (changed) *changed = true;
4686
4687 bmnk_block_mapper_t from_bmnk_mapper(bmnk_mapper);
4688 from_bmnk_mapper.push_blocks(abc_kind, layout.blocks());
4689
4690 auto fma_layout
4691 = from_bmnk_mapper.map_from_bmnk(abc_kind, bmnk_kinds, dpas_layout);
4692 fma_layout = fma_layout.make_dense();
4693 return fma_layout;
4694 }
4695
4696 class b_reduce_context_t {
4697 public:
b_reduce_context_t(const conv_config_t & cfg)4698 b_reduce_context_t(const conv_config_t &cfg)
4699 : cfg_(cfg), reduce_condition_(true) {
4700 if (cfg.do_b_reduction) b_reduced_reg_buf_ = make_buffer("b_reduced");
4701 }
4702
4703 // Setters for B reduced memory buffer/view.
set_b_reduced_mem_buf(const expr_t & buf)4704 void set_b_reduced_mem_buf(const expr_t &buf) { b_reduced_mem_buf_ = buf; }
set_b_reduced_view(const view_t & v)4705 void set_b_reduced_view(const view_t &v) { b_reduced_view_ = v; }
4706
4707 // Sets the condition to update B reduced output. Reduction is done across
4708 // K for B (KxN tensor) so M dimension should be checked before the update.
set_reduce_condition(const expr_t & cond)4709 void set_reduce_condition(const expr_t &cond) { reduce_condition_ = cond; }
4710
4711 // Global memory buffer.
b_reduced_mem_buf() const4712 const expr_t &b_reduced_mem_buf() const { return b_reduced_mem_buf_; }
4713
4714 // Register buffer.
b_reduced_reg_buf() const4715 const expr_t &b_reduced_reg_buf() const { return b_reduced_reg_buf_; }
b_reduced_size() const4716 int b_reduced_size() const { return b_reduced_size_; }
4717
4718 // Memory view.
b_reduced_thr_view() const4719 const view_t &b_reduced_thr_view() const { return b_reduced_thr_view_; }
4720
4721 // Register layout.
b_reduced_reg_layout() const4722 const layout_t &b_reduced_reg_layout() const {
4723 return b_reduced_reg_layout_;
4724 }
4725
init_reduced_thr_view(const tensor_t & b_thr_tile,const expr_t & cond=expr_t ())4726 void init_reduced_thr_view(
4727 const tensor_t &b_thr_tile, const expr_t &cond = expr_t()) {
4728 ir_assert(b_reduced_thr_view_.is_empty()) << "Can't initialize twice.";
4729
4730 auto b_reduced_thr_tile = b_to_b_reduced_tile(b_thr_tile);
4731 b_reduced_thr_view_
4732 = b_reduced_view_.create_sub_view(b_reduced_thr_tile);
4733 b_reduced_reg_layout_ = b_reduced_thr_view_.create_dense_vlayout();
4734 b_reduced_size_ = b_reduced_reg_layout_.size();
4735 b_reduced_size_ = utils::rnd_up(b_reduced_size_, cfg_.grf_size());
4736
4737 if (!cond.is_empty()) reduce_condition_ &= cond;
4738 }
4739
create_reduce_stmt(const layout_t & b_layout,const expr_t & b_buf,const tensor_t & sub_tile=tensor_t ())4740 stmt_t create_reduce_stmt(const layout_t &b_layout, const expr_t &b_buf,
4741 const tensor_t &sub_tile = tensor_t()) {
4742 auto reduction_stmt
4743 = jit::create_reduce_stmt(b_layout, b_reduced_reg_layout_,
4744 b_buf, b_reduced_reg_buf_, sub_tile, (1 << 1));
4745 return reduction_stmt;
4746 }
4747
create_store_stmt(ir_context_t & ir_ctx,const constraint_set_t & cset) const4748 stmt_t create_store_stmt(
4749 ir_context_t &ir_ctx, const constraint_set_t &cset) const {
4750 write_builder_t r2g(cfg_.hw, ir_ctx, cset, b_reduced_thr_view_,
4751 b_reduced_mem_buf_, b_reduced_reg_buf_, /*is_slm=*/false,
4752 ngen_proxy::AtomicOp::fadd);
4753 // TODO: Check that layouts match.
4754 auto ret = r2g.stmt();
4755 if (!reduce_condition_.is_empty()) {
4756 ret = if_t::make(reduce_condition_, ret);
4757 }
4758 return ret;
4759 }
4760
4761 private:
b_to_b_reduced_tile(const tensor_t & b_tile) const4762 tensor_t b_to_b_reduced_tile(const tensor_t &b_tile) const {
4763 std::vector<dim_t> dims;
4764 std::vector<expr_t> start;
4765 for (int i = 0; i < b_tile.ndims(); i++) {
4766 if ((reduction_mask_ & (1 << i)) != 0) {
4767 dims.push_back(b_tile(i));
4768 start.push_back(b_tile.start(i));
4769 }
4770 }
4771 return tensor_t(dims, start);
4772 }
4773
4774 const conv_config_t &cfg_;
4775
4776 expr_t reduce_condition_;
4777
4778 expr_t b_reduced_mem_buf_;
4779 expr_t b_reduced_reg_buf_;
4780
4781 view_t b_reduced_view_;
4782 view_t b_reduced_thr_view_;
4783
4784 layout_t b_reduced_reg_layout_;
4785 int b_reduced_size_ = 0;
4786
4787 uint32_t reduction_mask_ = (1 << 1);
4788 };
4789
4790 class load_multiply_builder_t {
4791 public:
load_multiply_builder_t(const conv_config_t & cfg,ir_context_t & ir_ctx,const constraint_set_t & cset,const gemm_schedule_t & gemm_schedule,b_reduce_context_t & b_reduce_ctx,const expr_t & ap_buf,const expr_t & a_slm_buf,const expr_t & bp_buf,const expr_t & b_slm_buf,const view_t & ap_x_view,const view_t & bp_x_view)4792 load_multiply_builder_t(const conv_config_t &cfg, ir_context_t &ir_ctx,
4793 const constraint_set_t &cset, const gemm_schedule_t &gemm_schedule,
4794 b_reduce_context_t &b_reduce_ctx, const expr_t &ap_buf,
4795 const expr_t &a_slm_buf, const expr_t &bp_buf,
4796 const expr_t &b_slm_buf, const view_t &ap_x_view,
4797 const view_t &bp_x_view)
4798 : cfg_(cfg)
4799 , ir_ctx_(ir_ctx)
4800 , cset_(cset)
4801 , gemm_schedule_(gemm_schedule)
4802 , b_reduce_ctx_(b_reduce_ctx)
4803 , ap_buf_(ap_buf)
4804 , a_slm_buf_(a_slm_buf)
4805 , bp_buf_(bp_buf)
4806 , b_slm_buf_(b_slm_buf) {
4807 ir_assert(cfg_.a_sub_tiles == 1 || cfg_.b_sub_tiles == 1)
4808 << "At most one tensor can be tiled.";
4809
4810 ab_tmp_buf_ = make_buffer("ab_tmp");
4811 a_buf_ = make_buffer("a");
4812 b_buf_ = make_buffer("b");
4813 c_buf_ = make_buffer("c");
4814
4815 // Views to multiply by a thread.
4816 a_thr_view_ = ap_x_view.create_sub_view(gemm_schedule_.a_thr_tile());
4817 b_thr_view_ = bp_x_view.create_sub_view(gemm_schedule_.b_thr_tile());
4818
4819 // Initialize view for reduced B.
4820 if (cfg_.do_b_reduction && !cfg_.use_b_slm) {
4821 b_reduce_ctx_.init_reduced_thr_view(
4822 gemm_schedule_.b_thr_tile(/*is_relative=*/false));
4823 }
4824
4825 // TODO: Specify loops over sub-tiles in the schedule, use unrolling.
4826 // Sub-tile indices.
4827 a_idx_ = ir_ctx_.create_tmp_var(type_t::s32(), "a_idx");
4828 b_idx_ = ir_ctx_.create_tmp_var(type_t::s32(), "b_idx");
4829
4830 // Sub-tile views.
4831 a_i_view_ = create_sub_tile_view(abc_kind_t::a, a_thr_view_,
4832 cfg_.a_sub_tiles, a_idx_, bmnk_kind_t::m, &a_i_outer_blocks_,
4833 a_i_tile_);
4834 b_j_view_ = create_sub_tile_view(abc_kind_t::b, b_thr_view_,
4835 cfg_.b_sub_tiles, b_idx_, bmnk_kind_t::n, &b_j_outer_blocks_,
4836 b_j_tile_);
4837
4838 build();
4839 }
4840
allocs() const4841 const std::vector<stmt_t> &allocs() const { return allocs_; }
4842
load_mul_stmt() const4843 const stmt_t &load_mul_stmt() const { return load_mul_stmt_; }
4844
c_buf() const4845 const expr_t &c_buf() const { return c_buf_; }
4846
c_reg_layout() const4847 const layout_t &c_reg_layout() const { return c_reg_layout_; }
4848
c_attr() const4849 const alloc_attr_t &c_attr() const { return c_attr_; }
4850
4851 private:
4852 struct sub_tile_info_t {
4853 bool is_loaded = false;
4854 view_t reg_view;
4855 int reg_buf_size;
4856 };
4857
create_sub_tile_view(abc_kind_t abc_kind,const view_t & thr_view,int sub_tiles,const expr_t & idx,bmnk_kind_t bmnk_kind,std::vector<block_t> * outer_blocks,tensor_t & sub_tile) const4858 view_t create_sub_tile_view(abc_kind_t abc_kind, const view_t &thr_view,
4859 int sub_tiles, const expr_t &idx, bmnk_kind_t bmnk_kind,
4860 std::vector<block_t> *outer_blocks, tensor_t &sub_tile) const {
4861 auto &bmnk_mapper = gemm_schedule_.bmnk_mapper();
4862 auto layout = thr_view.create_pseudo_vlayout();
4863 dim_t mn_dim = 1;
4864 for (auto &b : layout.blocks()) {
4865 auto b_bmnk_kind = bmnk_mapper.bmnk_kind(abc_kind, b.dim_idx);
4866 if (b_bmnk_kind == bmnk_kind) mn_dim *= b.block;
4867 }
4868
4869 std::vector<dim_t> sub_tile_dims(thr_view.nvdims(), 1);
4870 dim_t mn_sub_tile_dim = ir_utils::safe_divide(mn_dim, dim_t(sub_tiles));
4871 for (auto &b : layout.blocks()) {
4872 auto b_bmnk_kind = bmnk_mapper.bmnk_kind(abc_kind, b.dim_idx);
4873 if (b_bmnk_kind == bmnk_kind) {
4874 if (mn_sub_tile_dim == 1) continue;
4875 dim_t next_block;
4876 if (mn_sub_tile_dim % b.block == 0) {
4877 next_block = b.block;
4878 } else {
4879 next_block
4880 = ir_utils::safe_divide(b.block, mn_sub_tile_dim);
4881 }
4882 sub_tile_dims[b.dim_idx] *= next_block;
4883 mn_sub_tile_dim /= next_block;
4884 } else {
4885 sub_tile_dims[b.dim_idx] *= b.block;
4886 }
4887 }
4888 grid_info_t grid({sub_tiles}, {idx});
4889 sub_tile = layout.split(tensor_t(sub_tile_dims), grid, outer_blocks);
4890 return thr_view.create_sub_view(sub_tile);
4891 }
4892
a_type() const4893 const type_t &a_type() const { return a_i_view_.type(); }
b_type() const4894 const type_t &b_type() const { return b_j_view_.type(); }
4895
build()4896 void build() {
4897 a_sub_tiles_.resize(cfg_.a_sub_tiles);
4898 b_sub_tiles_.resize(cfg_.b_sub_tiles);
4899 for (int i = 0; i < cfg_.a_sub_tiles; i++) {
4900 for (int j = 0; j < cfg_.b_sub_tiles; j++) {
4901 build_sub_tile(i, j);
4902 }
4903 }
4904
4905 if (tmp_buf_size_ > 0) {
4906 register_buffer(ab_tmp_buf_, tmp_buf_size_, alloc_kind_t::grf);
4907 }
4908
4909 // C layout in problem notation.
4910 auto c_layout = c_sub_tile_layout_;
4911
4912 // Add outer blocks coming from A/B sub-tiles.
4913 auto &bmnk_mapper = gemm_schedule_.bmnk_mapper();
4914 for (auto &b : a_i_outer_blocks_) {
4915 auto &var = bmnk_mapper.var(abc_kind_t::a, b.dim_idx);
4916 int c_dim_idx = bmnk_mapper.dim_idx(abc_kind_t::c, var);
4917 c_layout = c_layout.add_outer_block(c_dim_idx, b.block);
4918 }
4919 for (auto &b : b_j_outer_blocks_) {
4920 auto &var = bmnk_mapper.var(abc_kind_t::b, b.dim_idx);
4921 int c_dim_idx = bmnk_mapper.dim_idx(abc_kind_t::c, var);
4922 c_layout = c_layout.add_outer_block(c_dim_idx, b.block);
4923 }
4924
4925 c_reg_layout_ = c_layout;
4926 }
4927
build_sub_tile(int i,int j)4928 void build_sub_tile(int i, int j) {
4929 bool is_first = (i == 0 && j == 0);
4930
4931 stmt_t ab_s2r_load;
4932 stmt_t ab_g2r_load;
4933 load_sub_tile(abc_kind_t::a, i, ab_s2r_load, ab_g2r_load);
4934 load_sub_tile(abc_kind_t::b, j, ab_s2r_load, ab_g2r_load);
4935
4936 load_mul_stmt_ = load_mul_stmt_.append(
4937 stmt_group_t::make(stmt_label_t::g2r_load(i + j), ab_g2r_load));
4938 load_mul_stmt_ = load_mul_stmt_.append(
4939 stmt_group_t::make(stmt_label_t::s2r_load(i + j), ab_s2r_load));
4940
4941 auto &a_i_view = a_sub_tiles_[i].reg_view;
4942 auto &b_j_view = b_sub_tiles_[j].reg_view;
4943
4944 // Multiply C_i_j += A_i x B_j in GEMM notation.
4945 multiply_builder_t mul_builder(cfg_, gemm_schedule_.bmnk_mapper(),
4946 a_i_view, b_j_view, a_buf_, b_buf_, c_buf_[c_buf_off_]);
4947 c_sub_tile_layout_ = mul_builder.c_layout();
4948 c_buf_off_ += c_sub_tile_layout_.size();
4949 ir_trace() << "Multiply (" << i << ", " << j << "):\n"
4950 << mul_builder.str() << std::endl;
4951
4952 load_mul_stmt_ = load_mul_stmt_.append(stmt_group_t::make(
4953 stmt_label_t::mul(i + j), mul_builder.stmt()));
4954
4955 if (!is_first) {
4956 ir_assert(mul_builder.c_layout() == c_sub_tile_layout_)
4957 << "Sub-tile layouts must be equal.";
4958 return;
4959 }
4960
4961 c_attr_ = grf_alloc_attr_t::make(mul_builder.c_grf_bundle());
4962
4963 auto a_attr = grf_alloc_attr_t::make(mul_builder.a_grf_bundle());
4964 register_buffer(a_buf_, a_sub_tiles_[i].reg_buf_size, alloc_kind_t::grf,
4965 a_attr);
4966
4967 auto b_attr = grf_alloc_attr_t::make(mul_builder.b_grf_bundle());
4968 register_buffer(b_buf_, b_sub_tiles_[j].reg_buf_size, alloc_kind_t::grf,
4969 b_attr);
4970 }
4971
4972 // Loads A_i or B_j sub-tile.
load_sub_tile(abc_kind_t abc_kind,int i,stmt_t & ab_s2r_load,stmt_t & ab_g2r_load)4973 void load_sub_tile(abc_kind_t abc_kind, int i, stmt_t &ab_s2r_load,
4974 stmt_t &ab_g2r_load) {
4975 bool is_a = (abc_kind == abc_kind_t::a);
4976 auto &info = (is_a ? a_sub_tiles_[i] : b_sub_tiles_[i]);
4977 if (info.is_loaded) return;
4978
4979 auto &bmnk_mapper = gemm_schedule_.bmnk_mapper();
4980
4981 auto &x_view = (is_a ? a_i_view_ : b_j_view_);
4982 auto &x_tile = (is_a ? a_i_tile_ : b_j_tile_);
4983 auto &x_idx = (is_a ? a_idx_ : b_idx_);
4984
4985 auto view = x_view.substitute(x_idx, i);
4986 auto tile = x_tile.substitute(x_idx, i);
4987
4988 bool use_x_slm = (is_a ? cfg_.use_a_slm : cfg_.use_b_slm);
4989 auto &x_slm_buf = (is_a ? a_slm_buf_ : b_slm_buf_);
4990 auto &x_gmem_buf = (is_a ? ap_buf_ : bp_buf_);
4991 auto &x_buf = (use_x_slm ? x_slm_buf : x_gmem_buf);
4992 auto &x_reg_buf = (is_a ? a_buf_ : b_buf_);
4993
4994 layout_t load_layout;
4995 view_t reg_view;
4996 stmt_t stmt;
4997 load_sub_tile_impl(abc_kind, i, view, x_buf, x_reg_buf, use_x_slm,
4998 load_layout, reg_view, stmt);
4999
5000 auto reg_layout = load_layout;
5001
5002 if (!is_a && cfg_.do_b_reduction && !cfg_.use_b_slm) {
5003 auto reduce_stmt = b_reduce_ctx_.create_reduce_stmt(
5004 reg_layout, b_buf_, tile);
5005 stmt = stmt.append(reduce_stmt);
5006 }
5007
5008 bool changed;
5009 auto fma_layout = convert_to_fma_friendly_layout(cfg_, abc_kind,
5010 bmnk_mapper, reg_layout, a_type(), b_type(), &changed);
5011
5012 if (changed) {
5013 if (fma_layout.type() != reg_layout.type()) {
5014 reg_view = reg_view.retype(fma_layout.type());
5015 }
5016 reg_layout = fma_layout;
5017 reg_view.set_tlayout(reg_layout);
5018 stmt = substitute(stmt, x_reg_buf, ab_tmp_buf_);
5019 stmt = stmt.append(create_reorder_stmt(
5020 load_layout, reg_layout, ab_tmp_buf_, x_reg_buf));
5021 tmp_buf_size_ = std::max(tmp_buf_size_, int(load_layout.size()));
5022 }
5023
5024 if (use_x_slm) {
5025 ab_s2r_load = ab_s2r_load.append(stmt);
5026 } else {
5027 ab_g2r_load = ab_g2r_load.append(stmt);
5028 }
5029 info.is_loaded = true;
5030 info.reg_view = reg_view;
5031 info.reg_buf_size = reg_layout.size();
5032 }
5033
load_sub_tile_impl(abc_kind_t abc_kind,int sub_tile_idx,const view_t & _mem_view,const expr_t & buf,const expr_t & reg_buf,bool is_slm,layout_t & reg_layout,view_t & reg_view,stmt_t & stmt)5034 void load_sub_tile_impl(abc_kind_t abc_kind, int sub_tile_idx,
5035 const view_t &_mem_view, const expr_t &buf, const expr_t ®_buf,
5036 bool is_slm, layout_t ®_layout, view_t ®_view, stmt_t &stmt) {
5037 bool is_a = (abc_kind == abc_kind_t::a);
5038
5039 view_t mem_view;
5040 bool load_buffered = false;
5041
5042 // Using buffered view is enabled only when:
5043 // - Loading directly from global memory
5044 // - FMA kind is mad (dpas implementation is more strict and requires
5045 // layouts, not views)
5046 // - Loading A tensor (A - activations for FWD/BWD_D where we may have
5047 // overlapping when applying KW blocking )
5048 if (!is_slm && is_a && cfg_.fma_kind == fma_kind_t::mad) {
5049 load_buffered
5050 = _mem_view.try_create_buffer_view(mem_view, reg_view);
5051 }
5052
5053 if (!load_buffered) mem_view = _mem_view;
5054
5055 read_builder_t read(
5056 cfg_.hw, ir_ctx_, cset_, mem_view, buf, reg_buf, is_slm);
5057 ir_trace() << (is_a ? "A" : "B") << " GMEM/SLM to GRF load #"
5058 << sub_tile_idx << ":\n"
5059 << read.str() << std::endl;
5060
5061 if (load_buffered) {
5062 reg_view.set_tlayout(read.reg_layout());
5063 } else {
5064 reg_view = view_t(read.reg_layout());
5065 }
5066
5067 reg_layout = read.reg_layout();
5068 stmt = read.stmt();
5069 }
5070
register_buffer(const stmt_t & alloc)5071 void register_buffer(const stmt_t &alloc) {
5072 ir_assert(alloc.is<alloc_t>());
5073 allocs_.push_back(alloc);
5074 }
5075
register_buffer(const expr_t & buf,int size,alloc_kind_t kind,const alloc_attr_t & attr={})5076 void register_buffer(const expr_t &buf, int size, alloc_kind_t kind,
5077 const alloc_attr_t &attr = {}) {
5078 register_buffer(alloc_t::make(buf, size, kind, attr));
5079 }
5080
5081 const conv_config_t &cfg_;
5082 ir_context_t ir_ctx_;
5083 const constraint_set_t &cset_;
5084 const gemm_schedule_t &gemm_schedule_;
5085 b_reduce_context_t &b_reduce_ctx_;
5086
5087 expr_t ap_buf_;
5088 expr_t a_slm_buf_;
5089
5090 expr_t bp_buf_;
5091 expr_t b_slm_buf_;
5092
5093 layout_t c_reg_layout_;
5094
5095 expr_t ab_tmp_buf_;
5096 expr_t a_buf_;
5097 expr_t b_buf_;
5098 expr_t c_buf_;
5099
5100 int tmp_buf_size_ = 0;
5101
5102 // Per-thread views to multiply.
5103 view_t a_thr_view_;
5104 view_t b_thr_view_;
5105
5106 // Sub-tile indices.
5107 expr_t a_idx_;
5108 expr_t b_idx_;
5109
5110 // Sub-tile views.
5111 view_t a_i_view_;
5112 view_t b_j_view_;
5113
5114 tensor_t a_i_tile_;
5115 tensor_t b_j_tile_;
5116
5117 std::vector<sub_tile_info_t> a_sub_tiles_;
5118 std::vector<sub_tile_info_t> b_sub_tiles_;
5119
5120 std::vector<block_t> a_i_outer_blocks_;
5121 std::vector<block_t> b_j_outer_blocks_;
5122
5123 std::vector<stmt_t> allocs_;
5124
5125 stmt_t load_mul_stmt_;
5126
5127 int c_buf_off_ = 0;
5128 layout_t c_sub_tile_layout_;
5129 alloc_attr_t c_attr_;
5130 };
5131
5132 class slm_reduce_builder_t {
5133 public:
slm_reduce_builder_t(const conv_config_t & cfg,ir_context_t & ir_ctx,const constraint_set_t & cset,const grid_info_t & tg_grid,const expr_t & reg_buf,const layout_t & reg_layout,const tensor_t & thr_tile)5134 slm_reduce_builder_t(const conv_config_t &cfg, ir_context_t &ir_ctx,
5135 const constraint_set_t &cset, const grid_info_t &tg_grid,
5136 const expr_t ®_buf, const layout_t ®_layout,
5137 const tensor_t &thr_tile)
5138 : cfg_(cfg)
5139 , ir_ctx_(ir_ctx)
5140 , cset_(cset)
5141 , tg_grid_(tg_grid)
5142 , reg_buf_(reg_buf)
5143 , reg_layout_(reg_layout)
5144 , thr_tile_(thr_tile) {
5145 ir_assert(tg_grid_.dim(2) > 1);
5146
5147 tmp_reg_buf_ = ir_ctx_.create_tmp_var(type_t::byte_ptr());
5148 slm_buf_ = make_buffer("reduce_slm");
5149
5150 build();
5151 }
5152
reg_layout() const5153 layout_t reg_layout() const { return reg_layout_; }
5154
thr_tile() const5155 tensor_t thr_tile() const { return thr_tile_; }
5156
stmt() const5157 stmt_t stmt() const { return stmt_; }
5158
5159 private:
build()5160 void build() {
5161 int ndims = reg_layout_.ndims();
5162 int tg_ndims = tg_grid_.ndims();
5163
5164 // Create SLM layout to store all intermediate buffers from the thread
5165 // group.
5166 layout_t slm_layout(reg_layout_.type(), ndims + tg_ndims,
5167 reg_layout_.offset(), reg_layout_.blocks());
5168 for (int i = tg_ndims - 1; i >= 0; i--) {
5169 slm_layout = slm_layout.add_outer_block(ndims + i, tg_grid_.dim(i));
5170 }
5171
5172 slm_buf_size_ = slm_layout.size();
5173
5174 // Write thread tile to SLM.
5175 std::vector<dim_t> write_dims = reg_layout_.dims();
5176 std::vector<expr_t> write_start(ndims + tg_ndims, 0);
5177 write_dims.resize(ndims + tg_ndims, 1);
5178 for (int i = tg_ndims - 1; i >= 0; i--) {
5179 write_start[ndims + i] = tg_grid_.idx(i);
5180 }
5181 auto write_tile = tensor_t(write_dims, write_start);
5182 write_builder_t write(cfg_.hw, ir_ctx_, cset_,
5183 view_t(slm_layout.map(write_tile)), slm_buf_, reg_buf_,
5184 /*is_slm=*/true);
5185 stmt_ = stmt_.append(funcs::barrier());
5186 stmt_ = stmt_.append(write.stmt());
5187 stmt_ = stmt_.append(funcs::barrier());
5188
5189 auto &write_layout = write.reg_layout();
5190 ir_assert(write_layout == reg_layout_) << "Incompatible layouts.";
5191
5192 // Redistribute the layout to read/reduce all k-axis tiles from every
5193 // thread.
5194 auto local_thr_tile = reg_layout_.split(tg_grid_.sub_grid({2}));
5195 reg_layout_ = reg_layout_.map(tensor_t(local_thr_tile.dims()));
5196
5197 stmt_t reduce_stmt;
5198 std::vector<dim_t> read_dims(ndims + tg_ndims, 1);
5199 std::vector<expr_t> read_start(ndims + tg_ndims);
5200 for (int i = 0; i < ndims; i++) {
5201 read_dims[i] = local_thr_tile(i);
5202 read_start[i] = local_thr_tile.start(i);
5203 }
5204 read_start[ndims + 0] = tg_grid_.idx(0);
5205 read_start[ndims + 1] = tg_grid_.idx(1);
5206 for (int i = 0; i < tg_grid_.dim(2); i++) {
5207 read_start[ndims + 2] = i;
5208 tensor_t read_tile(read_dims, read_start);
5209 read_builder_t read(cfg_.hw, ir_ctx_, cset_,
5210 view_t(slm_layout.map(read_tile)), slm_buf_, tmp_reg_buf_,
5211 /*is_slm=*/true);
5212 reduce_stmt = reduce_stmt.append(read.stmt());
5213
5214 tmp_reg_buf_size_
5215 = std::max(tmp_reg_buf_size_, read.reg_buf_size());
5216 auto read_layout = read.reg_layout();
5217 for (int j = 0; j < 2; j++)
5218 ir_assert(read_layout.dim(ndims + j) == 1);
5219 read_layout = layout_t(read_layout.type(), ndims + 1,
5220 read_layout.offset(), read_layout.blocks());
5221 reduce_stmt = reduce_stmt.append(
5222 create_reduce_stmt(read_layout, reg_layout_, tmp_reg_buf_,
5223 reg_buf_, tensor_t(), reduction_mask()));
5224 }
5225
5226 stmt_ = stmt_.append(
5227 create_zero_out_stmt(cfg_.hw, reg_buf_, reg_layout_.size()));
5228 stmt_ = stmt_.append(reduce_stmt);
5229
5230 stmt_ = alloc_t::make(
5231 slm_buf_, slm_buf_size_, alloc_kind_t::slm, {}, stmt_);
5232 stmt_ = alloc_t::make(
5233 tmp_reg_buf_, tmp_reg_buf_size_, alloc_kind_t::grf, {}, stmt_);
5234
5235 thr_tile_ = thr_tile_.create_sub_tensor(local_thr_tile);
5236 }
5237
reduction_mask() const5238 uint32_t reduction_mask() const {
5239 int k_dim_idx = reg_layout_.ndims();
5240 uint32_t mask = 0xFFFFFFFF;
5241 mask &= ~(1 << k_dim_idx);
5242 return mask;
5243 }
5244
5245 const conv_config_t &cfg_;
5246 ir_context_t &ir_ctx_;
5247 const constraint_set_t &cset_;
5248 grid_info_t tg_grid_;
5249
5250 expr_t reg_buf_;
5251 layout_t reg_layout_;
5252 tensor_t thr_tile_;
5253
5254 expr_t tmp_reg_buf_;
5255 int tmp_reg_buf_size_ = 0;
5256
5257 expr_t slm_buf_;
5258 int slm_buf_size_ = 0;
5259
5260 stmt_t stmt_;
5261 };
5262
5263 class compute_builder_t {
5264 public:
compute_builder_t(const conv_config_t & cfg,ir_context_t & ir_ctx,constraint_set_t & cset)5265 compute_builder_t(const conv_config_t &cfg, ir_context_t &ir_ctx,
5266 constraint_set_t &cset)
5267 : cfg_(cfg)
5268 , ir_ctx_(ir_ctx)
5269 , cset_(cset)
5270 , b_reduce_ctx_(cfg)
5271 , g2s_ctx_(ir_ctx) {}
5272
ab_slm_size() const5273 int ab_slm_size() const { return ab_slm_size_; }
5274
c_zero_out_stmt() const5275 const stmt_t &c_zero_out_stmt() const { return c_zero_out_stmt_; }
b_reduced_zero_out_stmt() const5276 const stmt_t &b_reduced_zero_out_stmt() const {
5277 return b_reduced_zero_out_stmt_;
5278 }
5279
zero_out_stmt() const5280 stmt_t zero_out_stmt() const {
5281 stmt_t ret;
5282 ret = ret.append(c_zero_out_stmt());
5283 ret = ret.append(b_reduced_zero_out_stmt());
5284 return ret;
5285 }
5286
iter_stmt() const5287 stmt_t iter_stmt() const {
5288 stmt_t stmt;
5289 bool use_prefetch = !prefetch_stmt_.is_empty();
5290 bool use_slm = !g2s_load_stmt_.is_empty();
5291 if (use_prefetch) {
5292 stmt = stmt.append(stmt_group_t::make(
5293 stmt_label_t::prefetch(), prefetch_stmt_));
5294 } else if (use_slm) {
5295 stmt = stmt.append(stmt_group_t::make(
5296 stmt_label_t::g2s_load(), g2s_load_stmt_));
5297 stmt = stmt.append(funcs::barrier());
5298 stmt = stmt.append(stmt_group_t::make(
5299 stmt_label_t::g2s_store(), g2s_store_stmt_));
5300 stmt = stmt.append(funcs::barrier());
5301 }
5302 stmt = stmt.append(load_mul_stmt_);
5303 return stmt;
5304 }
5305
c_store_stmt() const5306 const stmt_t &c_store_stmt() const { return c_store_stmt_; }
b_reduced_store_stmt() const5307 const stmt_t &b_reduced_store_stmt() const { return b_reduced_store_stmt_; }
5308
inject_compute_alloc_stmts(const stmt_t & stmt) const5309 stmt_t inject_compute_alloc_stmts(const stmt_t &stmt) const {
5310 return jit::inject_alloc_stmts(stmt, compute_allocs_);
5311 }
5312
inject_out_alloc_stmts(const stmt_t & stmt) const5313 stmt_t inject_out_alloc_stmts(const stmt_t &stmt) const {
5314 return jit::inject_alloc_stmts(stmt, out_allocs_);
5315 }
5316
inject_let_stmts(const stmt_t & stmt) const5317 stmt_t inject_let_stmts(const stmt_t &stmt) const {
5318 return jit::inject_let_stmts(stmt, g2s_ctx_.grid_idx_lets);
5319 }
5320
set_gemm_schedule(const gemm_schedule_t & gemm_schedule)5321 void set_gemm_schedule(const gemm_schedule_t &gemm_schedule) {
5322 gemm_schedule_ = gemm_schedule;
5323 }
5324
5325 // Setters for original AP/BP/CP buffers (P - problem notation).
set_ap_buf(const expr_t & buf)5326 void set_ap_buf(const expr_t &buf) { ap_buf_ = buf; }
set_bp_buf(const expr_t & buf)5327 void set_bp_buf(const expr_t &buf) { bp_buf_ = buf; }
set_cp_buf(const expr_t & buf)5328 void set_cp_buf(const expr_t &buf) { cp_buf_ = buf; }
set_b_reduced_mem_buf(const expr_t & buf)5329 void set_b_reduced_mem_buf(const expr_t &buf) {
5330 b_reduce_ctx_.set_b_reduced_mem_buf(buf);
5331 }
5332
set_b_reduced_view(const view_t & v)5333 void set_b_reduced_view(const view_t &v) {
5334 b_reduce_ctx_.set_b_reduced_view(v);
5335 }
5336
set_post_op_context(const post_op_context_t & post_op_ctx)5337 void set_post_op_context(const post_op_context_t &post_op_ctx) {
5338 post_op_ctx_ = post_op_ctx;
5339 }
5340
set_reduce_condition(const expr_t & cond)5341 void set_reduce_condition(const expr_t &cond) {
5342 b_reduce_ctx_.set_reduce_condition(cond);
5343 }
5344
build()5345 void build() {
5346 // Initialize SLM buffers.
5347 expr_t a_slm_buf = make_buffer("a_slm");
5348 expr_t b_slm_buf = make_buffer("b_slm");
5349
5350 view_t ap_gmem_view = gemm_schedule_.a_tg_view();
5351 view_t bp_gmem_view = gemm_schedule_.b_tg_view();
5352
5353 // Views to multiply by a thread group (either GMEM or SLM).
5354 view_t ap_x_view;
5355 view_t bp_x_view;
5356 prepare_gmem_to_slm("A", cfg_.use_a_slm, gemm_schedule_.a_tg_tile(),
5357 ap_gmem_view, ap_buf_, a_slm_buf, ap_x_view, g2s_ctx_);
5358 prepare_gmem_to_slm("B", cfg_.use_b_slm, gemm_schedule_.b_tg_tile(),
5359 bp_gmem_view, bp_buf_, b_slm_buf, bp_x_view, g2s_ctx_);
5360 prepare_prefetch("A", cfg_.use_prefetch, ap_gmem_view, ap_buf_);
5361 prepare_prefetch("B", cfg_.use_prefetch, bp_gmem_view, bp_buf_);
5362
5363 if (ap_x_view.is_empty()) ap_x_view = ap_gmem_view;
5364 if (bp_x_view.is_empty()) bp_x_view = bp_gmem_view;
5365
5366 for (auto &bi : g2s_ctx_.bufs) {
5367 register_compute_buffer(bi.buf, bi.size, alloc_kind_t::grf);
5368 }
5369
5370 load_multiply_builder_t load_mul_builder(cfg_, ir_ctx_, cset_,
5371 gemm_schedule_, b_reduce_ctx_, ap_buf_, a_slm_buf, bp_buf_,
5372 b_slm_buf, ap_x_view, bp_x_view);
5373
5374 load_mul_stmt_ = load_mul_builder.load_mul_stmt();
5375 compute_allocs_.insert(compute_allocs_.end(),
5376 load_mul_builder.allocs().begin(),
5377 load_mul_builder.allocs().end());
5378
5379 auto c_buf = load_mul_builder.c_buf();
5380 auto c_attr = load_mul_builder.c_attr();
5381 int c_size = load_mul_builder.c_reg_layout().size();
5382 register_out_buffer(c_buf, c_size, alloc_kind_t::grf, c_attr);
5383
5384 auto c_thr_reg_layout = load_mul_builder.c_reg_layout();
5385 auto thr_tile = gemm_schedule_.c_thr_tile(/*is_relative=*/false);
5386
5387 if (gemm_schedule_.with_thread_group_k_slicing()) {
5388 slm_reduce_builder_t slm_reduce_builder(cfg_, ir_ctx_, cset_,
5389 gemm_schedule_.tg_grid(), c_buf, c_thr_reg_layout,
5390 thr_tile);
5391 c_store_stmt_ = c_store_stmt_.append(slm_reduce_builder.stmt());
5392 c_thr_reg_layout = slm_reduce_builder.reg_layout();
5393 thr_tile = slm_reduce_builder.thr_tile();
5394 }
5395
5396 auto c_thr_mem_view = gemm_schedule_.c_view().create_sub_view(thr_tile);
5397 epilogue_builder_t c_m2g(cfg_, ir_ctx_, cset_, post_op_ctx_, thr_tile,
5398 c_thr_mem_view, c_thr_reg_layout, cp_buf_, c_buf);
5399 ir_trace() << "C GRF to GMEM store:\n" << c_m2g.stmt() << std::endl;
5400
5401 c_zero_out_stmt_ = stmt_group_t::make(stmt_label_t::c_zero_out(),
5402 create_zero_out_stmt(cfg_.hw, c_buf, c_size));
5403 c_store_stmt_ = c_store_stmt_.append(c_m2g.stmt());
5404
5405 if (cfg_.do_b_reduction) {
5406 auto &ctx = b_reduce_ctx_;
5407 b_reduced_zero_out_stmt_ = create_zero_out_stmt(
5408 cfg_.hw, ctx.b_reduced_reg_buf(), ctx.b_reduced_size());
5409 b_reduced_store_stmt_ = ctx.create_store_stmt(ir_ctx_, cset_);
5410 register_out_buffer(ctx.b_reduced_reg_buf(), ctx.b_reduced_size(),
5411 alloc_kind_t::grf);
5412 }
5413
5414 // Replace DPAS by DPASW when applicable.
5415 if (cfg_.fma_kind == fma_kind_t::dpasw) {
5416 alloc_updater_t alloc_updater;
5417 inject_dpasw(cfg_.hw, load_mul_stmt_, c_buf, c_store_stmt_,
5418 alloc_updater, gemm_schedule_.tg_grid().idx(0));
5419 for (auto &a : compute_allocs_) {
5420 a = alloc_updater.update(a);
5421 }
5422 for (auto &a : out_allocs_) {
5423 a = alloc_updater.update(a);
5424 }
5425 }
5426
5427 // Assign {Atomic} for DPAS(W) when applicable.
5428 load_mul_stmt_ = inject_atomic(load_mul_stmt_);
5429 }
5430
5431 private:
5432 struct buf_info_t {
buf_info_tdnnl::impl::gpu::jit::compute_builder_t::buf_info_t5433 buf_info_t(const std::string &tag, const expr_t &buf)
5434 : tag(tag), buf(buf) {}
5435
5436 std::string tag;
5437 expr_t buf;
5438 int size = 0;
5439 };
5440
5441 struct g2s_context_t {
g2s_context_tdnnl::impl::gpu::jit::compute_builder_t::g2s_context_t5442 g2s_context_t(ir_context_t &ir_ctx) : ir_ctx(ir_ctx) {}
5443
create_bufdnnl::impl::gpu::jit::compute_builder_t::g2s_context_t5444 expr_t create_buf(const char *tag, bool force_reuse = false) {
5445 if (reuse_buffers || force_reuse) {
5446 for (auto &bi : bufs) {
5447 if (bi.tag == tag) return bi.buf;
5448 }
5449 }
5450 auto buf = ir_ctx.create_tmp_var(type_t::byte_ptr(), tag);
5451 bufs.emplace_back(tag, buf);
5452 return buf;
5453 }
5454
set_buf_sizednnl::impl::gpu::jit::compute_builder_t::g2s_context_t5455 void set_buf_size(const expr_t &buf, int size) {
5456 for (auto &bi : bufs) {
5457 if (bi.buf.is_same(buf)) bi.size = std::max(bi.size, size);
5458 }
5459 }
5460
create_tmp_grid_idxdnnl::impl::gpu::jit::compute_builder_t::g2s_context_t5461 expr_t create_tmp_grid_idx() {
5462 auto var = ir_ctx.create_tmp_var(type_t::s32(), "idx");
5463 tmp_grid_idxs.insert({var, expr_t()});
5464 return var;
5465 }
5466
set_grid_idx_valuednnl::impl::gpu::jit::compute_builder_t::g2s_context_t5467 void set_grid_idx_value(const expr_t &idx, const expr_t &value) {
5468 auto &old = tmp_grid_idxs[idx];
5469 ir_assert(old.is_empty());
5470 old = substitute_grid_idx_value(value);
5471 }
5472
substitute_grid_idx_valuednnl::impl::gpu::jit::compute_builder_t::g2s_context_t5473 expr_t substitute_grid_idx_value(const expr_t &_e) {
5474 auto e = _e;
5475 auto vars = find_unique_objects<var_t>(e);
5476 for (auto &v : vars) {
5477 auto it = tmp_grid_idxs.find(v);
5478 if (it == tmp_grid_idxs.end()) continue;
5479 e = substitute(e, v, it->second);
5480 }
5481 return e;
5482 }
5483
register_griddnnl::impl::gpu::jit::compute_builder_t::g2s_context_t5484 void register_grid(const grid_info_t &grid) {
5485 for (int i = 0; i < grid.ndims(); i++) {
5486 auto &idx = grid.idx(i);
5487 auto it = tmp_grid_idxs.find(idx);
5488 if (it == tmp_grid_idxs.end()) continue;
5489 grid_idx_lets.emplace_back(let_t::make(idx, it->second));
5490 }
5491 }
5492
5493 ir_context_t &ir_ctx;
5494 grid_info_t prev_load_grid;
5495 bool reuse_buffers = false;
5496 std::vector<buf_info_t> bufs;
5497
5498 object_map_t<expr_t, expr_t> tmp_grid_idxs;
5499 std::vector<stmt_t> grid_idx_lets;
5500 };
5501
register_compute_buffer(const expr_t & buf,int size,alloc_kind_t kind,const alloc_attr_t & attr={})5502 void register_compute_buffer(const expr_t &buf, int size, alloc_kind_t kind,
5503 const alloc_attr_t &attr = {}) {
5504 compute_allocs_.push_back(alloc_t::make(buf, size, kind, attr));
5505 }
5506
register_out_buffer(const expr_t & buf,int size,alloc_kind_t kind,const alloc_attr_t & attr={})5507 void register_out_buffer(const expr_t &buf, int size, alloc_kind_t kind,
5508 const alloc_attr_t &attr = {}) {
5509 out_allocs_.push_back(alloc_t::make(buf, size, kind, attr));
5510 }
5511
5512 // Handles GMEM to SLM load for A and B. Done in two steps:
5513 // 1. Load: GMEM -> GRF (temporary)
5514 // 2. Store: GRF (temporary) -> SLM
prepare_gmem_to_slm(const char * tag,bool use_x_slm,const tensor_t & tg_tile,const view_t & x_gmem_view,const expr_t & xp_buf,const expr_t & x_slm_buf,view_t & x_slm_view,g2s_context_t & g2s_ctx)5515 void prepare_gmem_to_slm(const char *tag, bool use_x_slm,
5516 const tensor_t &tg_tile, const view_t &x_gmem_view,
5517 const expr_t &xp_buf, const expr_t &x_slm_buf, view_t &x_slm_view,
5518 g2s_context_t &g2s_ctx) {
5519 if (!use_x_slm) return;
5520
5521 grid_info_t load_grid = gemm_schedule_.tg_grid();
5522 for (;;) {
5523 bool ok = prepare_gmem_to_slm_impl(tag, use_x_slm, tg_tile,
5524 x_gmem_view, xp_buf, x_slm_buf, x_slm_view, load_grid,
5525 g2s_ctx);
5526 if (ok) {
5527 g2s_ctx.prev_load_grid = load_grid;
5528 g2s_ctx.register_grid(load_grid);
5529 return;
5530 }
5531
5532 // Reduce grid and try again.
5533 auto grid_idx = g2s_ctx.create_tmp_grid_idx();
5534 int dim_idx;
5535 expr_t grid_idx_value;
5536 auto new_load_grid
5537 = load_grid.halven(grid_idx, dim_idx, grid_idx_value);
5538 if (new_load_grid.is_empty()) break;
5539
5540 if (new_load_grid == g2s_ctx.prev_load_grid) {
5541 new_load_grid = load_grid.halven(
5542 grid_idx, dim_idx, grid_idx_value, /*first=*/false);
5543 g2s_ctx.reuse_buffers = true;
5544 }
5545 g2s_ctx.set_grid_idx_value(grid_idx, grid_idx_value);
5546
5547 cset_.add_constraint(grid_idx >= 0);
5548 cset_.add_constraint(grid_idx < new_load_grid.dim(dim_idx));
5549
5550 load_grid = new_load_grid;
5551 }
5552 ir_error_not_expected() << "Can't create GMEM -> SLM loads/stores.";
5553 }
5554
prepare_gmem_to_slm_impl(const char * tag,bool use_x_slm,const tensor_t & tg_tile,const view_t & x_gmem_view,const expr_t & xp_buf,const expr_t & x_slm_buf,view_t & x_slm_view,const grid_info_t & load_grid,g2s_context_t & g2s_ctx)5555 bool prepare_gmem_to_slm_impl(const char *tag, bool use_x_slm,
5556 const tensor_t &tg_tile, const view_t &x_gmem_view,
5557 const expr_t &xp_buf, const expr_t &x_slm_buf, view_t &x_slm_view,
5558 const grid_info_t &load_grid, g2s_context_t &g2s_ctx) {
5559 bool is_a = (tag[0] == 'A');
5560
5561 auto xp_slm_layout = create_slm_layout(
5562 x_gmem_view, is_a ? abc_kind_t::a : abc_kind_t::b, load_grid);
5563
5564 auto grid_cond = load_grid.slice_condition();
5565
5566 tensor_t thr_tile;
5567 // Per-thread view to load from GMEM to SLM.
5568 auto x_g2s_view = x_gmem_view.split(load_grid, thr_tile);
5569 auto slm_thr_layout = xp_slm_layout.map(thr_tile);
5570
5571 // Ensure that each thread writes a dense region to SLM. If the layout
5572 // is not dense, return and try with smaller grid.
5573 if (!slm_thr_layout.is_dense()) return false;
5574
5575 register_compute_buffer(
5576 x_slm_buf, xp_slm_layout.size(), alloc_kind_t::slm);
5577 ab_slm_size_ += xp_slm_layout.size();
5578
5579 // Temporary GRF buffer.
5580 expr_t x_g2s_reg_buf = g2s_ctx.create_buf("g2s");
5581
5582 // GMEM -> GRF load.
5583 read_builder_t x_read(cfg_.hw, ir_ctx_, cset_, x_g2s_view, xp_buf,
5584 x_g2s_reg_buf, /*is_slm=*/false);
5585 ir_trace() << tag << " GMEM to GRF load:\n"
5586 << x_read.str() << std::endl;
5587
5588 g2s_ctx.set_buf_size(x_g2s_reg_buf, x_read.reg_buf_size());
5589
5590 auto load_stmt = x_read.stmt();
5591 if (!grid_cond.is_empty()) load_stmt = if_t::make(grid_cond, load_stmt);
5592 g2s_load_stmt_ = g2s_load_stmt_.append(load_stmt);
5593
5594 // GRF -> SLM store.
5595 write_builder_t x_write(cfg_.hw, ir_ctx_, cset_, view_t(slm_thr_layout),
5596 x_slm_buf, x_g2s_reg_buf, /*is_slm=*/true);
5597 ir_trace() << tag << " GRF to SLM store:\n"
5598 << x_write.str() << std::endl;
5599 auto store_stmt = x_write.stmt();
5600
5601 auto &read_layout = x_read.reg_layout();
5602 auto &write_layout = x_write.reg_layout();
5603 if (read_layout != write_layout) {
5604 if (cfg_.allow_grf_reorder) {
5605 // Temporary GRF buffer.
5606 expr_t tmp_buf
5607 = g2s_ctx.create_buf("g2s_tmp", /*force_reuse=*/true);
5608 auto reorder_stmt = create_reorder_stmt(
5609 read_layout, write_layout, x_g2s_reg_buf, tmp_buf);
5610 g2s_ctx.set_buf_size(tmp_buf, x_write.reg_buf_size());
5611 store_stmt = substitute(store_stmt, x_g2s_reg_buf, tmp_buf);
5612 store_stmt = reorder_stmt.append(store_stmt);
5613 } else {
5614 ir_error_not_expected() << "Requested register layouts for "
5615 << tag << " do not match: "
5616 << "read: " << read_layout
5617 << ", write: " << write_layout;
5618 }
5619 }
5620 // Generate reduction statement for B.
5621 if (!is_a && cfg_.do_b_reduction) {
5622 auto absolute_thr_tile = tg_tile.create_sub_tensor(thr_tile);
5623 b_reduce_ctx_.init_reduced_thr_view(absolute_thr_tile, grid_cond);
5624 auto reduce_stmt = b_reduce_ctx_.create_reduce_stmt(
5625 read_layout, x_g2s_reg_buf);
5626 store_stmt = reduce_stmt.append(store_stmt);
5627 }
5628 if (!grid_cond.is_empty())
5629 store_stmt = if_t::make(grid_cond, store_stmt);
5630 g2s_store_stmt_ = g2s_store_stmt_.append(store_stmt);
5631
5632 x_slm_view = view_t(xp_slm_layout);
5633
5634 return true;
5635 }
5636
prepare_prefetch(const char * tag,bool use_prefetch,const view_t & x_gmem_view,const expr_t & xp_buf)5637 void prepare_prefetch(const char *tag, bool use_prefetch,
5638 const view_t &x_gmem_view, const expr_t &xp_buf) {
5639 if (!use_prefetch) return;
5640
5641 // Per-thread view to prefetch from GMEM.
5642 auto thr_view = x_gmem_view.split(gemm_schedule_.tg_grid());
5643
5644 // GMEM prefetch.
5645 read_builder_t x_prefetch(cfg_.hw, ir_ctx_, cset_, thr_view, xp_buf,
5646 expr_t(), /*is_slm=*/false, /*is_prefetch=*/true);
5647 ir_trace() << tag << " GMEM prefetch:\n"
5648 << x_prefetch.str() << std::endl;
5649
5650 prefetch_stmt_ = prefetch_stmt_.append(x_prefetch.stmt());
5651 }
5652
create_slm_layout(const view_t & tg_view,abc_kind_t abc_kind,const grid_info_t & load_grid) const5653 layout_t create_slm_layout(const view_t &tg_view, abc_kind_t abc_kind,
5654 const grid_info_t &load_grid) const {
5655 auto layout = tg_view.create_dense_vlayout();
5656 auto &a_type = gemm_schedule_.a_view().type();
5657 auto &b_type = gemm_schedule_.b_view().type();
5658 auto ret = convert_to_fma_friendly_layout(cfg_, abc_kind,
5659 gemm_schedule_.bmnk_mapper(), layout, a_type, b_type);
5660 if (cfg_.pad_slm) ret = pad_slm_layout(ret, load_grid);
5661 return ret;
5662 }
5663
5664 // SLM has 65 dword-granularity banks (Xe_HP):
5665 // banks: [bank 0] [bank 1] [bank 2] ... [bank 0]
5666 // byte offsets: | 0 | 4 | 8 ... | 4 * 65
5667 // SLM reads don't have conflicts. During SLM writes each fused EU writes
5668 // 64 bytes (in total 128 bytes per clock). If there are repeating banks
5669 // between 128 bytes the write takes 2 clocks to complete.
5670 // Assume that every X-axis thread (across tg_dim[0]) writes the
5671 // corresponding outer block of the layout. The goal is to ensure that the
5672 // stride between outer blocks allows to avoid duplicated banks.
pad_slm_layout(const layout_t & layout,const grid_info_t & load_grid) const5673 layout_t pad_slm_layout(
5674 const layout_t &layout, const grid_info_t &load_grid) const {
5675 auto tg_dim0 = load_grid.dim(0);
5676 auto tg_dim1 = load_grid.dim(1);
5677 int type_size = layout.type().size();
5678
5679 ir_assert(layout.elems() % tg_dim0 == 0) << layout;
5680 dim_t inner_block = layout.elems() / tg_dim0;
5681
5682 ir_assert((inner_block * type_size) % tg_dim1 == 0) << layout;
5683 dim_t per_thr_bytes = (inner_block * type_size) / tg_dim1;
5684
5685 std::vector<dim_t> multi_blocks = {inner_block, tg_dim0};
5686 auto l = layout.split_into_multi_blocks(multi_blocks);
5687
5688 auto padded_blocks = l.blocks();
5689 dim_t stride = -1;
5690 dim_t remaining_elems = inner_block;
5691 bool past_inner_block = false;
5692 for (auto &b : padded_blocks) {
5693 if (past_inner_block) {
5694 if (stride == -1) {
5695 dim_t stride_bytes = find_min_stride_without_conflicts(
5696 per_thr_bytes, dim_t(b.stride) * type_size);
5697 ir_assert(stride_bytes % type_size == 0);
5698 stride = stride_bytes / type_size;
5699 }
5700 b.stride = stride;
5701 stride = b.stride * b.block;
5702 continue;
5703 }
5704 ir_assert(remaining_elems % b.block == 0);
5705 remaining_elems /= b.block;
5706 if (remaining_elems == 1) past_inner_block = true;
5707 }
5708 return layout_t(
5709 layout.type(), layout.ndims(), layout.offset(), padded_blocks);
5710 }
5711
find_min_stride_without_conflicts(dim_t inner_bytes,dim_t dense_stride_bytes) const5712 dim_t find_min_stride_without_conflicts(
5713 dim_t inner_bytes, dim_t dense_stride_bytes) const {
5714 int write_step = 64;
5715 int stride_step = 16;
5716 dim_t stride_beg = dense_stride_bytes;
5717 dim_t stride_end = 2 * dense_stride_bytes;
5718 const int slm_banks = 65;
5719 for (dim_t s = stride_beg; s < stride_end; s += stride_step) {
5720 bool ok = true;
5721 for (dim_t off0 = 0; off0 < inner_bytes; off0 += write_step) {
5722 // Check banks for a single SLM write.
5723 bool found[slm_banks] = {false};
5724 for (dim_t off = off0; off < off0 + write_step;
5725 off += sizeof(uint32_t)) {
5726 int bank0 = (off / sizeof(uint32_t)) % slm_banks;
5727 int bank1 = ((off + s) / sizeof(uint32_t)) % slm_banks;
5728 if (found[bank0]) {
5729 ok = false;
5730 break;
5731 }
5732 found[bank0] = true;
5733 if (found[bank1]) {
5734 ok = false;
5735 break;
5736 }
5737 found[bank1] = true;
5738 }
5739 if (ok) return s;
5740 }
5741 }
5742
5743 ir_warning()
5744 << "Couldn't find stride without conflicts for SLM padding."
5745 << std::endl;
5746
5747 return dense_stride_bytes;
5748 }
5749
5750 const conv_config_t &cfg_;
5751 ir_context_t &ir_ctx_;
5752 constraint_set_t &cset_;
5753 post_op_context_t post_op_ctx_;
5754 b_reduce_context_t b_reduce_ctx_;
5755
5756 g2s_context_t g2s_ctx_;
5757
5758 gemm_schedule_t gemm_schedule_;
5759
5760 expr_t ap_buf_;
5761 expr_t bp_buf_;
5762 expr_t cp_buf_;
5763
5764 std::vector<stmt_t> compute_allocs_;
5765 std::vector<stmt_t> out_allocs_;
5766 int ab_slm_size_ = 0;
5767
5768 stmt_t g2s_load_stmt_;
5769 stmt_t g2s_store_stmt_;
5770 stmt_t prefetch_stmt_;
5771 stmt_t load_mul_stmt_;
5772
5773 stmt_t c_zero_out_stmt_;
5774 stmt_t c_store_stmt_;
5775
5776 stmt_t b_reduced_zero_out_stmt_;
5777 stmt_t b_reduced_store_stmt_;
5778 };
5779
build()5780 void kernel_builder_t::build() {
5781 ir_context_t ir_ctx;
5782 constraint_set_t init_cset;
5783
5784 int grid_ndims = 3;
5785 kernel_grid_ = grid_info_t(grid_ndims);
5786 tg_grid_ = grid_info_t(grid_ndims);
5787 for (int i = 0; i < grid_ndims; i++) {
5788 local_id_[i]
5789 = var_t::make(type_t::u16(), "local_id" + std::to_string(i));
5790 kernel_grid_.dim(i) = cfg_.kernel_grid_dim[i];
5791 kernel_grid_.idx(i)
5792 = var_t::make(type_t::s32(), "grid_idx" + std::to_string(i));
5793 tg_grid_.dim(i) = cfg_.tg_grid_dim[i];
5794 tg_grid_.idx(i)
5795 = var_t::make(type_t::s32(), "tg_idx" + std::to_string(i));
5796
5797 init_cset.add_constraint(kernel_grid_.idx(i) >= 0);
5798 init_cset.add_constraint(kernel_grid_.idx(i) < cfg_.kernel_grid_dim[i]);
5799 init_cset.add_constraint(tg_grid_.idx(i) >= 0);
5800 init_cset.add_constraint(tg_grid_.idx(i) < cfg_.tg_grid_dim[i]);
5801 }
5802
5803 gemm_schedule_t gemm_schedule(init_cset, kernel_grid_, tg_grid_);
5804
5805 std::vector<stmt_t> init_stmts;
5806 for (int i = 0; i < grid_ndims; i++) {
5807 auto value = local_id_[i];
5808 if (i == 0) value /= cfg_.simd_size;
5809 init_stmts.push_back(let_t::make(tg_grid_.idx(i), value));
5810 }
5811
5812 // Initialize memory buffers.
5813 std::vector<stmt_t> inner_lets;
5814
5815 view_t a_view;
5816 view_t b_view;
5817 view_t c_view;
5818 view_t bp_reduced_view;
5819
5820 expr_t ap_buf;
5821 expr_t bp_buf;
5822 expr_t cp_buf;
5823 expr_t b_reduced_mem_buf;
5824 expr_t b_reduction_condition;
5825
5826 if (cfg_.is_fwd) {
5827 init_fwd(gemm_schedule, a_view, b_view, c_view, ap_buf, bp_buf, cp_buf);
5828 } else if (cfg_.is_bwd_d) {
5829 init_bwd_d(
5830 gemm_schedule, a_view, b_view, c_view, ap_buf, bp_buf, cp_buf);
5831 } else if (cfg_.is_bwd_w) {
5832 init_bwd_w(gemm_schedule, a_view, b_view, c_view, bp_reduced_view,
5833 ap_buf, bp_buf, cp_buf, b_reduced_mem_buf,
5834 b_reduction_condition);
5835 } else {
5836 ir_error_not_expected();
5837 }
5838
5839 gemm_schedule.finalize();
5840
5841 post_op_context_t post_op_ctx(
5842 pd_, cfg_, gemm_schedule.c_view(), kernel_arg_info_);
5843 compute_builder_t cb(cfg_, ir_ctx, init_cset);
5844
5845 cb.set_gemm_schedule(gemm_schedule);
5846 cb.set_ap_buf(ap_buf);
5847 cb.set_bp_buf(bp_buf);
5848 cb.set_cp_buf(cp_buf);
5849 cb.set_b_reduced_mem_buf(b_reduced_mem_buf);
5850 cb.set_b_reduced_view(bp_reduced_view);
5851 cb.set_post_op_context(post_op_ctx);
5852 cb.set_reduce_condition(b_reduction_condition);
5853
5854 cb.build();
5855
5856 std::vector<stmt_t> allocs;
5857 for (int i = 0; i < kernel_arg_info_.nargs(); i++) {
5858 auto &var = kernel_arg_info_.arg_var(i);
5859 if (!var.type().is_ptr()) continue;
5860 allocs.push_back(alloc_t::make(var, 0, alloc_kind_t::global));
5861 }
5862
5863 // Create IR statements.
5864 stmt_t loop_stmt = cb.iter_stmt();
5865 loop_stmt = gemm_schedule.create_loop_nest(loop_stmt);
5866 loop_stmt = stmt_group_t::make(stmt_label_t::compute_loop(), loop_stmt);
5867 loop_stmt = cb.inject_compute_alloc_stmts(loop_stmt);
5868
5869 auto c_store_stmt
5870 = stmt_group_t::make(stmt_label_t::c_store(), cb.c_store_stmt());
5871 stmt_ = loop_stmt;
5872 stmt_ = stmt_seq_t::make(cb.zero_out_stmt(), stmt_);
5873 stmt_ = stmt_seq_t::make(stmt_, cb.b_reduced_store_stmt());
5874 stmt_ = stmt_seq_t::make(stmt_, c_store_stmt);
5875
5876 stmt_ = cb.inject_out_alloc_stmts(stmt_);
5877 stmt_ = cb.inject_let_stmts(stmt_);
5878
5879 stmt_ = gemm_schedule.create_bind_stmt(stmt_);
5880 stmt_ = inject_let_stmts(stmt_, init_stmts);
5881 stmt_ = inject_alloc_stmts(stmt_, allocs);
5882
5883 stmt_ = inject_external_var_let(stmt_);
5884 stmt_ = merge_slm_buffers(stmt_);
5885 if (!cfg_.do_loop_unroll && (cfg_.use_a_slm || cfg_.use_b_slm)) {
5886 stmt_ = inject_simple_slm_buffering(
5887 cfg_.hw, stmt_, cfg_, ir_ctx, cb.ab_slm_size());
5888 }
5889 stmt_ = lift_buffer_offsets_in_send(stmt_);
5890 stmt_ = simplify_pass(stmt_, init_cset);
5891 stmt_ = inject_send(stmt_, ir_ctx, init_cset);
5892 stmt_ = split_wide_stores(cfg_.hw, stmt_);
5893 stmt_ = lift_alloc(stmt_, cfg_);
5894 stmt_ = eliminate_common_subexprs(stmt_, ir_ctx);
5895 stmt_ = hoist_exprs(stmt_, ir_ctx);
5896 if (cfg_.do_loop_unroll) stmt_ = loop_strength_reduce(stmt_);
5897 stmt_ = optimize_alloc_let(stmt_);
5898 if (cfg_.do_loop_unroll) {
5899 stmt_ = update_loops_for_unrolling(stmt_, cfg_);
5900 stmt_ = inject_unrolling(stmt_, cfg_, ir_ctx, cb.ab_slm_size());
5901 }
5902 stmt_ = fixup_if_conditions(stmt_, cfg_);
5903 stmt_ = unroll_loops(stmt_, ir_ctx);
5904 stmt_ = simplify_pass(stmt_, init_cset);
5905 stmt_ = optimize_alloc_let(stmt_);
5906 stmt_ = optimize_peephole(stmt_);
5907 stmt_ = stmt_group_t::make(stmt_label_t::kernel(), stmt_);
5908
5909 ir_trace() << "Kernel body:\n" << stmt_ << std::endl;
5910 }
5911
5912 namespace {
need_src_or_dst_check(bool is_fwd,int o,int i,int k,int p,int s,int d)5913 bool need_src_or_dst_check(
5914 bool is_fwd, int o, int i, int k, int p, int s, int d) {
5915 if (is_fwd) {
5916 int i_min = -p;
5917 int i_max = (o - 1) * s - p + (k - 1) * (1 + d);
5918 return (i_min < 0) || (i_max >= i);
5919 }
5920 // Backward.
5921 int os_min = p - (k - 1) * (1 + d);
5922 int os_max = (o - 1) + p;
5923 return (os_min < 0) || (os_max >= i * s);
5924 }
5925
5926 } // namespace
5927
init_fwd(gemm_schedule_t & gemm_schedule,view_t & src_view,view_t & wei_view,view_t & dst_view,expr_t & src_buf,expr_t & wei_buf,expr_t & dst_buf)5928 void kernel_builder_t::init_fwd(gemm_schedule_t &gemm_schedule,
5929 view_t &src_view, view_t &wei_view, view_t &dst_view, expr_t &src_buf,
5930 expr_t &wei_buf, expr_t &dst_buf) {
5931 // Unify layouts.
5932 auto src_layout = cfg_.src_layout;
5933 auto wei_layout = cfg_.wei_layout;
5934 auto dst_layout = cfg_.dst_layout;
5935 normalize_conv_layouts(src_layout, wei_layout, dst_layout, cfg_.with_groups,
5936 cfg_.g, cfg_.is_dw, cfg_.reduced_to_1d, /*add_groups=*/true);
5937
5938 // Initialize views.
5939 auto mb = var_t::make(type_t::s32(), "mb");
5940 auto ic = var_t::make(type_t::s32(), "ic");
5941 auto oc = var_t::make(type_t::s32(), "oc");
5942 auto od = var_t::make(type_t::s32(), "od");
5943 auto oh = var_t::make(type_t::s32(), "oh");
5944 auto ow = var_t::make(type_t::s32(), "ow");
5945 auto kd = var_t::make(type_t::s32(), "kd");
5946 auto kh = var_t::make(type_t::s32(), "kh");
5947 auto kw = var_t::make(type_t::s32(), "kw");
5948 auto g = var_t::make(type_t::s32(), "g");
5949
5950 // Initialize masks.
5951 expr_t id_mask, ih_mask, iw_mask;
5952 expr_t od_mask, oh_mask, ow_mask;
5953 expr_t src_mb_mask, dst_mb_mask;
5954 expr_t wei_oc_mask, dst_oc_mask;
5955 expr_t src_g_mask, wei_g_mask, dst_g_mask;
5956 expr_t kw_mask;
5957
5958 bool check_ow = (cfg_.ow % cfg_.ow_tg_blk != 0);
5959 bool check_iw = check_ow
5960 || need_src_or_dst_check(cfg_.is_fwd, cfg_.ow, cfg_.iw, cfg_.kw,
5961 cfg_.pw, cfg_.sw, cfg_.dw);
5962 bool check_ih = need_src_or_dst_check(
5963 cfg_.is_fwd, cfg_.oh, cfg_.ih, cfg_.kh, cfg_.ph, cfg_.sh, cfg_.dh);
5964 bool check_id = need_src_or_dst_check(
5965 cfg_.is_fwd, cfg_.od, cfg_.id, cfg_.kd, cfg_.pd, cfg_.sd, cfg_.dd);
5966 bool check_kw = (cfg_.kw % cfg_.kw_blk != 0);
5967
5968 int src_g = int(src_layout.dim(1));
5969 int src_g_inner_blk = ir_utils::max_pow2_divisor(src_g);
5970 src_g_inner_blk = std::min(src_g_inner_blk, cfg_.g_thr_blk);
5971
5972 int wei_g = int(wei_layout.dim(0));
5973 int wei_g_inner_blk = ir_utils::max_pow2_divisor(wei_g);
5974 wei_g_inner_blk = std::min(wei_g_inner_blk, cfg_.g_thr_blk);
5975
5976 int wei_oc = int(wei_layout.dim(1));
5977 int wei_oc_inner_blk = ir_utils::max_pow2_divisor(wei_oc);
5978 wei_oc_inner_blk = std::min(wei_oc_inner_blk, cfg_.oc_thr_blk);
5979
5980 int dst_g = int(dst_layout.dim(1));
5981 int dst_g_inner_blk = ir_utils::max_pow2_divisor(dst_g);
5982 dst_g_inner_blk = std::min(dst_g_inner_blk, cfg_.g_thr_blk);
5983
5984 int dst_oc = int(dst_layout.dim(2));
5985 int dst_oc_inner_blk = ir_utils::max_pow2_divisor(dst_oc);
5986 dst_oc_inner_blk = std::min(dst_oc_inner_blk, cfg_.oc_thr_blk);
5987
5988 bool check_src_g = (src_g % cfg_.g_tg_blk != 0);
5989 bool check_wei_g = (wei_g % cfg_.g_tg_blk != 0);
5990 bool check_wei_oc = (wei_oc % cfg_.oc_tg_blk != 0);
5991 bool check_dst_g = (dst_g % cfg_.g_tg_blk != 0);
5992 bool check_dst_oc = (dst_oc % cfg_.oc_tg_blk != 0);
5993
5994 int src_mb = int(src_layout.dim(0));
5995 int dst_mb = int(dst_layout.dim(0));
5996
5997 bool check_src_mb = (src_mb % cfg_.mb_tg_blk != 0);
5998 bool check_dst_mb = (dst_mb % cfg_.mb_tg_blk != 0);
5999
6000 auto &x = view_t::placeholder_var();
6001 if (check_id) id_mask = (x >= 0) & (x < cfg_.id);
6002 if (check_ih) ih_mask = (x >= 0) & (x < cfg_.ih);
6003 if (check_iw) iw_mask = (x >= 0) & (x < cfg_.iw);
6004 if (check_ow) ow_mask = (x >= 0) & (x < cfg_.ow);
6005 if (check_src_g)
6006 src_g_mask = (x / src_g_inner_blk < src_g / src_g_inner_blk);
6007 if (check_wei_g)
6008 wei_g_mask = (x / wei_g_inner_blk < wei_g / wei_g_inner_blk);
6009 if (check_wei_oc)
6010 wei_oc_mask = (x / wei_oc_inner_blk < wei_oc / wei_oc_inner_blk);
6011 if (check_dst_g)
6012 dst_g_mask = (x / dst_g_inner_blk < dst_g / dst_g_inner_blk);
6013 if (check_dst_oc)
6014 dst_oc_mask = (x / dst_oc_inner_blk < dst_oc / dst_oc_inner_blk);
6015 if (check_kw) kw_mask = (x < cfg_.kw);
6016 if (check_src_mb) src_mb_mask = (x < src_mb);
6017 if (check_dst_mb) dst_mb_mask = (x < dst_mb);
6018
6019 // Source.
6020 src_view = view_t({mb, g, ic, od, oh, ow, kd, kh, kw}, 6);
6021 src_view.set_vdim(mb, cfg_.mb);
6022 src_view.set_vdim(g, cfg_.g);
6023 src_view.set_vdim(ic, cfg_.ic);
6024 src_view.set_vdim(od, cfg_.od);
6025 src_view.set_vdim(oh, cfg_.oh);
6026 src_view.set_vdim(ow, cfg_.ow);
6027 src_view.set_vdim(kd, cfg_.kd);
6028 src_view.set_vdim(kh, cfg_.kh);
6029 src_view.set_vdim(kw, cfg_.kw);
6030 src_view.set_tdim(0, mb, src_mb_mask);
6031 src_view.set_tdim(1, g, src_g_mask);
6032 src_view.set_tdim(2, ic);
6033 src_view.set_tdim(3, od * cfg_.sd - cfg_.pd + kd * (1 + cfg_.dd), id_mask);
6034 src_view.set_tdim(4, oh * cfg_.sh - cfg_.ph + kh * (1 + cfg_.dh), ih_mask);
6035 src_view.set_tdim(5, ow * cfg_.sw - cfg_.pw + kw * (1 + cfg_.dw), iw_mask);
6036 src_view.set_tlayout(src_layout);
6037
6038 // Weights.
6039 wei_view = view_t({g, oc, ic, kd, kh, kw}, 6);
6040 wei_view.set_vdim(g, cfg_.g);
6041 wei_view.set_vdim(oc, cfg_.oc);
6042 wei_view.set_vdim(ic, cfg_.ic);
6043 wei_view.set_vdim(kd, cfg_.kd);
6044 wei_view.set_vdim(kh, cfg_.kh);
6045 wei_view.set_vdim(kw, cfg_.kw);
6046 wei_view.set_tdim(0, g, wei_g_mask);
6047 wei_view.set_tdim(1, oc, wei_oc_mask);
6048 wei_view.set_tdim(2, ic);
6049 wei_view.set_tdim(3, kd);
6050 wei_view.set_tdim(4, kh);
6051 wei_view.set_tdim(5, kw, kw_mask);
6052 wei_view.set_tlayout(wei_layout);
6053
6054 // Destination.
6055 dst_view = view_t({mb, g, oc, od, oh, ow}, 6);
6056 dst_view.set_vdim(mb, cfg_.mb);
6057 dst_view.set_vdim(g, cfg_.g);
6058 dst_view.set_vdim(oc, cfg_.oc);
6059 dst_view.set_vdim(od, cfg_.od);
6060 dst_view.set_vdim(oh, cfg_.oh);
6061 dst_view.set_vdim(ow, cfg_.ow);
6062 dst_view.set_tdim(0, mb, dst_mb_mask);
6063 dst_view.set_tdim(1, g, dst_g_mask);
6064 dst_view.set_tdim(2, oc, dst_oc_mask);
6065 dst_view.set_tdim(3, od, od_mask);
6066 dst_view.set_tdim(4, oh, oh_mask);
6067 dst_view.set_tdim(5, ow, ow_mask);
6068 dst_view.set_tlayout(dst_layout);
6069
6070 // Initialize GEMM schedule.
6071 gemm_schedule.set_a_view(src_view);
6072 gemm_schedule.set_b_view(wei_view);
6073 gemm_schedule.set_c_view(dst_view);
6074 gemm_schedule.set_b_vars({g});
6075 gemm_schedule.set_m_vars({mb, od, oh, ow});
6076 gemm_schedule.set_n_vars({oc});
6077 gemm_schedule.set_k_vars({ic, kd, kh, kw});
6078
6079 expr_t g_tg_blk_idx, g_inner;
6080 expr_t oc_tg_blk_idx, oc_thr_blk_idx, oc_inner;
6081 expr_t mb_tg_blk_idx, mb_thr_blk_idx, mb_inner;
6082 expr_t ow_tg_blk_idx, ow_thr_blk_idx, ow_inner;
6083 expr_t kw_outer, kw_inner;
6084 expr_t ic_thr_blk_idx, ic_outer, ic_inner;
6085
6086 gemm_schedule.split(g, cfg_.g_tg_blk, g_tg_blk_idx, g_inner);
6087 gemm_schedule.split(oc, cfg_.oc_tg_blk, cfg_.oc_thr_blk, oc_tg_blk_idx,
6088 oc_thr_blk_idx, oc_inner);
6089 gemm_schedule.split(mb, cfg_.mb_tg_blk, cfg_.mb_thr_blk, mb_tg_blk_idx,
6090 mb_thr_blk_idx, mb_inner);
6091 gemm_schedule.split(ow, cfg_.ow_tg_blk, cfg_.ow_thr_blk, ow_tg_blk_idx,
6092 ow_thr_blk_idx, ow_inner);
6093 gemm_schedule.split(ic, cfg_.ic_blk * cfg_.ic_thr_dim, cfg_.ic_blk,
6094 ic_outer, ic_thr_blk_idx, ic_inner);
6095 gemm_schedule.split(kw, cfg_.kw_blk, kw_outer, kw_inner);
6096
6097 auto g_odhw_idx = gemm_schedule.fuse({g_tg_blk_idx, od, oh, ow_tg_blk_idx});
6098 auto mb_ow_thr_blk_idx = gemm_schedule.fuse(mb_thr_blk_idx, ow_thr_blk_idx);
6099
6100 gemm_schedule.bind(oc_tg_blk_idx, kernel_grid_.idx(0));
6101 gemm_schedule.bind(g_odhw_idx, kernel_grid_.idx(1));
6102 gemm_schedule.bind(mb_tg_blk_idx, kernel_grid_.idx(2));
6103 gemm_schedule.bind(oc_thr_blk_idx, tg_grid_.idx(0));
6104 gemm_schedule.bind(mb_ow_thr_blk_idx, tg_grid_.idx(1));
6105 gemm_schedule.bind(ic_thr_blk_idx, tg_grid_.idx(2));
6106
6107 gemm_schedule.tensorize(g_inner);
6108 gemm_schedule.tensorize(oc_inner);
6109 gemm_schedule.tensorize(mb_inner);
6110 gemm_schedule.tensorize(ow_inner);
6111 gemm_schedule.tensorize(kw_inner);
6112 gemm_schedule.tensorize(ic_inner);
6113
6114 gemm_schedule.reorder({ic_outer, kd, kh, kw_outer, oc_thr_blk_idx,
6115 mb_ow_thr_blk_idx, ic_thr_blk_idx});
6116
6117 src_buf = kernel_arg_info_.find_arg("src");
6118 wei_buf = kernel_arg_info_.find_arg("wei");
6119 dst_buf = kernel_arg_info_.find_arg("dst");
6120 }
6121
init_bwd_d(gemm_schedule_t & gemm_schedule,view_t & dst_view,view_t & wei_view,view_t & src_view,expr_t & dst_buf,expr_t & wei_buf,expr_t & src_buf)6122 void kernel_builder_t::init_bwd_d(gemm_schedule_t &gemm_schedule,
6123 view_t &dst_view, view_t &wei_view, view_t &src_view, expr_t &dst_buf,
6124 expr_t &wei_buf, expr_t &src_buf) {
6125 // Unify layouts.
6126 auto src_layout = cfg_.src_layout;
6127 auto wei_layout = cfg_.wei_layout;
6128 auto dst_layout = cfg_.dst_layout;
6129 normalize_conv_layouts(src_layout, wei_layout, dst_layout, cfg_.with_groups,
6130 cfg_.g, cfg_.is_dw, cfg_.reduced_to_1d, /*add_groups=*/false);
6131
6132 // Initialize views.
6133 auto mb = var_t::make(type_t::s32(), "mb");
6134 auto ic = var_t::make(type_t::s32(), "ic");
6135 auto oc = var_t::make(type_t::s32(), "oc");
6136 auto id = var_t::make(type_t::s32(), "id");
6137 auto ih = var_t::make(type_t::s32(), "ih");
6138 auto iw = var_t::make(type_t::s32(), "iw");
6139 auto kd = var_t::make(type_t::s32(), "kd");
6140 auto kh = var_t::make(type_t::s32(), "kh");
6141 auto kw = var_t::make(type_t::s32(), "kw");
6142
6143 // Initialize masks.
6144 expr_t id_mask, ih_mask, iw_mask;
6145 expr_t od_mask(true), oh_mask(true), ow_mask(true);
6146 expr_t src_mb_mask, dst_mb_mask;
6147 expr_t wei_oc_mask, dst_oc_mask;
6148 expr_t wei_ic_mask, src_ic_mask;
6149
6150 bool check_iw = (cfg_.iw % cfg_.iw_tg_blk != 0);
6151 bool check_ow = check_iw
6152 || need_src_or_dst_check(cfg_.is_fwd, cfg_.ow, cfg_.iw, cfg_.kw,
6153 cfg_.pw, cfg_.sw, cfg_.dw);
6154 bool check_oh = need_src_or_dst_check(
6155 cfg_.is_fwd, cfg_.oh, cfg_.ih, cfg_.kh, cfg_.ph, cfg_.sh, cfg_.dh);
6156 bool check_od = need_src_or_dst_check(
6157 cfg_.is_fwd, cfg_.od, cfg_.id, cfg_.kd, cfg_.pd, cfg_.sd, cfg_.dd);
6158
6159 int wei_ic = int(cfg_.wei_layout.dim(cfg_.with_groups ? 2 : 1));
6160 int src_ic = int(cfg_.src_layout.dim(1));
6161
6162 int wei_ic_inner_blk = ir_utils::max_pow2_divisor(wei_ic);
6163 int src_ic_inner_blk = ir_utils::max_pow2_divisor(src_ic);
6164 wei_ic_inner_blk = std::min(wei_ic_inner_blk, cfg_.ic_thr_blk);
6165 src_ic_inner_blk = std::min(src_ic_inner_blk, cfg_.ic_thr_blk);
6166
6167 bool check_wei_ic = (wei_ic % cfg_.ic_tg_blk != 0);
6168 bool check_src_ic = (src_ic % cfg_.ic_tg_blk != 0);
6169
6170 int src_mb = int(cfg_.src_layout.dim(0));
6171 int dst_mb = int(cfg_.src_layout.dim(0));
6172
6173 bool check_src_mb = (src_mb % cfg_.mb_tg_blk != 0);
6174 bool check_dst_mb = (dst_mb % cfg_.mb_tg_blk != 0);
6175
6176 auto &x = view_t::placeholder_var();
6177 if (check_od) od_mask = (x >= 0) & (x < cfg_.od);
6178 if (check_oh) oh_mask = (x >= 0) & (x < cfg_.oh);
6179 if (check_ow) ow_mask = (x >= 0) & (x < cfg_.ow);
6180 if (check_iw) iw_mask = (x >= 0) & (x < cfg_.iw);
6181 if (check_wei_ic)
6182 wei_ic_mask = (x / wei_ic_inner_blk < wei_ic / wei_ic_inner_blk);
6183 if (check_src_ic)
6184 src_ic_mask = (x / src_ic_inner_blk < src_ic / src_ic_inner_blk);
6185 if (check_src_mb) src_mb_mask = (x < src_mb);
6186 if (check_dst_mb) dst_mb_mask = (x < dst_mb);
6187
6188 // Destination.
6189 dst_view = view_t({mb, oc, id, ih, iw, kd, kh, kw}, 5);
6190 dst_view.set_vdim(mb, cfg_.mb);
6191 dst_view.set_vdim(oc, cfg_.oc);
6192 dst_view.set_vdim(id, cfg_.id);
6193 dst_view.set_vdim(ih, cfg_.ih);
6194 dst_view.set_vdim(iw, cfg_.iw);
6195 dst_view.set_vdim(kd, cfg_.kd);
6196 dst_view.set_vdim(kh, cfg_.kh);
6197 dst_view.set_vdim(kw, cfg_.kw);
6198 dst_view.set_tdim(0, mb, src_mb_mask);
6199 dst_view.set_tdim(1, oc);
6200
6201 auto od = id - kd * (1 + cfg_.dd) + cfg_.pd;
6202 auto oh = ih - kh * (1 + cfg_.dh) + cfg_.ph;
6203 auto ow = iw - kw * (1 + cfg_.dw) + cfg_.pw;
6204 dst_view.set_tdim(2, od / cfg_.sd, od_mask & (od % cfg_.sd == 0));
6205 dst_view.set_tdim(3, oh / cfg_.sh, oh_mask & (oh % cfg_.sh == 0));
6206 dst_view.set_tdim(4, ow / cfg_.sw, ow_mask & (ow % cfg_.sw == 0));
6207
6208 dst_view.set_tlayout(dst_layout);
6209
6210 // Weights.
6211 wei_view = view_t({oc, ic, kd, kh, kw}, 5);
6212 wei_view.set_vdim(ic, cfg_.ic);
6213 wei_view.set_vdim(oc, cfg_.oc);
6214 wei_view.set_vdim(kd, cfg_.kd);
6215 wei_view.set_vdim(kh, cfg_.kh);
6216 wei_view.set_vdim(kw, cfg_.kw);
6217 wei_view.set_tdim(0, oc);
6218 wei_view.set_tdim(1, ic, wei_ic_mask);
6219 wei_view.set_tdim(2, kd);
6220 wei_view.set_tdim(3, kh);
6221 wei_view.set_tdim(4, kw);
6222 wei_view.set_tlayout(wei_layout);
6223
6224 // Source.
6225 src_view = view_t({mb, ic, id, ih, iw}, 5);
6226 src_view.set_vdim(mb, cfg_.mb);
6227 src_view.set_vdim(ic, cfg_.ic);
6228 src_view.set_vdim(id, cfg_.id);
6229 src_view.set_vdim(ih, cfg_.ih);
6230 src_view.set_vdim(iw, cfg_.iw);
6231 src_view.set_tdim(0, mb, dst_mb_mask);
6232 src_view.set_tdim(1, ic, src_ic_mask);
6233 src_view.set_tdim(2, id, id_mask);
6234 src_view.set_tdim(3, ih, ih_mask);
6235 src_view.set_tdim(4, iw, iw_mask);
6236 src_view.set_tlayout(src_layout);
6237
6238 // Initialize GEMM schedule.
6239 gemm_schedule.set_a_view(dst_view);
6240 gemm_schedule.set_b_view(wei_view);
6241 gemm_schedule.set_c_view(src_view);
6242 gemm_schedule.set_m_vars({mb, id, ih, iw});
6243 gemm_schedule.set_n_vars({ic});
6244 gemm_schedule.set_k_vars({oc, kd, kh, kw});
6245
6246 expr_t ic_tg_blk_idx, ic_thr_blk_idx, ic_inner;
6247 expr_t mb_tg_blk_idx, mb_inner;
6248 expr_t iw_tg_blk_idx, iw_thr_blk_idx, iw_inner;
6249 expr_t oc_blk_idx, oc_inner;
6250
6251 gemm_schedule.split(ic, cfg_.ic_tg_blk, cfg_.ic_thr_blk, ic_tg_blk_idx,
6252 ic_thr_blk_idx, ic_inner);
6253 gemm_schedule.split(mb, cfg_.mb_tg_blk, mb_tg_blk_idx, mb_inner);
6254 gemm_schedule.split(iw, cfg_.iw_tg_blk, cfg_.iw_thr_blk, iw_tg_blk_idx,
6255 iw_thr_blk_idx, iw_inner);
6256 gemm_schedule.split(oc, cfg_.oc_blk, oc_blk_idx, oc_inner);
6257
6258 auto idhw_idx = gemm_schedule.fuse(id, ih, iw_tg_blk_idx);
6259 gemm_schedule.bind(ic_tg_blk_idx, kernel_grid_.idx(0));
6260 gemm_schedule.bind(idhw_idx, kernel_grid_.idx(1));
6261 gemm_schedule.bind(mb_tg_blk_idx, kernel_grid_.idx(2));
6262 gemm_schedule.bind(ic_thr_blk_idx, tg_grid_.idx(0));
6263 gemm_schedule.bind(iw_thr_blk_idx, tg_grid_.idx(1));
6264
6265 gemm_schedule.tensorize(ic_inner);
6266 gemm_schedule.tensorize(mb_inner);
6267 gemm_schedule.tensorize(iw_inner);
6268 gemm_schedule.tensorize(oc_inner);
6269
6270 gemm_schedule.reorder({oc_blk_idx, kd, kh, kw});
6271
6272 src_buf = kernel_arg_info_.find_arg("src");
6273 wei_buf = kernel_arg_info_.find_arg("wei");
6274 dst_buf = kernel_arg_info_.find_arg("dst");
6275 }
6276
init_bwd_w(gemm_schedule_t & gemm_schedule,view_t & src_view,view_t & dst_view,view_t & wei_view,view_t & bia_view,expr_t & src_buf,expr_t & dst_buf,expr_t & wei_buf,expr_t & bia_buf,expr_t & bia_reduction_condition)6277 void kernel_builder_t::init_bwd_w(gemm_schedule_t &gemm_schedule,
6278 view_t &src_view, view_t &dst_view, view_t &wei_view, view_t &bia_view,
6279 expr_t &src_buf, expr_t &dst_buf, expr_t &wei_buf, expr_t &bia_buf,
6280 expr_t &bia_reduction_condition) {
6281 // Unify layouts.
6282 auto src_layout = cfg_.src_layout;
6283 auto wei_layout = cfg_.wei_layout;
6284 auto dst_layout = cfg_.dst_layout;
6285 normalize_conv_layouts(src_layout, wei_layout, dst_layout, cfg_.with_groups,
6286 cfg_.g, cfg_.is_dw, cfg_.reduced_to_1d, /*add_groups=*/false);
6287
6288 // Initialize thread group views.
6289 auto mb = var_t::make(type_t::s32(), "mb");
6290 auto ic = var_t::make(type_t::s32(), "ic");
6291 auto oc = var_t::make(type_t::s32(), "oc");
6292 auto od = var_t::make(type_t::s32(), "od");
6293 auto oh = var_t::make(type_t::s32(), "oh");
6294 auto ow = var_t::make(type_t::s32(), "ow");
6295 auto kd = var_t::make(type_t::s32(), "kd");
6296 auto kh = var_t::make(type_t::s32(), "kh");
6297 auto kw = var_t::make(type_t::s32(), "kw");
6298
6299 // Initialize masks.
6300 expr_t id_mask(true), ih_mask(true), iw_mask(true);
6301 expr_t od_mask, oh_mask, ow_mask;
6302 expr_t src_mb_mask, src_ic_mask;
6303 expr_t dst_mb_mask, dst_oc_mask;
6304 expr_t wei_oc_mask, wei_ic_mask;
6305 expr_t kw_mask;
6306
6307 bool check_ow = (cfg_.ow % cfg_.ow_tg_blk != 0);
6308 bool check_oh = (cfg_.oh % cfg_.oh_tg_blk != 0);
6309 bool check_od = (cfg_.od % cfg_.od_tg_blk != 0);
6310 bool check_kw = (cfg_.kw % cfg_.kw_blk != 0);
6311 bool check_iw = check_kw
6312 || need_src_or_dst_check(/*is_fwd=*/true, cfg_.ow, cfg_.iw, cfg_.kw,
6313 cfg_.pw, cfg_.sw, cfg_.dw);
6314 bool check_ih = need_src_or_dst_check(/*is_fwd=*/true, cfg_.oh, cfg_.ih,
6315 cfg_.kh, cfg_.ph, cfg_.sh, cfg_.dh);
6316 bool check_id = need_src_or_dst_check(/*is_fwd=*/true, cfg_.od, cfg_.id,
6317 cfg_.kd, cfg_.pd, cfg_.sd, cfg_.dd);
6318 bool check_iw_min = check_iw;
6319 bool check_ih_min = check_ih;
6320 bool check_id_min = check_id;
6321 bool check_iw_max = (check_iw || check_ow);
6322 bool check_ih_max = (check_ih || check_oh);
6323 bool check_id_max = (check_id || check_od);
6324
6325 int src_ic = int(cfg_.src_layout.dim(1));
6326 int dst_oc = int(cfg_.dst_layout.dim(1));
6327 int wei_oc = int(cfg_.wei_layout.dim(cfg_.with_groups ? 1 : 0));
6328 int wei_ic = int(cfg_.wei_layout.dim(cfg_.with_groups ? 2 : 1));
6329
6330 int src_ic_inner_blk = ir_utils::max_pow2_divisor(src_ic);
6331 int dst_oc_inner_blk = ir_utils::max_pow2_divisor(dst_oc);
6332 int wei_oc_inner_blk = ir_utils::max_pow2_divisor(wei_oc);
6333 int wei_ic_inner_blk = ir_utils::max_pow2_divisor(wei_ic);
6334 src_ic_inner_blk = std::min(src_ic_inner_blk, cfg_.ic_thr_blk);
6335 dst_oc_inner_blk = std::min(dst_oc_inner_blk, cfg_.oc_thr_blk);
6336 wei_oc_inner_blk = std::min(wei_oc_inner_blk, cfg_.oc_thr_blk);
6337 wei_ic_inner_blk = std::min(wei_ic_inner_blk, cfg_.ic_thr_blk);
6338
6339 bool check_src_ic = (src_ic % cfg_.ic_tg_blk != 0);
6340 bool check_dst_oc = (dst_oc % cfg_.oc_tg_blk != 0);
6341 bool check_wei_oc = (wei_oc % cfg_.oc_tg_blk != 0);
6342 bool check_wei_ic = (wei_ic % cfg_.ic_tg_blk != 0);
6343
6344 auto &x = view_t::placeholder_var();
6345 if (check_id_min) id_mask &= (x >= 0);
6346 if (check_ih_min) ih_mask &= (x >= 0);
6347 if (check_iw_min) iw_mask &= (x >= 0);
6348 if (check_id_max) id_mask &= (x < cfg_.id);
6349 if (check_ih_max) ih_mask &= (x < cfg_.ih);
6350 if (check_iw_max) iw_mask &= (x < cfg_.iw);
6351 if (check_od) od_mask = (x < cfg_.od);
6352 if (check_oh) oh_mask = (x < cfg_.oh);
6353 if (check_ow) ow_mask = (x < cfg_.ow);
6354 if (check_src_ic)
6355 src_ic_mask = (x / src_ic_inner_blk < src_ic / src_ic_inner_blk);
6356 if (check_dst_oc)
6357 dst_oc_mask = (x / dst_oc_inner_blk < dst_oc / dst_oc_inner_blk);
6358 if (check_wei_oc)
6359 wei_oc_mask = (x / wei_oc_inner_blk < wei_oc / wei_oc_inner_blk);
6360 if (check_wei_ic)
6361 wei_ic_mask = (x / wei_ic_inner_blk < wei_ic / wei_ic_inner_blk);
6362 if (check_kw) kw_mask = (x < cfg_.kw);
6363
6364 // Source.
6365 src_view = view_t({mb, ic, od, oh, ow, kw}, 5);
6366 src_view.set_vdim(mb, cfg_.mb);
6367 src_view.set_vdim(ic, cfg_.ic);
6368 src_view.set_vdim(od, cfg_.od);
6369 src_view.set_vdim(oh, cfg_.oh);
6370 src_view.set_vdim(ow, cfg_.ow);
6371 src_view.set_vdim(kw, cfg_.kw);
6372 src_view.set_tdim(0, mb, src_mb_mask);
6373 src_view.set_tdim(1, ic, src_ic_mask);
6374 src_view.set_tdim(2, od * cfg_.sd - cfg_.pd + kd * (1 + cfg_.dd), id_mask);
6375 src_view.set_tdim(3, oh * cfg_.sh - cfg_.ph + kh * (1 + cfg_.dh), ih_mask);
6376 src_view.set_tdim(4, ow * cfg_.sw - cfg_.pw + kw * (1 + cfg_.dw), iw_mask);
6377 src_view.set_tlayout(src_layout);
6378
6379 // Weights.
6380 wei_view = view_t({oc, ic, kd, kh, kw}, 5);
6381 wei_view.set_vdim(oc, cfg_.oc);
6382 wei_view.set_vdim(ic, cfg_.ic);
6383 wei_view.set_vdim(kd, cfg_.kd);
6384 wei_view.set_vdim(kh, cfg_.kh);
6385 wei_view.set_vdim(kw, cfg_.kw);
6386 wei_view.set_tdim(0, oc, wei_oc_mask);
6387 wei_view.set_tdim(1, ic, wei_ic_mask);
6388 wei_view.set_tdim(2, kd);
6389 wei_view.set_tdim(3, kh);
6390 wei_view.set_tdim(4, kw, kw_mask);
6391 wei_view.set_tlayout(wei_layout);
6392
6393 // Destination.
6394 dst_view = view_t({mb, oc, od, oh, ow}, 5);
6395 dst_view.set_vdim(mb, cfg_.mb);
6396 dst_view.set_vdim(oc, cfg_.oc);
6397 dst_view.set_vdim(od, cfg_.od);
6398 dst_view.set_vdim(oh, cfg_.oh);
6399 dst_view.set_vdim(ow, cfg_.ow);
6400 dst_view.set_tdim(0, mb, dst_mb_mask);
6401 dst_view.set_tdim(1, oc, dst_oc_mask);
6402 dst_view.set_tdim(2, od, od_mask);
6403 dst_view.set_tdim(3, oh, oh_mask);
6404 dst_view.set_tdim(4, ow, ow_mask);
6405 dst_view.set_tlayout(dst_layout);
6406
6407 // Bias.
6408 if (cfg_.with_bias) {
6409 expr_t bia_oc_mask;
6410 if (cfg_.oc % cfg_.oc_tg_blk != 0) bia_oc_mask = (x < cfg_.oc);
6411 bia_view = view_t({oc}, 1);
6412 bia_view.set_vdim(oc, cfg_.oc, 0);
6413 bia_view.set_tdim(0, oc, bia_oc_mask);
6414 bia_view.set_tlayout(cfg_.bia_layout);
6415 }
6416
6417 // Initialize GEMM schedule.
6418 gemm_schedule.set_a_view(src_view);
6419 gemm_schedule.set_b_view(dst_view);
6420 gemm_schedule.set_c_view(wei_view);
6421 gemm_schedule.set_m_vars({ic, kw});
6422 gemm_schedule.set_n_vars({oc});
6423 gemm_schedule.set_k_vars({mb, od, oh, ow});
6424
6425 expr_t mb_tg_blk_idx, mb_thr_blk_idx, mb_inner;
6426 expr_t oc_tg_blk_idx, oc_thr_blk_idx, oc_inner;
6427 expr_t ic_tg_blk_idx, ic_thr_blk_idx, ic_inner;
6428 expr_t od_tg_blk_idx, od_inner;
6429 expr_t oh_tg_blk_idx, oh_inner;
6430 expr_t ow_tg_blk_idx, ow_thr_blk_idx, ow_inner;
6431 expr_t kw_tg_blk_idx, kw_inner;
6432
6433 gemm_schedule.split(mb, cfg_.mb_tg_blk, cfg_.mb_blk, mb_tg_blk_idx,
6434 mb_thr_blk_idx, mb_inner);
6435 gemm_schedule.split(ic, cfg_.ic_tg_blk, cfg_.ic_thr_blk, ic_tg_blk_idx,
6436 ic_thr_blk_idx, ic_inner);
6437 gemm_schedule.split(oc, cfg_.oc_tg_blk, cfg_.oc_thr_blk, oc_tg_blk_idx,
6438 oc_thr_blk_idx, oc_inner);
6439 gemm_schedule.split(od, cfg_.od_tg_blk, od_tg_blk_idx, od_inner);
6440 gemm_schedule.split(oh, cfg_.oh_tg_blk, oh_tg_blk_idx, oh_inner);
6441 gemm_schedule.split(ow, cfg_.ow_tg_blk, cfg_.ow_thr_blk, ow_tg_blk_idx,
6442 ow_thr_blk_idx, ow_inner);
6443 gemm_schedule.split(kw, cfg_.kw_tg_blk, kw_tg_blk_idx, kw_inner);
6444
6445 auto odhw_tg_blk_kdhw_ic_tg_blk_idx
6446 = gemm_schedule.fuse({od_tg_blk_idx, oh_tg_blk_idx, ow_tg_blk_idx,
6447 kd, kh, kw_tg_blk_idx, ic_tg_blk_idx});
6448
6449 gemm_schedule.bind(oc_tg_blk_idx, kernel_grid_.idx(0));
6450 gemm_schedule.bind(odhw_tg_blk_kdhw_ic_tg_blk_idx, kernel_grid_.idx(1));
6451 gemm_schedule.bind(mb_tg_blk_idx, kernel_grid_.idx(2));
6452
6453 gemm_schedule.bind(oc_thr_blk_idx, tg_grid_.idx(0));
6454 gemm_schedule.bind(ic_thr_blk_idx, tg_grid_.idx(1));
6455
6456 gemm_schedule.reorder({od_inner, oh_inner, ow_inner, mb_thr_blk_idx});
6457
6458 gemm_schedule.unroll(mb_thr_blk_idx, cfg_.mb_unroll);
6459 gemm_schedule.unroll(ow_thr_blk_idx, cfg_.ow_unroll);
6460 gemm_schedule.tensorize(oc_inner);
6461 gemm_schedule.tensorize(ic_inner);
6462 gemm_schedule.tensorize(mb_inner);
6463 gemm_schedule.tensorize(ow_inner);
6464 gemm_schedule.tensorize(kw_inner);
6465
6466 src_buf = kernel_arg_info_.find_arg("src");
6467 wei_buf = kernel_arg_info_.find_arg("wei");
6468 dst_buf = kernel_arg_info_.find_arg("dst");
6469
6470 if (cfg_.with_bias) {
6471 bia_buf = kernel_arg_info_.find_arg("bia");
6472 bia_reduction_condition = expr_t(true);
6473 if (cfg_.kd > 1) bia_reduction_condition &= (kd == 0);
6474 if (cfg_.kh > 1) bia_reduction_condition &= (kh == 0);
6475 if (cfg_.kw > 1) bia_reduction_condition &= (kw_tg_blk_idx == 0);
6476 if (cfg_.ic_tg_dim > 1) bia_reduction_condition &= (ic_tg_blk_idx == 0);
6477 if (!cfg_.use_b_slm && tg_grid_.dim(1) > 1) {
6478 bia_reduction_condition &= (tg_grid_.idx(1) == 0);
6479 }
6480 }
6481 }
6482
6483 } // namespace jit
6484 } // namespace gpu
6485 } // namespace impl
6486 } // namespace dnnl
6487