1 /*******************************************************************************
2 * Copyright 2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include "gpu/jit/conv/kernel_builder.hpp"
18 
19 #include <algorithm>
20 #include <array>
21 #include <iostream>
22 #include <memory>
23 #include <utility>
24 #include <vector>
25 #include <unordered_map>
26 
27 #include "gpu/jit/conv/config.hpp"
28 #include "gpu/jit/conv/fma_support.hpp"
29 #include "gpu/jit/conv/gemm_schedule.hpp"
30 #include "gpu/jit/conv/ir.hpp"
31 #include "gpu/jit/conv/message_support.hpp"
32 #include "gpu/jit/conv/post_op_support.hpp"
33 #include "gpu/jit/conv/reduce_support.hpp"
34 #include "gpu/jit/conv/reorder_support.hpp"
35 #include "gpu/jit/conv/tensor.hpp"
36 
37 namespace dnnl {
38 namespace impl {
39 namespace gpu {
40 namespace jit {
41 
42 class permutation_injector_t : public ir_mutator_t {
43 public:
permutation_injector_t(const grf_permutator_t & grf_perm)44     permutation_injector_t(const grf_permutator_t &grf_perm)
45         : grf_perm_(new grf_permutator_t(grf_perm)) {}
46 
_mutate(const func_call_t & obj)47     object_t _mutate(const func_call_t &obj) override {
48         if (!is_func_call<reorder_t>(&obj)) return ir_mutator_t::_mutate(obj);
49 
50         auto &func = obj.func.as<reorder_t>();
51         auto new_func
52                 = reorder_t::make(func.src_layout, func.dst_layout, grf_perm_);
53 
54         return new_func.call(obj.args);
55     }
56 
57 private:
58     std::shared_ptr<grf_permutator_t> grf_perm_;
59 };
60 
61 class dpasw_injector_t {
62 public:
dpasw_injector_t(ngen::HW hw,const stmt_t & load_mul_stmt,const expr_t & c_buf,const stmt_t & c_store_stmt,alloc_updater_t & alloc_updater,const expr_t & tg_idx0)63     dpasw_injector_t(ngen::HW hw, const stmt_t &load_mul_stmt,
64             const expr_t &c_buf, const stmt_t &c_store_stmt,
65             alloc_updater_t &alloc_updater, const expr_t &tg_idx0)
66         : hw_(hw)
67         , load_mul_stmt_(load_mul_stmt)
68         , c_buf_(c_buf)
69         , c_store_stmt_(c_store_stmt)
70         , alloc_updater_(alloc_updater)
71         , tg_idx0_(tg_idx0) {}
72 
load_mul_stmt() const73     const stmt_t &load_mul_stmt() const { return load_mul_stmt_; }
74 
c_store_stmt() const75     const stmt_t &c_store_stmt() const { return c_store_stmt_; }
76 
inject()77     void inject() {
78         expr_t src2_base;
79         extract_dpas_calls(src2_base);
80 
81         grf_permutator_t grf_perm(hw_, c_buf_);
82 
83         bool was_injected = false;
84         int dpas_count = int(dpas_infos_.size());
85         for (int i = 0; i < dpas_count;) {
86             if (i + 1 < dpas_count) {
87                 auto &a = dpas_infos_[i];
88                 auto &b = dpas_infos_[i + 1];
89                 if (try_convert_to_dpasw(a, b, grf_perm)) {
90                     was_injected = true;
91                     i += 2;
92                     continue;
93                 }
94             }
95             if (try_convert_to_dpasw(dpas_infos_[i], grf_perm)) {
96                 was_injected = true;
97             }
98             ++i;
99         }
100         // Nothing to update, no dpas -> dpasw transformation.
101         if (!was_injected) return;
102 
103         int src2_size = 0;
104         object_map_t<stmt_t, int> send2off;
105         std::function<int(const stmt_t &)> get_src2_off;
106         get_src2_off = [&](const stmt_t &s) {
107             auto &si = find_send_info(s);
108             if (!si.base_call.is_empty()) return get_src2_off(si.base_call);
109             if (!si.prev_send.is_empty()) return get_src2_off(si.prev_send);
110 
111             auto it = send2off.find(s);
112             if (it != send2off.end()) return it->second;
113 
114             auto ret = send2off.insert({s, src2_size});
115             if (!ret.second) return ret.first->second;
116 
117             int new_size = si.new_reg_buf_size();
118             src2_size += new_size;
119             return ret.first->second;
120         };
121         for (auto &si : send_infos_) {
122             if (!si.reg_buf_base().is_equal(src2_base)) continue;
123 
124             int src2_off = get_src2_off(si.call);
125             auto src2_sub = src2_base[src2_off];
126             auto new_call = si.new_call;
127             if (!new_call.is_empty()) {
128                 new_call = substitute(
129                         new_call, send_t::arg_reg_buf(new_call), src2_sub, 1);
130             }
131 
132             load_mul_stmt_ = substitute(load_mul_stmt_, si.call, new_call, 1);
133             for (auto &d : si.dpas_consumers) {
134                 auto &di = find_dpas_info(d);
135                 ir_assert(si.promote_to_dpasw == di.promote_to_dpasw)
136                         << "Both send and dpas must be updated.";
137                 if (di.update_applied) {
138                     ir_error_not_expected() << "Can it happen?";
139                     continue;
140                 }
141                 auto new_call = di.new_call;
142                 new_call = substitute(new_call, dpas_t::arg_src2(new_call),
143                         src2_sub[di.src2_relative_off], 1);
144                 load_mul_stmt_
145                         = substitute(load_mul_stmt_, di.call, new_call, 1);
146                 di.update_applied = true;
147             }
148         }
149 
150         // Apply permutation to C store.
151         c_store_stmt_ = apply_permutation_to_reorder(c_store_stmt_, grf_perm);
152 
153         // Update src2 size after applying send updates.
154         alloc_updater_.resize(src2_base, src2_size);
155     }
156 
157 private:
158     struct send_info_t {
159         send_info_t() = default;
160 
send_info_tdnnl::impl::gpu::jit::dpasw_injector_t::send_info_t161         send_info_t(const stmt_t &call) : call(call), new_call(call) {}
162 
senddnnl::impl::gpu::jit::dpasw_injector_t::send_info_t163         const send_t &send() const {
164             return call.as<func_call_t>().func.as<send_t>();
165         }
166 
new_senddnnl::impl::gpu::jit::dpasw_injector_t::send_info_t167         const send_t &new_send() const {
168             ir_assert(!new_call.is_same(call));
169             return new_call.as<func_call_t>().func.as<send_t>();
170         }
171 
argsdnnl::impl::gpu::jit::dpasw_injector_t::send_info_t172         const std::vector<expr_t> &args() const {
173             return call.as<func_call_t>().args;
174         }
175 
reg_bufdnnl::impl::gpu::jit::dpasw_injector_t::send_info_t176         const expr_t &reg_buf() const { return send_t::arg_reg_buf(call); }
177 
reg_buf_basednnl::impl::gpu::jit::dpasw_injector_t::send_info_t178         const expr_t &reg_buf_base() const {
179             return reg_buf().as<ptr_t>().base;
180         }
181 
reg_buf_sizednnl::impl::gpu::jit::dpasw_injector_t::send_info_t182         int reg_buf_size() const { return send().register_size(); }
183 
new_reg_buf_sizednnl::impl::gpu::jit::dpasw_injector_t::send_info_t184         int new_reg_buf_size() const {
185             if (new_call.is_same(call)) return 0;
186             return new_send().register_size();
187         }
188 
set_new_calldnnl::impl::gpu::jit::dpasw_injector_t::send_info_t189         void set_new_call(const stmt_t &s, const stmt_t &base = stmt_t()) {
190             if (!promote_to_dpasw) {
191                 promote_to_dpasw = true;
192                 new_call = s;
193                 base_call = base;
194                 return;
195             }
196             ir_assert(new_call.is_equal(s));
197             ir_assert(base_call.is_equal(base));
198         }
199 
set_prev_senddnnl::impl::gpu::jit::dpasw_injector_t::send_info_t200         void set_prev_send(const stmt_t &s) {
201             int prev_size
202                     = s.as<func_call_t>().func.as<send_t>().register_size();
203             if (reg_buf_size() != prev_size) return;
204             prev_send = s;
205         }
206 
207         stmt_t call;
208         std::vector<stmt_t> dpas_consumers;
209 
210         bool promote_to_dpasw = false;
211         stmt_t new_call;
212         stmt_t base_call;
213         stmt_t prev_send;
214     };
215 
216     struct dpas_info_t {
217         dpas_info_t() = default;
218 
dpas_info_tdnnl::impl::gpu::jit::dpasw_injector_t::dpas_info_t219         dpas_info_t(const stmt_t &call) : call(call), new_call(call) {}
220 
dpasdnnl::impl::gpu::jit::dpasw_injector_t::dpas_info_t221         const dpas_t &dpas() const {
222             return call.as<func_call_t>().func.as<dpas_t>();
223         }
224 
argsdnnl::impl::gpu::jit::dpasw_injector_t::dpas_info_t225         const std::vector<expr_t> &args() const {
226             return call.as<func_call_t>().args;
227         }
228 
src1_bufdnnl::impl::gpu::jit::dpasw_injector_t::dpas_info_t229         const expr_t &src1_buf() const { return dpas_t::arg_src1(call); }
230 
src2_bufdnnl::impl::gpu::jit::dpasw_injector_t::dpas_info_t231         const expr_t &src2_buf() const { return dpas_t::arg_src2(call); }
232 
src2_sizednnl::impl::gpu::jit::dpasw_injector_t::dpas_info_t233         int src2_size() const { return dpas().src2_size(); }
234 
set_new_calldnnl::impl::gpu::jit::dpasw_injector_t::dpas_info_t235         void set_new_call(const stmt_t &s, int src2_relative_off) {
236             if (!promote_to_dpasw) {
237                 promote_to_dpasw = true;
238                 this->src2_relative_off = src2_relative_off;
239                 new_call = s;
240                 return;
241             }
242             ir_assert(this->src2_relative_off == src2_relative_off);
243             ir_assert(new_call.is_equal(s));
244         }
245 
246         stmt_t call;
247         stmt_t send_producer;
248 
249         bool promote_to_dpasw = false;
250         bool update_applied = false;
251         int src2_relative_off = 0;
252         stmt_t new_call;
253     };
254 
find_send_info(const stmt_t & s)255     send_info_t &find_send_info(const stmt_t &s) {
256         for (auto &si : send_infos_)
257             if (si.call.is_same(s)) return si;
258         ir_error_not_expected();
259         return send_infos_.front();
260     }
261 
find_dpas_info(const stmt_t & s)262     dpas_info_t &find_dpas_info(const stmt_t &s) {
263         for (auto &si : dpas_infos_)
264             if (si.call.is_same(s)) return si;
265         ir_error_not_expected();
266         return dpas_infos_.front();
267     }
is_send(const stmt_t & s,send_info_t & info)268     static bool is_send(const stmt_t &s, send_info_t &info) {
269         if (!is_func_call<send_t>(s)) return false;
270         info = send_info_t(s);
271         return true;
272     }
273 
is_dpas(const stmt_t & s,dpas_info_t & info)274     static bool is_dpas(const stmt_t &s, dpas_info_t &info) {
275         if (!is_func_call<dpas_t>(s)) return false;
276         info = dpas_info_t(s);
277         return true;
278     }
279 
extract_dpas_calls(expr_t & src2_base)280     void extract_dpas_calls(expr_t &src2_base) {
281         object_eq_map_t<expr_t, stmt_t> buf2send;
282 
283         auto set_src2_base = [&](const expr_t &ptr) {
284             auto &ptr_base = ptr.as<ptr_t>().base;
285             if (src2_base.is_empty()) {
286                 src2_base = ptr_base;
287                 return;
288             }
289             // This may need a fix in the future.
290             ir_assert(src2_base.is_same(ptr_base));
291         };
292 
293         // Iterate through dpas and send calls.
294         auto stmt_vec = flatten_statements(load_mul_stmt_);
295         for (auto &s : stmt_vec) {
296             send_info_t send_info;
297             if (is_send(s, send_info)) {
298                 auto &buf = send_info.reg_buf();
299                 stmt_t prev_send;
300                 auto it = buf2send.find(buf);
301                 if (it != buf2send.end()) prev_send = it->second;
302                 buf2send[buf] = s;
303                 send_infos_.push_back(send_info);
304                 if (!prev_send.is_empty()) {
305                     send_infos_.back().set_prev_send(prev_send);
306                 }
307                 continue;
308             }
309             dpas_info_t dpas_info;
310             if (is_dpas(s, dpas_info)) {
311                 set_src2_base(dpas_info.src2_buf());
312                 auto &buf = dpas_info.src2_buf();
313                 auto it = buf2send.find(buf);
314                 if (it == buf2send.end()) continue;
315                 auto &send_info = find_send_info(it->second);
316                 // Ensure read size matches DPAS src2 size.
317                 // FIXME: This may not be always the case.
318                 ir_assert(send_info.reg_buf_size() == dpas_info.src2_size());
319                 dpas_info.send_producer = send_info.call;
320                 send_info.dpas_consumers.push_back(s);
321                 dpas_infos_.push_back(dpas_info);
322             }
323         }
324     }
325 
326     // Checks for the following pattern:
327     //    dpas.sxr(a_dst, a_src0, src1, src2)
328     //    dpas.sxr(b_dst, b_src0, src1, src2 + s * r * 4)
can_convert_to_dpasw(const dpas_info_t & a,const dpas_info_t & b)329     static bool can_convert_to_dpasw(
330             const dpas_info_t &a, const dpas_info_t &b) {
331         if (!a.dpas().is_equal(b.dpas())) return false;
332         if (!a.src1_buf().is_equal(b.src1_buf())) return false;
333 
334         auto src2_off0 = to_cpp<int>(a.src2_buf().as<ptr_t>().off);
335         auto src2_off1 = to_cpp<int>(b.src2_buf().as<ptr_t>().off);
336 
337         if (src2_off1 - src2_off0 != a.src2_size()) return false;
338 
339         return true;
340     }
341 
try_convert_to_dpasw(dpas_info_t & a,dpas_info_t & b,grf_permutator_t & grf_perm)342     bool try_convert_to_dpasw(
343             dpas_info_t &a, dpas_info_t &b, grf_permutator_t &grf_perm) {
344         if (hw_ >= ngen::HW::XeHPC) return false;
345 
346         // Check if DPAS -> DPASW transformation is possible.
347         if (!can_convert_to_dpasw(a, b)) return false;
348 
349         // Perform the transformation:
350         // Before:
351         //   send(slm, a_off, src2[0])
352         //   send(slm, b_off, src2[s * r * 4])
353         //   dpas.sxr(a_dst, a_src0, src1, src2[0])
354         //   dpas.sxr(b_dst, b_src0, src1, src2[s * r * 4])
355         // After:
356         //   send(slm, a_off + (tg_idx0 % 2) * (b_off - a_off), src2)
357         //   dpasw.sxr(p_a_dst, p_a_src0, src1, src2[0])
358         //   dpasw.sxr(p_b_dst, p_b_src0, src1, src2[s * r * 4 / 2])
359         // Where:
360         //   p_a_dst[:] = a_dst[0:rcount / 2] + b_dst[0:rcount / 2]
361         //   p_b_dst[:] = a_dst[rcount / 2:rcount] + b_dst[rcount / 2:rcount]
362         ir_assert(a.dpas().is_equal(b.dpas()));
363         auto _dpasw = dpas_t::make_dpasw(a.dpas());
364         auto &dpasw = _dpasw.as<dpas_t>();
365 
366         auto a_args = a.args();
367         auto b_args = b.args();
368         dpas_t::arg_src2(b_args) -= dpasw.src2_size();
369 
370         a.set_new_call(dpasw.call(a.args()), 0);
371         b.set_new_call(dpasw.call(b_args), dpasw.src2_size());
372 
373         // Record permutation for registers to apply it for the destination
374         // store later.
375         const auto grf_size = ngen::GRF::bytes(hw_);
376         const auto rcount = a.dpas().rcount;
377         for (int j = 0; j < rcount; j++) {
378             int k = j % (rcount / 2);
379             auto a_old = dpas_t::arg_dst(a_args) + grf_size * j;
380             auto b_old = dpas_t::arg_dst(b_args) + grf_size * j;
381             expr_t grf_new;
382             if (j < rcount / 2) {
383                 grf_new = dpas_t::arg_dst(a_args)[grf_size * k];
384             } else {
385                 grf_new = dpas_t::arg_dst(b_args)[grf_size * k];
386             }
387             grf_perm.set_permute(a_old, grf_new);
388             grf_perm.set_permute(b_old, grf_new + grf_size * rcount / 2);
389         }
390 
391         auto &a_send = find_send_info(a.send_producer);
392         auto &b_send = find_send_info(b.send_producer);
393 
394         auto &a_mem_off = send_t::arg_mem_off(a_send.call);
395         auto &b_mem_off = send_t::arg_mem_off(b_send.call);
396         auto ab_addr_diff = simplify(b_mem_off - a_mem_off);
397         ir_assert(is_const(ab_addr_diff));
398 
399         auto new_send_args = a_send.args();
400         send_t::arg_mem_off(new_send_args)
401                 += (tg_idx0_ % 2) * to_cpp<int64_t>(ab_addr_diff);
402 
403         a_send.set_new_call(a_send.send().call(new_send_args));
404         b_send.set_new_call(stmt_t(), a_send.call);
405 
406         return true;
407     }
408 
can_convert_to_dpasw(const dpas_info_t & a_dpas,const send_info_t & a_send,const expr_t & tg_idx0)409     static bool can_convert_to_dpasw(const dpas_info_t &a_dpas,
410             const send_info_t &a_send, const expr_t &tg_idx0) {
411         if (contains_object(a_send.call, tg_idx0)) return false;
412         return a_dpas.dpas().rcount % 2 == 0;
413     }
414 
create_half_send(const send_t & send)415     static func_t create_half_send(const send_t &send) {
416         ir_assert(send.data_elems % 2 == 0) << "Can't create half-send.";
417         auto _s = send.with_data_elems(send.data_elems / 2);
418         auto &s = _s.as<send_t>();
419         ir_assert(s.is_supported())
420                 << "Can't find send reading half of the original send.";
421         MAYBE_UNUSED(s);
422         return _s;
423     }
424 
try_convert_to_dpasw(dpas_info_t & a,grf_permutator_t & grf_perm)425     bool try_convert_to_dpasw(dpas_info_t &a, grf_permutator_t &grf_perm) {
426         if (hw_ >= ngen::HW::XeHPC) return false;
427         if (!can_convert_to_dpasw(a, find_send_info(a.send_producer), tg_idx0_))
428             return false;
429 
430         // Perform the transformation:
431         // Before:
432         //   send(slm, a_off, src2[0])
433         //   dpas.sxr(a_dst, a_src0, src1, src2[0])
434         // After:
435         //   send(slm, a_off + (tg_idx0 % 2) * (s * r * 4 / 2), src2)
436         //   dpasw.sxr(a_dst, a_src0, src1, src2[0])
437 
438         auto _dpasw = dpas_t::make_dpasw(a.dpas());
439         auto &dpasw = _dpasw.as<dpas_t>();
440 
441         a.set_new_call(dpasw.call(a.args()), 0);
442 
443         // Real permutation is not required but it needs to be set anyway.
444         const auto grf_size = ngen::GRF::bytes(hw_);
445         const auto rcount = a.dpas().rcount;
446         for (int j = 0; j < rcount; j++) {
447             auto grf = dpas_t::arg_dst(a.args()) + grf_size * j;
448             grf_perm.set_permute(grf, grf);
449         }
450 
451         auto &a_send = find_send_info(a.send_producer);
452         auto new_send_args = a_send.args();
453         send_t::arg_mem_off(new_send_args)
454                 += (tg_idx0_ % 2) * to_cpp<int64_t>(a.src2_size() / 2);
455         a_send.set_new_call(
456                 create_half_send(a_send.send()).call(new_send_args));
457 
458         return true;
459     }
460 
apply_permutation_to_reorder(const stmt_t & stmt,const grf_permutator_t & grf_perm)461     static stmt_t apply_permutation_to_reorder(
462             const stmt_t &stmt, const grf_permutator_t &grf_perm) {
463         return permutation_injector_t(grf_perm).mutate(stmt);
464     }
465 
466     ngen::HW hw_;
467     stmt_t load_mul_stmt_;
468     expr_t c_buf_;
469     stmt_t c_store_stmt_;
470     alloc_updater_t &alloc_updater_;
471     expr_t tg_idx0_;
472 
473     std::vector<dpas_info_t> dpas_infos_;
474     std::vector<send_info_t> send_infos_;
475 };
476 
477 // Transforms DPAS to DPASW.
inject_dpasw(ngen::HW hw,stmt_t & load_mul_stmt,const expr_t & c_buf,stmt_t & c_store_stmt,alloc_updater_t & alloc_updater,const expr_t & tg_idx0)478 void inject_dpasw(ngen::HW hw, stmt_t &load_mul_stmt, const expr_t &c_buf,
479         stmt_t &c_store_stmt, alloc_updater_t &alloc_updater,
480         const expr_t &tg_idx0) {
481     dpasw_injector_t injector(
482             hw, load_mul_stmt, c_buf, c_store_stmt, alloc_updater, tg_idx0);
483     injector.inject();
484 
485     load_mul_stmt = injector.load_mul_stmt();
486     c_store_stmt = injector.c_store_stmt();
487 }
488 
489 // Adds {Atomic} modifier to DPAS/DPASW instructions when applicable.
inject_atomic(const stmt_t & stmt)490 stmt_t inject_atomic(const stmt_t &stmt) {
491     stmt_t ret = stmt;
492     auto stmt_vec = flatten_statements(stmt);
493     for (size_t i = 0; i < stmt_vec.size(); i++) {
494         bool ok = true;
495         ok &= is_func_call<dpas_t>(stmt_vec[i]);
496         ok &= (i + 1 < stmt_vec.size()
497                 && is_func_call<dpas_t>(stmt_vec[i + 1]));
498         if (ok) {
499             auto &cur_src1 = dpas_t::arg_src1(stmt_vec[i]);
500             auto &next_src1 = dpas_t::arg_src1(stmt_vec[i + 1]);
501             // Compare src1, apply {Atomic} if they are equal.
502             if (cur_src1.is_equal(next_src1)) {
503                 auto &s = stmt_vec[i];
504                 auto atomic_attr = instruction_modifier_attr_t::make(
505                         ngen_proxy::InstructionModifier().with_atomic());
506                 ret = substitute(ret, s, atomic_attr.apply_to(s));
507             }
508         }
509     }
510     return ret;
511 }
512 
513 // Trace for debugging purposes.
trace_pass(const char * pass_name,const stmt_t & stmt)514 void trace_pass(const char *pass_name, const stmt_t &stmt) {
515     ir_trace() << "=== After " << pass_name << std::endl;
516     ir_trace() << stmt << std::endl;
517 }
518 
519 class external_var_visitor_t : public scope_visitor_t {
520 public:
_visit(const var_t & obj)521     void _visit(const var_t &obj) {
522         if (!is_expr_defined(obj)) external_vars.insert(obj);
523     }
524 
525     object_eq_set_t<expr_t> external_vars;
526 };
527 
inject_external_var_let(const stmt_t & _stmt)528 stmt_t inject_external_var_let(const stmt_t &_stmt) {
529     auto stmt = _stmt;
530     external_var_visitor_t v;
531     v.visit(stmt);
532 
533     for (auto &var : v.external_vars)
534         stmt = let_t::make(var, {}, stmt);
535 
536     trace_pass("inject_external_var_let", stmt);
537     return stmt;
538 }
539 
540 class slm_buffer_merger_t : public ir_mutator_t {
541 public:
slm_buffer_merger_t()542     slm_buffer_merger_t() {
543         slm_base_ = make_buffer("slm");
544         slm_off_.push_back(0);
545     }
546 
slm_base() const547     const expr_t &slm_base() const { return slm_base_; }
548 
slm_size() const549     int slm_size() const { return slm_size_; }
550 
_mutate(const alloc_t & obj)551     object_t _mutate(const alloc_t &obj) override {
552         if (obj.kind != alloc_kind_t::slm) return ir_mutator_t::_mutate(obj);
553 
554         auto new_buf = push(obj);
555         auto new_obj = ir_mutator_t::_mutate(obj);
556         pop();
557 
558         auto &alloc = new_obj.as<alloc_t>();
559         new_obj = substitute(alloc.body, alloc.buf, new_buf);
560 
561         return new_obj;
562     }
563 
564 private:
push(const alloc_t & obj)565     expr_t push(const alloc_t &obj) {
566         int cur_off = slm_off_.back();
567         expr_t new_buf = slm_base_ + cur_off;
568         slm_off_.push_back(cur_off + obj.size);
569         slm_size_ = std::max(slm_size_, cur_off + obj.size);
570         return new_buf;
571     }
572 
pop()573     void pop() { slm_off_.pop_back(); }
574 
575     expr_t slm_base_;
576     std::vector<int> slm_off_;
577     int slm_size_ = 0;
578 };
579 
580 // Merges all SLM buffers into a single one.
merge_slm_buffers(const stmt_t & _stmt)581 stmt_t merge_slm_buffers(const stmt_t &_stmt) {
582     stmt_t stmt = _stmt;
583     slm_buffer_merger_t merger;
584     stmt = merger.mutate(stmt);
585     stmt = alloc_t::make(
586             merger.slm_base(), merger.slm_size(), alloc_kind_t::slm, {}, stmt);
587     trace_pass("merge_slm_buffers", stmt);
588     return stmt;
589 }
590 
591 class buffer_offset_lifter_t : public ir_mutator_t {
592 public:
_mutate(const func_call_t & obj)593     object_t _mutate(const func_call_t &obj) {
594         if (!obj.func.is<send_t>()) return ir_mutator_t::_mutate(obj);
595 
596         auto &mem_buf = send_t::arg_mem_buf(obj);
597         if (!mem_buf.is<ptr_t>()) return ir_mutator_t::_mutate(obj);
598 
599         auto &base = mem_buf.as<ptr_t>().base;
600         auto &off = mem_buf.as<ptr_t>().off;
601 
602         std::vector<expr_t> new_args = obj.args;
603         send_t::arg_mem_buf(new_args) = base;
604         send_t::arg_mem_off(new_args) += off;
605         return obj.func.call(new_args, obj.attr);
606     }
607 };
608 
lift_buffer_offsets_in_send(const stmt_t & s)609 stmt_t lift_buffer_offsets_in_send(const stmt_t &s) {
610     buffer_offset_lifter_t lifter;
611     auto ret = lifter.mutate(s);
612     trace_pass("lift_buffer_offsets_in_send", ret);
613     return ret;
614 }
615 
simplify_pass(const stmt_t & s,const constraint_set_t & cset)616 stmt_t simplify_pass(const stmt_t &s, const constraint_set_t &cset) {
617     auto ret = simplify(s, cset);
618     trace_pass("simplify_pass", ret);
619     return ret;
620 }
621 
622 class send_injector_t : public ir_mutator_t {
623 public:
send_injector_t(ir_context_t & ir_ctx,const constraint_set_t & cset)624     send_injector_t(ir_context_t &ir_ctx, const constraint_set_t &cset)
625         : ir_ctx_(ir_ctx), cset_(cset) {}
626 
_mutate(const func_call_t & obj)627     object_t _mutate(const func_call_t &obj) {
628         auto *send = obj.func.as_ptr<send_t>();
629         if (!send) return ir_mutator_t::_mutate(obj);
630 
631         auto &mem_buf = send_t::arg_mem_buf(obj);
632         auto &mem_off = send_t::arg_mem_off(obj);
633         auto &reg_buf = send_t::arg_reg_buf(obj);
634         auto &mask = send_t::arg_mask(obj);
635 
636         ir_assert(is_var(mem_buf)) << mem_buf;
637 
638         auto header_buf = ir_ctx_.create_tmp_var(type_t::byte_ptr(), "h");
639         auto off_store = simplify_store(
640                 send->create_offset_store(header_buf, mem_buf, mem_off));
641 
642         auto new_call = func_call_t::make(
643                 obj.func, {mem_buf, header_buf, reg_buf, mask}, obj.attr);
644         auto body = stmt_seq_t::make(off_store, new_call);
645 
646         // Allocate header.
647         return alloc_t::make(
648                 header_buf, send->header_size(), alloc_kind_t::grf, {}, body);
649     }
650 
651 private:
simplify_store(const stmt_t & _store) const652     stmt_t simplify_store(const stmt_t &_store) const {
653         auto &store = _store.as<store_t>();
654 
655         auto value = store.value;
656         value = simplify(value, cset_);
657 
658         // Convert to N-ary form and back to expand multiplications. This
659         // helps to find more common subexpressions during the pass.
660         value = nary_op_canonicalize(value);
661         value = nary_op_back_transform(value);
662 
663         return store_t::make(store.buf, store.off, value, store.stride);
664     }
665 
666     ir_context_t &ir_ctx_;
667     const constraint_set_t &cset_;
668 };
669 
inject_send(const stmt_t & s,ir_context_t & ir_ctx,const constraint_set_t & cset)670 stmt_t inject_send(
671         const stmt_t &s, ir_context_t &ir_ctx, const constraint_set_t &cset) {
672     auto ret = send_injector_t(ir_ctx, cset).mutate(s);
673     trace_pass("inject_send", ret);
674     return ret;
675 }
676 
677 class alloc_lifter_t : public ir_mutator_t {
678 public:
alloc_lifter_t(const stmt_t & root,bool reuse_headers)679     alloc_lifter_t(const stmt_t &root, bool reuse_headers)
680         : reuse_headers_(reuse_headers) {
681         if (!reuse_headers_) return;
682         auto calls = find_objects<func_call_t>(root);
683         for (auto &c : calls) {
684             if (!is_func_call<send_t>(c)) continue;
685             auto header_buf = send_t::arg_mem_off(c);
686             ir_assert(is_var(header_buf)) << header_buf;
687             header_bufs_.insert(header_buf);
688         }
689     }
690 
_mutate(const alloc_t & obj)691     object_t _mutate(const alloc_t &obj) override {
692         if (!do_lift(obj)) return ir_mutator_t::_mutate(obj);
693         // Remove alloc and insert it before the compute loop.
694         allocs_.push_back(&obj);
695         return obj.body;
696     }
697 
_mutate(const stmt_group_t & obj)698     object_t _mutate(const stmt_group_t &obj) override {
699         bool is_compute_loop = (obj.label == stmt_label_t::compute_loop());
700         if (is_compute_loop) in_compute_loop_ = true;
701         auto new_obj = ir_mutator_t::_mutate(obj);
702         if (is_compute_loop) {
703             in_compute_loop_ = false;
704             // Outermost loop.
705             for (auto it = allocs_.rbegin(); it != allocs_.rend(); ++it) {
706                 auto &a = it->as<alloc_t>();
707                 new_obj = alloc_t::make(a.buf, a.size, a.kind, a.attr, new_obj);
708             }
709             allocs_.resize(0);
710         }
711         return new_obj;
712     }
713 
714 private:
do_lift(const alloc_t & obj) const715     bool do_lift(const alloc_t &obj) const {
716         if (!in_compute_loop_) return false;
717         if (reuse_headers_) {
718             bool is_header_alloc = (header_bufs_.count(obj.buf) != 0);
719             return !is_header_alloc;
720         }
721         return true;
722     }
723 
724     bool reuse_headers_;
725     object_set_t<expr_t> header_bufs_;
726 
727     bool in_compute_loop_ = false;
728     std::vector<stmt_t> allocs_;
729 };
730 
731 // Lifts alloc statements out of loops.
lift_alloc(const stmt_t & s,const conv_config_t & cfg)732 stmt_t lift_alloc(const stmt_t &s, const conv_config_t &cfg) {
733     auto ret = alloc_lifter_t(s, cfg.reuse_headers).mutate(s);
734     trace_pass("lift_alloc", ret);
735     return ret;
736 }
737 
738 // Common subexpression elimination support.
739 
740 // Represents an expression-candidate to eliminate.
741 class cse_expr_t {
742 public:
cse_expr_t(const expr_t & expr,const ir_path_t & path,int refs=1,const expr_t & cse_var={})743     cse_expr_t(const expr_t &expr, const ir_path_t &path, int refs = 1,
744             const expr_t &cse_var = {})
745         : expr(expr), path(path), refs(refs), cse_var(cse_var) {
746         ir_trace() << "cse_pass: add expression: " << expr << std::endl;
747     }
748 
add_usage(const ir_path_t & other_path,bool do_increment=true)749     void add_usage(const ir_path_t &other_path, bool do_increment = true) {
750         if (do_increment) refs++;
751         path.merge(other_path);
752         ir_trace() << "cse_pass: add usage: " << expr
753                    << ", total refs: " << refs << std::endl;
754     }
755 
756     // Expression to eliminate via let.
757     expr_t expr;
758     // Path to the innermost IR node where the expression can be defined.
759     ir_path_t path;
760     // Number of references to the expression.
761     int refs;
762     // Variable assigned to the expression (if decided to eliminate).
763     expr_t cse_var;
764 };
765 
766 // Stores information about all expressions subject to CSEing.
767 class cse_context_t {
768 public:
cse_context_t(ir_context_t & ir_ctx)769     cse_context_t(ir_context_t &ir_ctx) : ir_ctx_(ir_ctx) {}
770 
ir_ctx()771     ir_context_t &ir_ctx() { return ir_ctx_; }
772 
has(const expr_t & e) const773     bool has(const expr_t &e) const { return cse_exprs_.count(e) != 0; }
774 
find_cse_expr(const expr_t & e)775     cse_expr_t &find_cse_expr(const expr_t &e) {
776         ir_assert(has(e)) << e;
777         return cse_exprs_.at(e);
778     }
779 
find_cse_expr(const expr_t & e) const780     const cse_expr_t &find_cse_expr(const expr_t &e) const {
781         ir_assert(has(e)) << e;
782         return cse_exprs_.at(e);
783     }
784 
has_var(const expr_t & e) const785     bool has_var(const expr_t &e) const {
786         return !find_cse_expr(e).cse_var.is_empty();
787     }
788 
get_refs(const expr_t & e) const789     int get_refs(const expr_t &e) const {
790         if (!has(e)) return 0;
791         return find_cse_expr(e).refs;
792     }
793 
register_expr(const expr_t & e,const ir_path_t & path)794     void register_expr(const expr_t &e, const ir_path_t &path) {
795         if (e.type().is_bool()) return; // Ignore booleans.
796         auto ret = cse_exprs_.insert({e, cse_expr_t(e, path)});
797         ir_assert(ret.second) << e;
798         MAYBE_UNUSED(ret);
799     }
800 
register_expr(const cse_expr_t & cse_expr)801     void register_expr(const cse_expr_t &cse_expr) {
802         auto ret = cse_exprs_.insert({cse_expr.expr, cse_expr});
803         ir_assert(ret.second);
804         MAYBE_UNUSED(ret);
805     }
806 
get_or_assign_var(const expr_t & e)807     expr_t get_or_assign_var(const expr_t &e) {
808         auto &cse_expr = find_cse_expr(e);
809         if (cse_expr.cse_var.is_empty()) {
810             cse_expr.cse_var = ir_ctx_.create_tmp_var(e.type());
811             ir_trace() << "cse_pass: assigning var: " << e << " -> "
812                        << cse_expr.cse_var << std::endl;
813         }
814         return cse_expr.cse_var;
815     }
816 
get_var(const expr_t & e) const817     const expr_t &get_var(const expr_t &e) const {
818         return find_cse_expr(e).cse_var;
819     }
820 
get_path(const expr_t & e) const821     const ir_path_t &get_path(const expr_t &e) const {
822         return find_cse_expr(e).path;
823     }
824 
add_usage(const expr_t & e,const ir_path_t & path,bool do_increment=true)825     void add_usage(
826             const expr_t &e, const ir_path_t &path, bool do_increment = true) {
827         if (e.type().is_bool()) return; // Ignore booleans.
828         return find_cse_expr(e).add_usage(path, do_increment);
829     }
830 
update_expr(const expr_t & old_expr,const expr_t & new_expr)831     void update_expr(const expr_t &old_expr, const expr_t &new_expr) {
832         auto it = cse_exprs_.find(old_expr);
833         ir_assert(it != cse_exprs_.end()) << old_expr;
834         auto &old_cse_expr = it->second;
835         auto new_cse_expr = cse_expr_t(new_expr, old_cse_expr.path,
836                 old_cse_expr.refs, old_cse_expr.cse_var);
837         cse_exprs_.erase(it);
838         auto ret = cse_exprs_.insert({new_expr, new_cse_expr});
839         ir_assert(ret.second);
840         MAYBE_UNUSED(ret);
841     }
842 
843     template <typename F>
for_each(const F & f) const844     void for_each(const F &f) const {
845         for (auto &kv : cse_exprs_)
846             f(kv.first);
847     }
848 
849 private:
850     ir_context_t &ir_ctx_;
851     object_eq_map_t<expr_t, cse_expr_t> cse_exprs_;
852 };
853 
854 // Collects statistics about expressions for common subexpression elimination.
855 class cse_visitor_t : public ir_visitor_t {
856 public:
cse_visitor_t(cse_context_t & ctx)857     cse_visitor_t(cse_context_t &ctx) : ctx_(ctx) {}
858 
_visit(const binary_op_t & obj)859     void _visit(const binary_op_t &obj) override { visit_expr(obj); }
_visit(const shuffle_t & obj)860     void _visit(const shuffle_t &obj) override {
861         if (is_const_broadcast(obj)) return;
862         visit_expr(obj);
863     }
_visit(const unary_op_t & obj)864     void _visit(const unary_op_t &obj) override { visit_expr(obj); }
865 
866 #define HANDLE_IR_OBJECT(type) \
867     void _visit(const type &obj) override { visit_stmt(obj); }
868 
869     HANDLE_STMT_IR_OBJECTS()
870 
871 #undef HANDLE_IR_OBJECT
872 
873 private:
874     template <typename T>
visit_expr(const T & obj)875     void visit_expr(const T &obj) {
876         // Exclude loads as they may have side effects.
877         if (count_objects<load_t>(obj) > 0) {
878             ir_visitor_t::_visit(obj);
879             return;
880         }
881 
882         if (propagate_path_) {
883             if (ctx_.has(obj))
884                 ctx_.add_usage(obj, root_path_, /*do_increment=*/false);
885             ir_visitor_t::_visit(obj);
886             return;
887         }
888         if (ctx_.has(obj)) {
889             ctx_.add_usage(obj, root_path_);
890             propagate_path_ = true;
891             ir_visitor_t::_visit(obj);
892             propagate_path_ = false;
893             return;
894         }
895         ir_visitor_t::_visit(obj);
896         ctx_.register_expr(obj, root_path_);
897     }
898 
899     template <typename T>
visit_stmt(const T & obj)900     void visit_stmt(const T &obj) {
901         if (std::is_same<T, for_t>::value) {
902             visit_for((const object_impl_t &)obj);
903             return;
904         }
905         if (std::is_same<T, let_t>::value) {
906             visit_let((const object_impl_t &)obj);
907             return;
908         }
909         root_path_.push(&obj);
910         ir_visitor_t::_visit(obj);
911         root_path_.pop();
912     }
913 
visit_for(const object_impl_t & _obj)914     void visit_for(const object_impl_t &_obj) {
915         auto &obj = (const for_t &)_obj;
916 
917         visit(obj.var);
918         visit(obj.init);
919         visit(obj.bound);
920         root_path_.push(&obj);
921         visit(obj.body);
922         root_path_.pop();
923     }
924 
visit_let(const object_impl_t & _obj)925     void visit_let(const object_impl_t &_obj) {
926         auto &obj = (const let_t &)_obj;
927 
928         visit(obj.var);
929         visit(obj.value);
930         root_path_.push(&obj);
931         visit(obj.body);
932         root_path_.pop();
933     }
934 
935     cse_context_t &ctx_;
936     ir_path_t root_path_;
937 
938     bool propagate_path_ = false;
939 };
940 
941 // Verifies all IR paths are correct (for debugging purposes).
942 class cse_verifier_t : public scope_visitor_t {
943 public:
cse_verifier_t(cse_context_t & ctx)944     cse_verifier_t(cse_context_t &ctx) : ctx_(ctx) {}
945 
~cse_verifier_t()946     ~cse_verifier_t() override { ir_assert(to_check_.empty()); }
947 
_visit(const binary_op_t & obj)948     void _visit(const binary_op_t &obj) override { visit_expr(obj); }
_visit(const shuffle_t & obj)949     void _visit(const shuffle_t &obj) override { return visit_expr(obj); }
_visit(const unary_op_t & obj)950     void _visit(const unary_op_t &obj) override { visit_expr(obj); }
951 
952 #define HANDLE_IR_OBJECT(type) \
953     void _visit(const type &obj) override { visit_stmt(obj); }
954 
HANDLE_STMT_IR_OBJECTS()955     HANDLE_STMT_IR_OBJECTS()
956 
957 #undef HANDLE_IR_OBJECT
958 
959     void verify(const stmt_t &s) {
960         // Phase 0: collect IR paths for expressions.
961         phase_ = 0;
962         visit(s);
963 
964         // Phase 1: verify all expressions are defined at their path.
965         phase_ = 1;
966         visit(s);
967     }
968 
969 private:
970     template <typename T>
visit_expr(const T & obj)971     void visit_expr(const T &obj) {
972         // Expressions are not used during phase 1.
973         if (phase_ == 1) return;
974         if (ctx_.has(obj)) {
975             auto &path = ctx_.get_path(obj);
976             to_check_[path.back()].push_back(obj);
977         }
978         scope_visitor_t::_visit(obj);
979     }
980 
981     template <typename T>
visit_stmt(const T & obj)982     void visit_stmt(const T &obj) {
983         scope_visitor_t::_visit(obj);
984 
985         // Statements are not used during phase 0.
986         if (phase_ == 0) return;
987 
988         // Phase 1: check that all attached expressions are defined at this
989         // statement.
990         auto it = to_check_.find(obj);
991         if (it != to_check_.end()) {
992             for (auto &e : it->second) {
993                 ir_assert(is_expr_defined(e))
994                         << "Expression contains undefined variables: " << e;
995                 MAYBE_UNUSED(e);
996             }
997             to_check_.erase(it);
998         }
999     }
1000 
1001     cse_context_t &ctx_;
1002 
1003     int phase_ = 0;
1004     object_map_t<stmt_t, std::vector<expr_t>> to_check_;
1005 };
1006 
1007 // Generates let statements for expressions being eliminated.
1008 class cse_let_generator_t : public ir_visitor_t {
1009 public:
cse_let_generator_t(const cse_context_t & ctx,const stmt_t & stmt)1010     cse_let_generator_t(const cse_context_t &ctx, const stmt_t &stmt)
1011         : ctx_(ctx), stmt_(stmt) {}
1012 
_visit(const binary_op_t & obj)1013     void _visit(const binary_op_t &obj) override { visit_expr(obj); }
_visit(const shuffle_t & obj)1014     void _visit(const shuffle_t &obj) override { visit_expr(obj); }
_visit(const unary_op_t & obj)1015     void _visit(const unary_op_t &obj) override { visit_expr(obj); }
_visit(const var_t & obj)1016     void _visit(const var_t &obj) override {
1017         auto it = all_vars_.find(obj);
1018         if (it == all_vars_.end()) return;
1019         if (seen_vars_.count(obj) == 0) generate_for_expr(it->second);
1020     }
1021 
generate()1022     stmt_t generate() {
1023         ctx_.for_each([&](const expr_t &e) {
1024             auto &cse_var = ctx_.get_var(e);
1025             auto ret = all_vars_.insert({cse_var, e});
1026             ir_assert(ret.second);
1027             MAYBE_UNUSED(ret);
1028         });
1029         ctx_.for_each([&](const expr_t &e) { generate_for_expr(e); });
1030         for (auto it = lets_.rbegin(); it != lets_.rend(); ++it) {
1031             auto &let = it->as<let_t>();
1032             stmt_ = let_t::make(let.var, let.value, stmt_);
1033         }
1034         return stmt_;
1035     }
1036 
1037 private:
generate_for_expr(const expr_t & e)1038     void generate_for_expr(const expr_t &e) {
1039         auto &cse_var = ctx_.get_var(e);
1040         if (seen_vars_.count(cse_var) == 1) return;
1041         visit(e);
1042     }
1043 
1044     template <typename T>
visit_expr(const T & obj)1045     void visit_expr(const T &obj) {
1046         ir_visitor_t::_visit(obj);
1047         if (ctx_.has(obj) && ctx_.has_var(obj)) {
1048             auto &var = ctx_.get_var(obj);
1049             auto ret = seen_vars_.insert(var);
1050             if (ret.second) lets_.push_back(let_t::make(var, obj));
1051         }
1052     }
1053 
1054     const cse_context_t &ctx_;
1055     stmt_t stmt_;
1056 
1057     object_map_t<expr_t, expr_t> all_vars_; // Var -> expression.
1058     object_set_t<expr_t> seen_vars_;
1059 
1060     std::vector<stmt_t> lets_;
1061 };
1062 
1063 // Eliminiates expressions from the statement.
1064 class cse_mutator_t : public ir_mutator_t {
1065 public:
cse_mutator_t(cse_context_t & ctx)1066     cse_mutator_t(cse_context_t &ctx) : ctx_(ctx) {}
1067 
_mutate(const binary_op_t & obj)1068     object_t _mutate(const binary_op_t &obj) override {
1069         return mutate_expr(obj);
1070     }
_mutate(const shuffle_t & obj)1071     object_t _mutate(const shuffle_t &obj) override { return mutate_expr(obj); }
_mutate(const unary_op_t & obj)1072     object_t _mutate(const unary_op_t &obj) override {
1073         return mutate_expr(obj);
1074     }
1075 
1076 #define HANDLE_IR_OBJECT(type) \
1077     object_t _mutate(const type &obj) override { return mutate_stmt(obj); }
1078 
1079     HANDLE_STMT_IR_OBJECTS()
1080 
1081 #undef HANDLE_IR_OBJECT
1082 
1083 private:
1084     template <typename T>
mutate_expr(const T & obj)1085     object_t mutate_expr(const T &obj) {
1086         auto new_obj = ir_mutator_t::_mutate(obj);
1087         if (ctx_.has(obj) && !new_obj.is_equal(obj)) {
1088             ctx_.update_expr(obj, new_obj);
1089         }
1090         if (ctx_.get_refs(new_obj) > 1) {
1091             bool has_var = ctx_.has_var(new_obj);
1092             auto var = ctx_.get_or_assign_var(new_obj);
1093             auto &path = ctx_.get_path(new_obj);
1094             if (!has_var) to_update_[path.back()].push_back(new_obj);
1095             return std::move(var);
1096         }
1097         return new_obj;
1098     }
1099 
1100     template <typename T>
mutate_stmt(const T & obj)1101     object_t mutate_stmt(const T &obj) {
1102         auto new_obj = ir_mutator_t::_mutate(obj);
1103         auto it = to_update_.find(obj);
1104         if (it == to_update_.end()) return new_obj;
1105 
1106         cse_context_t local_ctx(ctx_.ir_ctx());
1107         for (auto &e : it->second) {
1108             local_ctx.register_expr(ctx_.find_cse_expr(e));
1109         }
1110         to_update_.erase(it);
1111 
1112         auto body = get_stmt_body(new_obj);
1113         cse_let_generator_t g(local_ctx, body);
1114         body = g.generate();
1115         new_obj = replace_stmt_body(new_obj, body);
1116         return new_obj;
1117     }
1118 
1119     cse_context_t &ctx_;
1120     object_map_t<stmt_t, std::vector<expr_t>> to_update_;
1121 };
1122 
eliminate_common_subexprs(const stmt_t & _stmt,ir_context_t & ir_ctx)1123 stmt_t eliminate_common_subexprs(const stmt_t &_stmt, ir_context_t &ir_ctx) {
1124     auto stmt = _stmt;
1125 
1126     cse_context_t ctx(ir_ctx);
1127 
1128     // Collect statistics.
1129     cse_visitor_t visitor(ctx);
1130     visitor.visit(stmt);
1131 
1132 #ifndef NDEBUG
1133     // Verify that collected IR paths are correct (cse_expr_t objects are
1134     // defined at those paths).
1135     cse_verifier_t verifier(ctx);
1136     verifier.verify(stmt);
1137 #endif
1138 
1139     // Eliminate subexpressions.
1140     cse_mutator_t mutator(ctx);
1141     stmt = mutator.mutate(stmt);
1142 
1143     trace_pass("eliminate_common_subexprs", stmt);
1144     return stmt;
1145 }
1146 
1147 class hoist_exprs_mutator_t : public ir_mutator_t {
1148 public:
hoist_exprs_mutator_t(ir_context_t & ir_ctx)1149     hoist_exprs_mutator_t(ir_context_t &ir_ctx) : ir_ctx_(ir_ctx) {}
1150 
~hoist_exprs_mutator_t()1151     ~hoist_exprs_mutator_t() override { ir_assert(let_vars_.empty()); }
1152 
_mutate(const func_call_t & obj)1153     object_t _mutate(const func_call_t &obj) override {
1154         if (!obj.func.is<send_t>()) return ir_mutator_t::_mutate(obj);
1155 
1156         std::vector<expr_t> new_args;
1157         for (auto &e : obj.args) {
1158             new_args.push_back(hoist_expr(e));
1159         }
1160 
1161         if (ir_utils::is_equal(new_args, obj.args)) return obj;
1162 
1163         return func_call_t::make(obj.func, new_args, obj.attr);
1164     }
1165 
_mutate(const stmt_group_t & obj)1166     object_t _mutate(const stmt_group_t &obj) override {
1167         if (obj.body.is<for_t>()) {
1168             loops_.emplace_back(obj.body.as<for_t>().var);
1169             const for_t *for_obj = obj.body.as_ptr<for_t>();
1170             auto body = for_obj ? ir_mutator_t::_mutate(*for_obj) : for_obj;
1171             if (body.is_same(obj.body)) return obj;
1172             auto new_obj = stmt_group_t::make(obj.label, body);
1173             return injects_lets_and_pop_loop(new_obj);
1174         }
1175         return ir_mutator_t::_mutate(obj);
1176     }
1177 
_mutate(const store_t & obj)1178     object_t _mutate(const store_t &obj) override {
1179         auto value = hoist_expr(obj.value);
1180         if (value.is_equal(obj.value)) return obj;
1181         return store_t::make(obj.buf, obj.off, value, obj.stride);
1182     }
1183 
_mutate(const for_t & obj)1184     object_t _mutate(const for_t &obj) override {
1185         loops_.emplace_back(obj.var);
1186         auto new_obj = ir_mutator_t::_mutate(obj);
1187         return injects_lets_and_pop_loop(new_obj);
1188     }
1189 
_mutate(const let_t & obj)1190     object_t _mutate(const let_t &obj) override {
1191         bool fully_hoisted = false;
1192         auto new_value = hoist_expr(obj.value, obj.var, &fully_hoisted);
1193         if (fully_hoisted) return mutate(obj.body);
1194         register_let(obj.var, new_value);
1195         auto new_obj = let_t::make(
1196                 obj.var, new_value, ir_mutator_t::mutate(obj.body));
1197         unregister_let(obj.var);
1198         return std::move(new_obj);
1199     }
1200 
1201 private:
1202     struct loop_info_t {
loop_info_tdnnl::impl::gpu::jit::hoist_exprs_mutator_t::loop_info_t1203         loop_info_t(const expr_t &var) : var(var) {}
1204 
1205         expr_t var;
1206         int var_count = 0;
1207         std::vector<stmt_t> lets;
1208     };
1209 
hoist_expr(const expr_t & expr,const expr_t & expr_var={},bool * fully_hoisted=nullptr)1210     expr_t hoist_expr(const expr_t &expr, const expr_t &expr_var = {},
1211             bool *fully_hoisted = nullptr) {
1212         if (expr.is_empty()) return expr;
1213         if (expr.type().is_ptr()) return expr;
1214         if (expr.type().is_bool()) return expr;
1215         if (is_const(expr) || is_shuffle_const(expr) || is_var(expr))
1216             return expr;
1217 
1218         auto hoisted_expr = hoist_expr_with_add(expr, expr_var, fully_hoisted);
1219         if (!hoisted_expr.is_equal(expr)) return hoisted_expr;
1220 
1221         // hoist_expr_with_add() doesn't handle cast so try to hoist it manually.
1222         auto *cast = expr.as_ptr<cast_t>();
1223         if (!cast) return hoisted_expr;
1224 
1225         auto hoisted_cast_expr = hoist_expr(cast->expr);
1226         if (!hoisted_cast_expr.is_equal(cast->expr)) {
1227             hoisted_expr = cast_t::make(
1228                     cast->type, hoisted_cast_expr, cast->saturate);
1229         }
1230         return hoisted_expr;
1231     }
1232 
hoist_expr_with_add(const expr_t & expr,const expr_t & expr_var={},bool * fully_hoisted=nullptr)1233     expr_t hoist_expr_with_add(const expr_t &expr, const expr_t &expr_var = {},
1234             bool *fully_hoisted = nullptr) {
1235         auto cur_expr = nary_op_canonicalize(expr);
1236 
__anon70af794b0502(const expr_t &e) 1237         auto is_nary_add = [](const expr_t &e) {
1238             auto *nary = e.as_ptr<nary_op_t>();
1239             return nary && (nary->op_kind == op_kind_t::_add);
1240         };
1241 
1242         for (size_t i = 0; i < loops_.size(); i++) {
1243             std::vector<expr_t> invariant_args;
1244             std::vector<expr_t> other_args;
1245             std::vector<expr_t> nary_args;
1246             if (is_nary_add(cur_expr)) {
1247                 nary_args = cvt_expr_to_nary_op_args(cur_expr);
1248             } else {
1249                 nary_args.push_back(cur_expr);
1250             }
1251             for (auto &_a : nary_args) {
1252                 auto a = nary_op_back_transform(_a);
1253                 bool is_inv_arg = true;
1254                 for (size_t j = i; j < loops_.size(); j++) {
1255                     if (!is_invariant(a, loops_[j].var)) is_inv_arg = false;
1256                 }
1257                 if (is_inv_arg) {
1258                     invariant_args.push_back(_a);
1259                 } else {
1260                     other_args.push_back(_a);
1261                 }
1262             }
1263             // Nothing to hoist for this loop, continue.
1264             if (invariant_args.empty()) continue;
1265             if (invariant_args.size() == 1 && is_var(invariant_args[0]))
1266                 continue;
1267 
1268             // Introduce new variable for the invariant sub-expression.
1269             auto inv_expr = nary_op_back_transform(
1270                     make_nary_op(op_kind_t::_add, invariant_args));
1271             expr_t inv_var;
1272             if (!expr_var.is_empty() && other_args.empty()) {
1273                 // If nothing to hoist further, reuse the old variable and
1274                 // return.
1275                 inv_var = expr_var;
1276             } else {
1277                 inv_var = ir_ctx_.create_tmp_var(inv_expr.type());
1278             }
1279             auto let = let_t::make(inv_var, inv_expr);
1280             register_let(inv_var, inv_expr);
1281             loops_[i].lets.push_back(let);
1282 
1283             if (other_args.empty()) {
1284                 if (fully_hoisted) *fully_hoisted = true;
1285                 return inv_var;
1286             }
1287 
1288             other_args.push_back(inv_var);
1289             cur_expr = make_nary_op(op_kind_t::_add, other_args);
1290         }
1291         return nary_op_back_transform(cur_expr);
1292     }
1293 
injects_lets_and_pop_loop(const stmt_t & _s)1294     stmt_t injects_lets_and_pop_loop(const stmt_t &_s) {
1295         stmt_t s = _s;
1296         // Inject let statements if any.
1297         auto &lets = loops_.back().lets;
1298         for (auto it = lets.rbegin(); it != lets.rend(); ++it) {
1299             auto &let = it->as<let_t>();
1300             s = let_t::make(let.var, let.value, s);
1301             unregister_let(let.var);
1302         }
1303         loops_.pop_back();
1304         return s;
1305     }
1306 
register_let(const expr_t & var,const expr_t & value)1307     void register_let(const expr_t &var, const expr_t &value) {
1308         let_vars_.insert({var, value});
1309     }
1310 
unregister_let(const expr_t & var)1311     void unregister_let(const expr_t &var) { let_vars_.erase(var); }
1312 
is_invariant(const expr_t & e,const expr_t & var) const1313     bool is_invariant(const expr_t &e, const expr_t &var) const {
1314         if (contains_object(e, var)) return false;
1315         if (!find_objects<load_t>(e).empty()) return false;
1316 
1317         // Check value if this is a let variable.
1318         auto it = let_vars_.find(e);
1319         if (it != let_vars_.end()) return is_invariant(it->second, var);
1320 
1321         if (is_var(e)) return true;
1322 
1323         // Check transitive dependencies.
1324         auto vars = find_unique_objects<var_t>(e);
1325         for (auto &v : vars) {
1326             if (!is_invariant(v, var)) return false;
1327         }
1328         return true;
1329     }
1330 
1331     ir_context_t &ir_ctx_;
1332     std::vector<loop_info_t> loops_;
1333 
1334     object_map_t<expr_t, expr_t> let_vars_;
1335 };
1336 
1337 // Moves invariant expressions out of loops.
hoist_exprs(const stmt_t & s,ir_context_t & ir_ctx)1338 stmt_t hoist_exprs(const stmt_t &s, ir_context_t &ir_ctx) {
1339     auto ret = hoist_exprs_mutator_t(ir_ctx).mutate(s);
1340     trace_pass("hoist_exprs", ret);
1341     return ret;
1342 }
1343 
1344 class loop_strength_reducer_t : public ir_mutator_t {
1345 public:
loop_strength_reducer_t()1346     loop_strength_reducer_t() {
1347         // Create top-level dummy loop.
1348         loops_.emplace_back();
1349     }
1350 
~loop_strength_reducer_t()1351     ~loop_strength_reducer_t() override {
1352         // Sanity check, all stores must be applied.
1353         ir_assert(post_inc_stores.empty());
1354     }
1355 
_mutate(const for_t & obj)1356     object_t _mutate(const for_t &obj) override {
1357         loops_.emplace_back(obj);
1358         auto new_obj = ir_mutator_t::_mutate(obj);
1359         return inject_stores_and_pop_loop(new_obj);
1360     }
1361 
_mutate(const let_t & obj)1362     object_t _mutate(const let_t &obj) override {
1363         int loop_level = int(loops_.size()) - 1;
1364         auto ret = lets_.insert(
1365                 {obj.var, let_info_t(obj.var, obj.value, loop_level)});
1366         ir_assert(ret.second);
1367         MAYBE_UNUSED(ret);
1368         auto new_obj = ir_mutator_t::_mutate(obj);
1369         lets_.erase(obj.var);
1370         return new_obj;
1371     }
1372 
_mutate(const stmt_group_t & obj)1373     object_t _mutate(const stmt_group_t &obj) override {
1374         if (obj.body.is<for_t>()) {
1375             loops_.emplace_back(obj.body);
1376             const for_t *for_obj = obj.body.as_ptr<for_t>();
1377             auto body = for_obj ? ir_mutator_t::_mutate(*for_obj) : for_obj;
1378             if (body.is_same(obj.body)) return obj;
1379             auto new_obj = stmt_group_t::make(obj.label, body);
1380             return inject_stores_and_pop_loop(new_obj);
1381         }
1382         return ir_mutator_t::_mutate(obj);
1383     }
1384 
1385     // Pattern to handle:
1386     //     for (...) {
1387     //         store(buf_ptr, ...) <- Write (producer).
1388     //         // ...
1389     //         stmt_t(..., buf_ptr, ...) <- Read (consumer).
1390     //     }
_mutate(const store_t & obj)1391     object_t _mutate(const store_t &obj) override {
1392         if (loops_.size() == 1) return ir_mutator_t::_mutate(obj);
1393 
1394         // Try to reduce strength, moving the store up.
1395         int init_store_level = -1;
1396         stmt_t init_store_stmt = obj;
1397         post_inc_store_info_t post_inc_store(obj);
1398         for (int level = int(loops_.size()) - 1; level >= 1; level--) {
1399             auto &loop_info = loops_[level];
1400             int refs = count_object(loop_info.loop, obj.buf);
1401             // Producer and consumer - must be 2 references.
1402             if (refs != 2) break;
1403 
1404             // Try to insert the store before level-th loop.
1405             auto &store = init_store_stmt.as<store_t>();
1406             auto &store_value = store.value;
1407             auto &loop_var = loop_info.loop_var();
1408 
1409             auto cur_value = substitute_let(store_value, level);
1410             auto next_value = substitute(cur_value, loop_var, loop_var + 1);
1411             auto inc = simplify(next_value - cur_value);
1412 
1413             // Cannot eliminate loop variable, break.
1414             if (contains_object(inc, loop_var)) break;
1415 
1416             // Success, replace store by post-increment store.
1417             init_store_level = level;
1418 
1419             auto new_store_value
1420                     = substitute(cur_value, loop_var, loop_info.loop_init());
1421             init_store_stmt = store_t::make(store.buf, store.off,
1422                     simplify(new_store_value), store.stride);
1423 
1424             post_inc_store.update(loop_info, inc);
1425         }
1426 
1427         // Can't do anything, return as is.
1428         if (init_store_level == -1) return ir_mutator_t::_mutate(obj);
1429 
1430         // Move this store up, remove from here.
1431         loops_[init_store_level].init_stores.push_back(init_store_stmt);
1432         if (!post_inc_store.is_empty()) {
1433             auto ret = post_inc_stores.insert({obj.buf, post_inc_store});
1434             ir_assert(ret.second);
1435             MAYBE_UNUSED(ret);
1436         }
1437         return stmt_t();
1438     }
1439 
_mutate(const func_call_t & obj)1440     object_t _mutate(const func_call_t &obj) override {
1441         for (auto &kv : post_inc_stores) {
1442             int refs = count_object(obj, kv.first);
1443             if (refs == 1) {
1444                 auto ret = stmt_seq_t::make(obj, kv.second.stmt());
1445                 post_inc_stores.erase(kv.first);
1446                 return std::move(ret);
1447             }
1448         }
1449         return ir_mutator_t::_mutate(obj);
1450     }
1451 
1452 private:
1453     struct loop_info_t {
loop_info_tdnnl::impl::gpu::jit::loop_strength_reducer_t::loop_info_t1454         loop_info_t(const stmt_t &loop = {}) : loop(loop) {}
1455 
loop_vardnnl::impl::gpu::jit::loop_strength_reducer_t::loop_info_t1456         const expr_t &loop_var() const { return loop.as<for_t>().var; }
1457 
loop_initdnnl::impl::gpu::jit::loop_strength_reducer_t::loop_info_t1458         const expr_t &loop_init() const { return loop.as<for_t>().init; }
1459 
loop_bounddnnl::impl::gpu::jit::loop_strength_reducer_t::loop_info_t1460         const expr_t &loop_bound() const { return loop.as<for_t>().bound; }
1461 
loop_extentdnnl::impl::gpu::jit::loop_strength_reducer_t::loop_info_t1462         expr_t loop_extent() const { return loop_bound() - loop_init(); }
1463 
1464         // Loop being analyzed.
1465         stmt_t loop;
1466         // Stores to insert before the loop.
1467         std::vector<stmt_t> init_stores;
1468 
1469         std::vector<stmt_t> lets;
1470     };
1471 
1472     struct let_info_t {
let_info_tdnnl::impl::gpu::jit::loop_strength_reducer_t::let_info_t1473         let_info_t(const expr_t &var, const expr_t &value, int loop_level)
1474             : var(var), value(value), loop_level(loop_level) {}
1475 
1476         expr_t var;
1477         expr_t value;
1478         int loop_level;
1479     };
1480 
1481     struct post_inc_store_info_t {
post_inc_store_info_tdnnl::impl::gpu::jit::loop_strength_reducer_t::post_inc_store_info_t1482         post_inc_store_info_t(const store_t &obj)
1483             : store(&obj), inc(0), last_iter_cond(true), compensation(0) {}
1484 
stmtdnnl::impl::gpu::jit::loop_strength_reducer_t::post_inc_store_info_t1485         stmt_t stmt() const {
1486             auto load
1487                     = load_t::make(store->value.type(), store->buf, store->off);
1488             return store_t::make(store->buf, store->off, load + inc);
1489         }
1490 
is_emptydnnl::impl::gpu::jit::loop_strength_reducer_t::post_inc_store_info_t1491         bool is_empty() const { return is_zero(inc); }
1492 
updatednnl::impl::gpu::jit::loop_strength_reducer_t::post_inc_store_info_t1493         void update(const loop_info_t &loop, const expr_t &loop_inc) {
1494             inc = simplify(iif_t::make(
1495                     last_iter_cond, inc - compensation + loop_inc, inc));
1496             if (last_iter_cond.is_equal(expr_t(true))) {
1497                 last_iter_cond = (loop.loop_var() == loop.loop_bound() - 1);
1498             } else {
1499                 last_iter_cond = last_iter_cond
1500                         & (loop.loop_var() == loop.loop_bound() - 1);
1501             }
1502             compensation = simplify(loop.loop_extent() * loop_inc);
1503         }
1504 
1505         const store_t *store;
1506         expr_t inc;
1507 
1508         expr_t last_iter_cond;
1509         expr_t compensation;
1510     };
1511 
1512     // Recursively substitutes all variable from let statements located under
1513     // the given loop level.
substitute_let(const expr_t & _e,int loop_level) const1514     expr_t substitute_let(const expr_t &_e, int loop_level) const {
1515         auto e = _e;
1516         for (;;) {
1517             bool found = false;
1518             auto vars = find_unique_objects<var_t>(e);
1519             for (auto &v : vars) {
1520                 auto it = lets_.find(v);
1521                 if (it == lets_.end()) continue;
1522                 auto &let_info = it->second;
1523                 // Do not substitute top-level let variables.
1524                 if (let_info.loop_level < loop_level) continue;
1525                 found = true;
1526                 e = substitute(e, v, let_info.value);
1527             }
1528             if (!found) break;
1529         }
1530         return e;
1531     }
1532 
1533     // Injects initial store statements if any.
inject_stores_and_pop_loop(const stmt_t & _s)1534     object_t inject_stores_and_pop_loop(const stmt_t &_s) {
1535         stmt_t s = _s;
1536         auto &stores = loops_.back().init_stores;
1537         for (auto it = stores.rbegin(); it != stores.rend(); ++it) {
1538             s = stmt_seq_t::make(*it, s);
1539         }
1540         loops_.pop_back();
1541         // The top-level dummy loop shouldn't be removed.
1542         ir_assert(loops_.size() >= 1);
1543         return std::move(s);
1544     }
1545 
1546     // Loops, ordered from outermost to innermost. The first loop is dummy, to
1547     // represent let statements in the top-level scope.
1548     std::vector<loop_info_t> loops_;
1549 
1550     // Buffers whose references are to be updated.
1551     object_map_t<expr_t, post_inc_store_info_t> post_inc_stores;
1552 
1553     // Let statements available at the current IR node.
1554     object_map_t<expr_t, let_info_t> lets_;
1555 };
1556 
1557 // Detects and converts expensive expression operations inside a loop to less
1558 // expensive operations. Example:
1559 // Before:
1560 //     for (int j = 0; j < N; j++) {
1561 //         int off = off_i + j * K;
1562 //         a[off] = j;
1563 //     }
1564 // After:
1565 //     int off = off_i;
1566 //     for (int j = 0; j < N; j++) {
1567 //         a[off] = j;
1568 //         off += K;
1569 //     }
loop_strength_reduce(const stmt_t & s)1570 stmt_t loop_strength_reduce(const stmt_t &s) {
1571     auto ret = loop_strength_reducer_t().mutate(s);
1572     trace_pass("loop_strength_reduce", ret);
1573     return ret;
1574 }
1575 
1576 class alloc_let_optimizer_t : public ir_mutator_t {
1577 public:
1578     // Also track alloc_t and for_t to validate all variable usages.
_mutate(const alloc_t & obj)1579     object_t _mutate(const alloc_t &obj) override {
1580         return mutate_scope(obj, obj.buf);
1581     }
1582 
_mutate(const for_t & obj)1583     object_t _mutate(const for_t &obj) override {
1584         level_++;
1585         auto new_obj = mutate_scope(obj, obj.var);
1586         level_--;
1587         return new_obj;
1588     }
1589 
_mutate(const let_t & obj)1590     object_t _mutate(const let_t &obj) override {
1591         return mutate_scope(obj, obj.var);
1592     }
1593 
_mutate(const store_t & obj)1594     object_t _mutate(const store_t &obj) override {
1595         auto &base = (obj.buf.is<var_t>() ? obj.buf : obj.buf.as<ptr_t>().base);
1596         // Do not count store references. If there are only stores to a buffer
1597         // and no other usages, the buffer can be safely removed.
1598         skip_var_ = base;
1599         auto new_obj = ir_mutator_t::_mutate(obj);
1600         skip_var_ = expr_t();
1601         return new_obj;
1602     }
1603 
_mutate(const var_t & obj)1604     object_t _mutate(const var_t &obj) override {
1605         ir_assert(refs_.count(obj) == 1)
1606                 << "Variable is not defined: " << expr_t(&obj);
1607         if (!skip_var_.is_same(obj)) refs_[&obj].update(increment_, level_);
1608         return ir_mutator_t::_mutate(obj);
1609     }
1610 
1611 private:
1612     struct ref_info_t {
ref_info_tdnnl::impl::gpu::jit::alloc_let_optimizer_t::ref_info_t1613         ref_info_t(int level = 0)
1614             : refs(0), min_level(level), max_level(level) {}
1615 
updatednnl::impl::gpu::jit::alloc_let_optimizer_t::ref_info_t1616         void update(int increment, int level) {
1617             refs += increment;
1618             max_level = std::max(max_level, level);
1619         }
1620 
is_same_leveldnnl::impl::gpu::jit::alloc_let_optimizer_t::ref_info_t1621         bool is_same_level() const { return min_level == max_level; }
1622 
1623         int refs;
1624         int min_level;
1625         int max_level;
1626     };
1627 
1628     template <typename T>
mutate_scope(const T & obj,const expr_t & var)1629     object_t mutate_scope(const T &obj, const expr_t &var) {
1630         auto ret = refs_.insert({var, ref_info_t(level_)});
1631         ir_assert(ret.second) << stmt_t(obj);
1632         MAYBE_UNUSED(ret);
1633 
1634         auto new_obj = ir_mutator_t::_mutate(obj);
1635         auto &ref_info = refs_[var];
1636 
1637         if (std::is_same<T, let_t>()) {
1638             new_obj = mutate_let(new_obj.template as<let_t>(), ref_info);
1639         } else if (std::is_same<T, alloc_t>()) {
1640             new_obj = mutate_alloc(new_obj.template as<alloc_t>(), ref_info);
1641         }
1642 
1643         refs_.erase(var);
1644         return new_obj;
1645     }
1646 
mutate_let(const let_t & obj,const ref_info_t & ref_info)1647     object_t mutate_let(const let_t &obj, const ref_info_t &ref_info) {
1648         ir_assert(ref_info.refs >= 1);
1649         if (ref_info.refs == 1) {
1650             // Variable is not used.
1651             remove_refs(obj);
1652             return obj.body;
1653         }
1654         // Check following conditions to substitute let value:
1655         // - 2 references: one from producer, one from consumer - means single usage
1656         // - Consumer and producer are on the same level (same loop)
1657         // - Variable is not external
1658         if (ref_info.refs == 2 && ref_info.is_same_level()
1659                 && !obj.value.is_empty()) {
1660             return substitute(obj.body, obj.var, obj.value);
1661         }
1662         return obj;
1663     }
1664 
mutate_alloc(const alloc_t & obj,const ref_info_t & ref_info)1665     object_t mutate_alloc(const alloc_t &obj, const ref_info_t &ref_info) {
1666         ir_assert(ref_info.refs >= 1);
1667         // Buffer is not used, single reference from alloc_t itself. Remove
1668         // stores to the buffer if any.
1669         if (ref_info.refs == 1) return remove_stores(obj.body, obj.buf);
1670         return obj;
1671     }
1672 
remove_refs(const let_t & obj)1673     void remove_refs(const let_t &obj) {
1674         increment_ = -1;
1675         mutate(obj.value);
1676         increment_ = 1;
1677     }
1678 
1679     // Removes all nested stores to the buffer.
remove_stores(const stmt_t & stmt,const expr_t & buf)1680     stmt_t remove_stores(const stmt_t &stmt, const expr_t &buf) {
1681         auto ret = stmt;
1682         auto stores = find_objects<store_t>(stmt);
1683         for (auto &_s : stores) {
1684             auto &s = _s.as<store_t>();
1685             auto &base = (s.buf.is<var_t>() ? s.buf : s.buf.as<ptr_t>().base);
1686             if (base.is_same(buf)) ret = substitute(ret, _s, stmt_t());
1687         }
1688         return ret;
1689     }
1690 
1691     int increment_ = 1;
1692     int level_ = 0;
1693 
1694     expr_t skip_var_;
1695     object_map_t<expr_t, ref_info_t> refs_;
1696 };
1697 
optimize_alloc_let(const stmt_t & s)1698 stmt_t optimize_alloc_let(const stmt_t &s) {
1699     auto ret = alloc_let_optimizer_t().mutate(s);
1700     trace_pass("optimize_alloc_let", ret);
1701     return ret;
1702 }
1703 
1704 class unrolling_updater_t : public ir_mutator_t {
1705 public:
_mutate(const let_t & obj)1706     object_t _mutate(const let_t &obj) override {
1707         if (level_ == 0) {
1708             // Skip top-level let statements.
1709             return ir_mutator_t::_mutate(obj);
1710         }
1711         lets_.push_back(&obj);
1712         auto new_body = mutate(obj.body);
1713         if (!lets_.back()) {
1714             // Let was moved to the innermost loop.
1715             lets_.pop_back();
1716             return new_body;
1717         }
1718         lets_.pop_back();
1719         if (new_body.is_same(obj.body)) return obj;
1720         return let_t::make(obj.var, obj.value, new_body);
1721     }
1722 
_mutate(const for_t & obj)1723     object_t _mutate(const for_t &obj) override {
1724         level_++;
1725         found_loop_ = false;
1726         auto new_obj = ir_mutator_t::_mutate(obj);
1727         level_--;
1728         if (!found_loop_) {
1729             // Innermost loop, inject let statements.
1730             auto body = get_stmt_body(new_obj);
1731             for (auto it = lets_.rbegin(); it != lets_.rend(); ++it) {
1732                 body = let_t::make((*it)->var, (*it)->value, body);
1733                 *it = nullptr;
1734             }
1735             new_obj = replace_stmt_body(new_obj, body);
1736         }
1737         found_loop_ = true;
1738         return new_obj;
1739     }
1740 
1741 private:
1742     bool found_loop_ = false;
1743     int level_ = 0;
1744     std::vector<const let_t *> lets_;
1745 };
1746 
1747 // Eliminates let statements from the outer loops to be able to unroll loop
1748 // nest for SLM buffering or prefetch injection. Example:
1749 // Before:
1750 //     for (int i = 0; i < I; i++) {
1751 //         int tmp = TMP;
1752 //         for (int j = 0; j < J; j++) {
1753 //            ...
1754 //         }
1755 //     }
1756 // After:
1757 //     for (int i = 0; i < I; i++) {
1758 //         for (int j = 0; j < J; j++) {
1759 //             int tmp = TMP;
1760 //             ...
1761 //         }
1762 //     }
update_loops_for_unrolling(const stmt_t & s,const conv_config_t & cfg)1763 stmt_t update_loops_for_unrolling(const stmt_t &s, const conv_config_t &cfg) {
1764     auto ret = s;
1765     if (cfg.do_loop_unroll) ret = unrolling_updater_t().mutate(s);
1766     trace_pass("update_loops_for_unrolling", ret);
1767     return ret;
1768 }
1769 
1770 // Helper structure for for_t.
1771 struct loop_info_t {
1772     loop_info_t() = default;
1773 
loop_info_tdnnl::impl::gpu::jit::loop_info_t1774     loop_info_t(const stmt_t &s) {
1775         ir_assert(s.is<for_t>()) << s;
1776         auto &loop = s.as<for_t>();
1777         stmt = s;
1778         var = loop.var;
1779         init_ = loop.init;
1780         bound_ = loop.bound;
1781 
1782         auto e_size = simplify(bound_ - init_);
1783         ir_assert(is_const(e_size));
1784         size_ = to_cpp<int>(e_size);
1785     }
1786 
initdnnl::impl::gpu::jit::loop_info_t1787     int init() const {
1788         ir_assert(is_const(init_));
1789         return to_cpp<int>(init_);
1790     }
1791 
bounddnnl::impl::gpu::jit::loop_info_t1792     int bound() const {
1793         ir_assert(is_const(bound_));
1794         return to_cpp<int>(bound_);
1795     }
1796 
sizednnl::impl::gpu::jit::loop_info_t1797     int size() const { return size_; }
1798 
1799     stmt_t stmt;
1800     expr_t var;
1801 
1802 private:
1803     expr_t init_;
1804     expr_t bound_;
1805     int size_;
1806 };
1807 
1808 // Iterates through multiple nested loops with fixed bounds. Used to unroll
1809 // such nested loops.
1810 class multi_loop_iterator_t {
1811 public:
1812     // Ordered from innermost to outermost.
multi_loop_iterator_t(const std::vector<loop_info_t> & loops)1813     multi_loop_iterator_t(const std::vector<loop_info_t> &loops)
1814         : loops_(loops) {
1815         for (auto &l : loops)
1816             var_values_.push_back(l.init());
1817     }
1818 
var_value(const expr_t & var) const1819     int var_value(const expr_t &var) const {
1820         for (size_t i = 0; i < loops_.size(); i++) {
1821             if (loops_[i].var.is_same(var)) return var_values_[i];
1822         }
1823         ir_error_not_expected();
1824         return 0;
1825     }
1826 
advance(int n=1)1827     void advance(int n = 1) {
1828         if (loops_.empty()) return;
1829         for (int i_n = 0; i_n < n; i_n++) {
1830             for (size_t i = 0; i < loops_.size(); i++) {
1831                 auto &l = loops_[i];
1832                 if (++var_values_[i] < l.bound()) break;
1833                 var_values_[i] = l.init();
1834             }
1835             ir_assert(var_values_.back() < loops_.back().bound());
1836         }
1837     }
1838 
is_outer_loop_end() const1839     bool is_outer_loop_end() const {
1840         if (loops_.empty()) return true;
1841         for (size_t i = 0; i < loops_.size() - 1; i++) {
1842             auto &l = loops_[i];
1843             if (var_values_[i] != l.bound() - 1) return false;
1844         }
1845         return true;
1846     }
1847 
str() const1848     std::string str() const {
1849         std::ostringstream oss;
1850         oss << "multi_loop_iterator_t(";
1851         for (size_t i = 0; i < loops_.size(); i++) {
1852             oss << (i != 0 ? ", " : "");
1853             oss << loops_[i].var << " = " << var_values_[i];
1854         }
1855         oss << ")";
1856         return oss.str();
1857     }
1858 
1859     IR_DEFINE_DUMP()
1860 
1861 private:
1862     std::vector<loop_info_t> loops_;
1863     std::vector<int> var_values_;
1864 };
1865 
1866 // Extracts different parts of the compute iteration and verifies the loop nest
1867 // is properly formed and can be further injected with SLM buffering.
1868 class compute_step_visitor_t : public ir_visitor_t {
1869 public:
find_stmt_group(const stmt_label_t & label) const1870     stmt_t find_stmt_group(const stmt_label_t &label) const {
1871         auto groups = find_stmt_groups(label);
1872         if (groups.empty()) return stmt_t();
1873         ir_assert(groups.size() == 1);
1874         return groups[0];
1875     }
1876 
find_stmt_groups(const stmt_label_t & label) const1877     std::vector<stmt_t> find_stmt_groups(const stmt_label_t &label) const {
1878         std::vector<stmt_t> ret;
1879         for (auto &_g : stmt_groups_) {
1880             auto &g = _g.as<stmt_group_t>();
1881             if (g.label == label) ret.push_back(_g);
1882         }
1883         return ret;
1884     }
1885 
inner_let_stmts() const1886     const std::vector<stmt_t> &inner_let_stmts() const {
1887         return inner_let_stmts_;
1888     }
1889 
1890 #define HANDLE_IR_OBJECT(type) \
1891     void _visit(const type &obj) override { visit_stmt(obj); }
1892 
HANDLE_STMT_IR_OBJECTS()1893     HANDLE_STMT_IR_OBJECTS()
1894 
1895 #undef HANDLE_IR_OBJECT
1896 
1897     template <typename T>
1898     void visit_stmt(const T &obj) {
1899         auto obj_type_id = T::_type_id();
1900         bool is_for = (obj_type_id == for_t::_type_id());
1901         bool is_stmt_group = (obj_type_id == stmt_group_t::_type_id());
1902         bool is_let = (obj_type_id == let_t::_type_id());
1903         bool is_stmt_seq = (obj_type_id == stmt_seq_t::_type_id());
1904 
1905         // Loop may contain:
1906         // - Another loop
1907         // - Container statement (stmt_seq_t or stmt_group_t)
1908         // - Let statement (in the innermost loop only)
1909         // - Barrier
1910         if (loop_level_ > 0) {
1911             bool ok = false;
1912             if (is_for || is_let || is_stmt_group || is_stmt_seq) {
1913                 ok = true;
1914             } else if (obj_type_id == func_call_t::_type_id()) {
1915                 auto &call = obj.template as<func_call_t>();
1916                 ok = call.func.is_equal(funcs::barrier_func());
1917             }
1918 
1919             if (!ok) {
1920                 ir_error_not_expected()
1921                         << "Found unexpected statement inside loop.\n"
1922                         << stmt_t(obj);
1923             }
1924         }
1925 
1926         bool is_compute_loop = false;
1927         if (is_stmt_group) {
1928             auto label = obj.template as<stmt_group_t>().label;
1929             stmt_groups_.push_back(obj);
1930             if (utils::one_of(label, stmt_label_t::g2s_load(),
1931                         stmt_label_t::g2s_store(), stmt_label_t::g2r_load(),
1932                         stmt_label_t::s2r_load(), stmt_label_t::prefetch(),
1933                         stmt_label_t::mul())) {
1934                 // Leaf labels, do not visit them.
1935                 return;
1936             }
1937             if (label == stmt_label_t::compute_loop()) {
1938                 is_compute_loop = true;
1939                 in_compute_loop_ = true;
1940             }
1941         }
1942 
1943         if (is_for) loop_level_++;
1944         found_loop_ = false;
1945         ir_visitor_t::_visit(obj);
1946         if (in_compute_loop_ && is_let) {
1947             if (found_loop_)
1948                 ir_error_not_expected()
1949                         << "Let is allowed in the innermost loop only.";
1950 
1951             inner_let_stmts_.push_back(replace_stmt_body(obj, stmt_t()));
1952         }
1953         if (is_for) {
1954             loop_level_--;
1955             found_loop_ = true;
1956         }
1957 
1958         if (is_compute_loop) in_compute_loop_ = false;
1959     }
1960 
1961 private:
1962     bool found_loop_ = false;
1963     bool in_compute_loop_ = false;
1964     int loop_level_ = 0;
1965 
1966     std::vector<stmt_t> stmt_groups_;
1967     std::vector<stmt_t> inner_let_stmts_;
1968 };
1969 
1970 // Provides access to different parts of the inner compute iteration.
1971 class compute_step_t {
1972 public:
compute_step_t(const stmt_t & parent)1973     compute_step_t(const stmt_t &parent) {
1974         compute_step_visitor_t v;
1975         v.visit(parent);
1976 
1977         compute_loop_ = v.find_stmt_group(stmt_label_t::compute_loop());
1978         g2s_load_ = v.find_stmt_group(stmt_label_t::g2s_load());
1979         g2s_store_ = v.find_stmt_group(stmt_label_t::g2s_store());
1980         prefetch_ = v.find_stmt_group(stmt_label_t::prefetch());
1981         g2r_load_ = v.find_stmt_groups(stmt_label_t::g2r_load());
1982         s2r_load_ = v.find_stmt_groups(stmt_label_t::s2r_load());
1983         mul_ = v.find_stmt_groups(stmt_label_t::mul());
1984         c_zero_out_ = v.find_stmt_group(stmt_label_t::c_zero_out());
1985         inner_let_stmts_ = v.inner_let_stmts();
1986 
1987         ir_assert(g2r_load_.size() == mul_.size());
1988         ir_assert(s2r_load_.size() == mul_.size());
1989 
1990         // Assign preload/mul tags to let statements.
1991         for (auto &_let : inner_let_stmts_) {
1992             auto &var = _let.as<let_t>().var;
1993             bool is_preload = (count_object(g2s_load_, var) > 0)
1994                     || (count_object(prefetch_, var) > 0);
1995             bool is_mul = count_object(g2r_load_, var) > 0;
1996             if (is_preload) preload_lets_.insert(_let);
1997             if (is_mul) mul_lets_.insert(_let);
1998         }
1999 
2000         // Propagate preload/mul tags up based on dependencies between let
2001         // statements.
2002         std::vector<let_info_t> let_infos;
2003         object_set_t<stmt_t> seen;
2004         std::function<void(const stmt_t &)> propagate;
2005         propagate = [&](const stmt_t &_let) {
2006             if (seen.count(_let) > 0) return;
2007             auto &let = _let.as<let_t>();
2008             for (auto &_child : inner_let_stmts_) {
2009                 auto &child = _child.as<let_t>();
2010                 if (_child.is_same(_let)) continue;
2011                 if (contains_object(child.value, let.var)) {
2012                     // Visit child let statements first.
2013                     propagate(_child);
2014                     // Propagate child preload/mul values to this let statement.
2015                     if (is_preload_let(_child)) preload_lets_.insert(_let);
2016                     if (is_mul_let(_child)) mul_lets_.insert(_let);
2017                 }
2018             }
2019             auto let_info = create_let_info(
2020                     let, is_preload_let(_let), is_mul_let(_let));
2021             let_infos.push_back(let_info);
2022             seen.insert(_let);
2023         };
2024         for (auto &_let : inner_let_stmts_)
2025             propagate(_let);
2026 
2027         // Duplicate lets that are used in both preload and mul contexts.
2028         duplicate_lets(let_infos);
2029     }
2030 
2031     // See ir_core.hpp for the description.
compute_loop() const2032     const stmt_t &compute_loop() const { return compute_loop_; }
g2s_load() const2033     const stmt_t &g2s_load() const { return g2s_load_; }
g2s_store() const2034     const stmt_t &g2s_store() const { return g2s_store_; }
prefetch() const2035     const stmt_t &prefetch() const { return prefetch_; }
g2r_load() const2036     const std::vector<stmt_t> &g2r_load() const { return g2r_load_; }
s2r_load() const2037     const std::vector<stmt_t> &s2r_load() const { return s2r_load_; }
mul() const2038     const std::vector<stmt_t> &mul() const { return mul_; }
c_zero_out() const2039     const stmt_t &c_zero_out() const { return c_zero_out_; }
inner_let_stmts() const2040     const std::vector<stmt_t> &inner_let_stmts() const {
2041         return inner_let_stmts_;
2042     }
2043 
is_preload_let(const stmt_t & s) const2044     bool is_preload_let(const stmt_t &s) const {
2045         return preload_lets_.count(s) > 0;
2046     }
is_mul_let(const stmt_t & s) const2047     bool is_mul_let(const stmt_t &s) const { return mul_lets_.count(s) > 0; }
2048 
2049 private:
2050     struct let_info_t {
let_info_tdnnl::impl::gpu::jit::compute_step_t::let_info_t2051         let_info_t(const expr_t &var) : var(var) {}
2052 
2053         expr_t var;
2054         expr_t preload_var;
2055         expr_t mul_var;
2056 
is_preloaddnnl::impl::gpu::jit::compute_step_t::let_info_t2057         bool is_preload() const { return !preload_var.is_empty(); }
is_muldnnl::impl::gpu::jit::compute_step_t::let_info_t2058         bool is_mul() const { return !mul_var.is_empty(); }
2059 
needs_updatednnl::impl::gpu::jit::compute_step_t::let_info_t2060         bool needs_update() const { return is_preload() && is_mul(); }
2061     };
2062 
create_let_info(const let_t & let,bool is_preload,bool is_mul)2063     let_info_t create_let_info(const let_t &let, bool is_preload, bool is_mul) {
2064         let_info_t info(let.var);
2065         if (is_preload && !is_mul) {
2066             info.preload_var = let.var;
2067         } else if (!is_preload && is_mul) {
2068             info.mul_var = let.var;
2069         } else if (is_preload && is_mul) {
2070             info.preload_var = create_var_with_suffix(let.var, "p");
2071             info.mul_var = create_var_with_suffix(let.var, "m");
2072         }
2073         return info;
2074     }
2075 
duplicate_lets(const std::vector<let_info_t> & let_infos)2076     void duplicate_lets(const std::vector<let_info_t> &let_infos) {
2077         int nlets = int(inner_let_stmts_.size());
2078         ir_assert(int(let_infos.size()) == nlets);
2079 
2080         std::vector<stmt_t> new_lets;
2081         for (int i = nlets - 1; i >= 0; i--) {
2082             auto &info = let_infos[i];
2083             auto &old_let = inner_let_stmts_[i].as<let_t>();
2084             if (!info.needs_update()) {
2085                 auto new_value = update_var(old_let.value, let_infos,
2086                         info.is_preload(), info.is_mul());
2087                 auto new_let = inner_let_stmts_[i];
2088                 if (!new_value.is_same(old_let.value)) {
2089                     new_let = let_t::make(old_let.var, new_value, old_let.body);
2090                 }
2091                 new_lets.push_back(new_let);
2092                 continue;
2093             }
2094 
2095             preload_lets_.erase(&old_let);
2096             mul_lets_.erase(&old_let);
2097 
2098             auto preload_value
2099                     = update_var(old_let.value, let_infos, true, false);
2100             auto preload_let = let_t::make(
2101                     info.preload_var, preload_value, old_let.body);
2102 
2103             auto mul_value = update_var(old_let.value, let_infos, false, true);
2104             auto mul_let = let_t::make(info.mul_var, mul_value, old_let.body);
2105 
2106             preload_lets_.insert(preload_let);
2107             new_lets.push_back(preload_let);
2108 
2109             mul_lets_.insert(mul_let);
2110             new_lets.push_back(mul_let);
2111 
2112             // Update statements.
2113             g2s_load_ = update_var(g2s_load_, let_infos, true, false);
2114             g2s_store_ = update_var(g2s_store_, let_infos, true, false);
2115             prefetch_ = update_var(prefetch_, let_infos, true, false);
2116             g2r_load_ = update_var(g2r_load_, let_infos, false, true);
2117             s2r_load_ = update_var(s2r_load_, let_infos, false, true);
2118             mul_ = update_var(mul_, let_infos, false, true);
2119         }
2120 
2121         std::reverse(new_lets.begin(), new_lets.end());
2122         inner_let_stmts_ = new_lets;
2123     }
2124 
2125     template <typename T>
update_var(const std::vector<T> & vec,const std::vector<let_info_t> & let_infos,bool is_preload,bool is_mul)2126     static std::vector<T> update_var(const std::vector<T> &vec,
2127             const std::vector<let_info_t> &let_infos, bool is_preload,
2128             bool is_mul) {
2129         std::vector<T> ret;
2130         for (auto &v : vec)
2131             ret.push_back(update_var(v, let_infos, is_preload, is_mul));
2132         return ret;
2133     }
2134 
update_var(const object_t & obj,const std::vector<let_info_t> & let_infos,bool is_preload,bool is_mul)2135     static object_t update_var(const object_t &obj,
2136             const std::vector<let_info_t> &let_infos, bool is_preload,
2137             bool is_mul) {
2138         auto ret = obj;
2139         for (auto &info : let_infos) {
2140             if (!info.needs_update()) continue;
2141             if (!contains_object(ret, info.var)) continue;
2142             if (is_preload) {
2143                 ir_assert(info.is_preload());
2144                 ret = substitute(ret, info.var, info.preload_var);
2145             } else if (is_mul) {
2146                 ir_assert(info.is_mul());
2147                 ret = substitute(ret, info.var, info.mul_var);
2148             }
2149         }
2150         return ret;
2151     }
2152 
create_var_with_suffix(const expr_t & _var,const std::string & suffix)2153     static expr_t create_var_with_suffix(
2154             const expr_t &_var, const std::string &suffix) {
2155         auto &var = _var.as<var_t>();
2156         auto new_name = var.name + "_" + suffix;
2157         return var_t::make(var.type, new_name);
2158     }
2159 
2160     stmt_t compute_loop_;
2161     stmt_t g2s_load_;
2162     stmt_t g2s_store_;
2163     stmt_t prefetch_;
2164     std::vector<stmt_t> g2r_load_;
2165     std::vector<stmt_t> s2r_load_;
2166     std::vector<stmt_t> mul_;
2167     stmt_t c_zero_out_;
2168 
2169     std::vector<stmt_t> inner_let_stmts_;
2170 
2171     // Due to loop unrolling the inner let statements may depend on different
2172     // indices of the outer loops. There are two contexts:
2173     // - "preload" loop iteration, e.g. index I
2174     // - "multiplication" loop iteration, e.g. index (I + nbuf)
2175     // Preloads (either via SLM or via prefetches) for the corresponding
2176     // multiplication are executed several iterations before the real
2177     // multiplication. That's why we need to know exactly in which context the
2178     // given let statement is used. It might be that the same variable is used
2179     // from two different contexts. In this case it is duplicated and
2180     // initialized with different values for each case.
2181     object_set_t<stmt_t> preload_lets_;
2182     object_set_t<stmt_t> mul_lets_;
2183 };
2184 
2185 // Helper class to access the outer loop index after pipelining. Pipelining
2186 // in general requires tracking two versions of a loop index:
2187 // - Multiplication version - corresponding to the iteration that is currently
2188 //   used for multiplication
2189 // - Preload version - corresponding to the iteration that is currently used
2190 //   for preload for one of the next multiplications
2191 // The multiplilcation version is a few steps behind the preload version.
2192 class outer_loop_info_t : public loop_info_t {
2193 public:
2194     outer_loop_info_t() = default;
2195 
outer_loop_info_t(const stmt_t & s,ir_context_t & ir_ctx)2196     outer_loop_info_t(const stmt_t &s, ir_context_t &ir_ctx) : loop_info_t(s) {
2197         // Outer loop may not be used for unrolling hence loop iterations must
2198         // not use its index. If this doesn't hold, introduce a GRF buffer to
2199         // represent that variable and apply post-increment updates after each
2200         // outer loop iteration.
2201         if (count_object(s.as<for_t>().body, var) != 0) {
2202             has_var_refs_ = true;
2203             mul_var_buf_ = ir_ctx.create_tmp_var(
2204                     type_t::byte_ptr(), var.as<var_t>().name + "_mul_buf");
2205             preload_var_buf_ = ir_ctx.create_tmp_var(
2206                     type_t::byte_ptr(), var.as<var_t>().name + "_preload_buf");
2207 
2208             auto mul_alloc = alloc_t::make(
2209                     mul_var_buf_, var.type().size(), alloc_kind_t::grf);
2210             auto preload_alloc = alloc_t::make(
2211                     preload_var_buf_, var.type().size(), alloc_kind_t::grf);
2212             allocs_.push_back(mul_alloc);
2213             allocs_.push_back(preload_alloc);
2214 
2215             auto mul_init = store_t::make(mul_var_buf_, 0, init());
2216             auto preload_init = store_t::make(preload_var_buf_, 0, init());
2217             init_stmt_ = mul_init.append(preload_init);
2218 
2219             mul_post_inc_stmt_
2220                     = store_t::make(mul_var_buf_, 0, mul_var_load() + 1);
2221             preload_post_inc_stmt_ = store_t::make(
2222                     preload_var_buf_, 0, preload_var_load() + 1);
2223         }
2224     }
2225 
has_var_refs() const2226     bool has_var_refs() const { return has_var_refs_; }
2227 
mul_var_load() const2228     expr_t mul_var_load() const {
2229         return load_t::make(var.type(), mul_var_buf_, 0);
2230     }
preload_var_load() const2231     expr_t preload_var_load() const {
2232         return load_t::make(var.type(), preload_var_buf_, 0);
2233     }
2234 
inject_alloc_stmts(const stmt_t & stmt) const2235     stmt_t inject_alloc_stmts(const stmt_t &stmt) const {
2236         return jit::inject_alloc_stmts(stmt, allocs_);
2237     }
2238 
init_stmt() const2239     const stmt_t &init_stmt() const { return init_stmt_; }
2240 
mul_post_inc_stmt() const2241     const stmt_t &mul_post_inc_stmt() const { return mul_post_inc_stmt_; }
preload_post_inc_stmt() const2242     const stmt_t &preload_post_inc_stmt() const {
2243         return preload_post_inc_stmt_;
2244     }
2245 
2246 private:
2247     bool has_var_refs_ = false;
2248 
2249     // Helper expressions/statements to partially unroll the loop.
2250     expr_t mul_var_buf_;
2251     expr_t preload_var_buf_;
2252     std::vector<stmt_t> allocs_;
2253     stmt_t init_stmt_;
2254     stmt_t mul_post_inc_stmt_;
2255     stmt_t preload_post_inc_stmt_;
2256 };
2257 
2258 // Helper class to work with loop nest of the compute loop.
2259 class compute_loop_nest_t {
2260 public:
2261     compute_loop_nest_t() = default;
2262 
compute_loop_nest_t(const stmt_t & root,ir_context_t & ir_ctx)2263     compute_loop_nest_t(const stmt_t &root, ir_context_t &ir_ctx)
2264         : root_(root) {
2265         for (auto &l : find_objects<for_t>(root)) {
2266             loops_.emplace_back(l);
2267         }
2268 
2269         if (loops_.empty()) {
2270             outer_loop_size_ = 1;
2271             return;
2272         }
2273 
2274         outer_loop_ = outer_loop_info_t(loops_.back().stmt, ir_ctx);
2275         outer_loop_size_ = outer_loop_.size();
2276     }
2277 
loops() const2278     const std::vector<loop_info_t> &loops() const { return loops_; }
2279 
2280     // Number of iterations of all loops.
size() const2281     int size() const {
2282         int ret = 1;
2283         for (auto &l : loops_)
2284             ret *= l.size();
2285         return ret;
2286     }
2287 
2288     // Number of iterations in the outermost loop (see comments in ctor).
outer_loop_size() const2289     int outer_loop_size() const { return outer_loop_size_; }
2290 
outer_loop_info() const2291     const outer_loop_info_t &outer_loop_info() const { return outer_loop_; }
2292 
2293     template <typename F>
for_each_loop_var(const F & f) const2294     void for_each_loop_var(const F &f) const {
2295         for (auto &l : loops_)
2296             f(l.var);
2297     }
2298 
2299     // Number of iterations of all loops except the outermost.
inner_loops_size() const2300     int inner_loops_size() const { return size() / outer_loop_size(); }
2301 
2302 private:
2303     stmt_t root_;
2304     std::vector<loop_info_t> loops_;
2305 
2306     int outer_loop_size_;
2307     outer_loop_info_t outer_loop_;
2308 };
2309 
2310 struct compute_params_t {
2311     compute_params_t() = default;
2312 
compute_params_tdnnl::impl::gpu::jit::compute_params_t2313     compute_params_t(int slm_bufs, int gmem_bufs, int slm_buf_size,
2314             int prefetch_bufs, int inner_loops_iters)
2315         : slm_bufs(slm_bufs)
2316         , gmem_bufs(gmem_bufs)
2317         , slm_buf_size(slm_buf_size)
2318         , prefetch_bufs(prefetch_bufs) {
2319         use_slm = (slm_buf_size > 0);
2320         use_prefetch = (prefetch_bufs > 0);
2321         ir_assert(!use_slm || !use_prefetch)
2322                 << "Can't have both SLM buffering and prefetch enabled.";
2323         if (use_slm) {
2324             ir_assert(utils::one_of(slm_bufs, 1, 2, 3));
2325             ir_assert(utils::one_of(gmem_bufs, 1, 2));
2326             preload_bufs = slm_bufs;
2327             unroll = math::lcm(slm_bufs * gmem_bufs, inner_loops_iters);
2328         } else if (use_prefetch) {
2329             preload_bufs = prefetch_bufs;
2330             ir_assert(slm_bufs == 0);
2331             ir_assert(gmem_bufs == 0);
2332             unroll = math::lcm(prefetch_bufs, inner_loops_iters);
2333         } else {
2334             preload_bufs = 0;
2335             ir_assert(slm_bufs == 0);
2336             ir_assert(gmem_bufs == 0);
2337             unroll = inner_loops_iters;
2338         }
2339     }
2340 
2341     int slm_bufs;
2342     int gmem_bufs;
2343     int slm_buf_size;
2344     int prefetch_bufs;
2345     int preload_bufs;
2346     int unroll;
2347 
2348     bool use_slm;
2349     bool use_prefetch;
2350 };
2351 
2352 // Helper class to implement SLM buffering.
2353 class compute_iterator_t {
2354 public:
compute_iterator_t(const compute_params_t & params,const compute_loop_nest_t & loop_nest)2355     compute_iterator_t(const compute_params_t &params,
2356             const compute_loop_nest_t &loop_nest)
2357         : params(params)
2358         , preload_loop_it(loop_nest.loops())
2359         , mul_loop_it(loop_nest.loops()) {
2360 
2361         int compute_iters = loop_nest.size();
2362         iters = compute_iters;
2363         ir_assert(iters >= 1) << "Empty loop is not expected.";
2364 
2365         iters += std::max(0, preload_bufs() - 1) + std::max(0, gmem_bufs() - 1);
2366         ramp_up_iters
2367                 = std::max(1, preload_bufs() + std::max(0, gmem_bufs() - 1));
2368         ramp_down_iters = std::min(
2369                 std::max(0, preload_bufs() - 1) + std::max(0, gmem_bufs() - 1),
2370                 iters - ramp_up_iters);
2371         body_iters = iters - ramp_up_iters - ramp_down_iters;
2372         body_iters = utils::rnd_dn(body_iters, params.unroll);
2373         ramp_down_iters = iters - ramp_up_iters - body_iters;
2374 
2375         ir_assert(ramp_up_iters + body_iters + ramp_down_iters == iters);
2376 
2377         iter = 0;
2378         linear_id = 0;
2379         riter = iters - 1;
2380     }
2381 
unroll() const2382     int unroll() const { return params.unroll; }
2383 
preload_bufs() const2384     int preload_bufs() const { return params.preload_bufs; }
2385 
slm_bufs() const2386     int slm_bufs() const { return params.slm_bufs; }
2387 
gmem_bufs() const2388     int gmem_bufs() const { return params.gmem_bufs; }
2389 
operator ++()2390     compute_iterator_t &operator++() {
2391         if (do_preload()) preload_loop_it.advance();
2392         if (do_mul()) mul_loop_it.advance();
2393         ++iter;
2394         ++linear_id;
2395         --riter;
2396         return *this;
2397     }
2398 
advance(int n)2399     void advance(int n) {
2400         if (n == 0) return;
2401 
2402         ir_assert(n % params.unroll == 0);
2403         ir_assert(iter + n <= iters);
2404 
2405         if (preload_bufs() > 0) ir_assert(do_preload());
2406         ir_assert(do_mul());
2407 
2408         iter += n;
2409         riter -= n;
2410 
2411         if (preload_bufs() > 0) preload_loop_it.advance(n);
2412         mul_loop_it.advance(n);
2413     }
2414 
do_mul() const2415     bool do_mul() const {
2416         return iter >= std::max(0, preload_bufs() - 1)
2417                 + std::max(0, gmem_bufs() - 1);
2418     }
2419 
is_first_mul() const2420     bool is_first_mul() const {
2421         return iter
2422                 == std::max(0, preload_bufs() - 1)
2423                 + std::max(0, gmem_bufs() - 1);
2424     }
is_last_mul() const2425     bool is_last_mul() const { return riter == 0; }
2426 
do_preload() const2427     bool do_preload() const {
2428         if (preload_bufs() == 0) return false;
2429         return riter >= (preload_bufs() - 1) + std::max(0, gmem_bufs() - 1);
2430     }
2431 
do_prefetch() const2432     bool do_prefetch() const {
2433         if (!params.use_prefetch) return false;
2434         return do_preload();
2435     }
2436 
do_g2s_load() const2437     bool do_g2s_load() const {
2438         if (!params.use_slm) return false;
2439         return do_preload();
2440     }
2441 
do_s2r_load() const2442     bool do_s2r_load() const {
2443         if (!params.use_slm) return false;
2444         ir_assert(gmem_bufs() >= 1);
2445         return iter >= (gmem_bufs() - 1) && riter >= (slm_bufs() - 1);
2446     }
2447 
gmem_write_buf_index() const2448     int gmem_write_buf_index() const {
2449         ir_assert(do_g2s_load());
2450         return iter % gmem_bufs();
2451     }
2452 
gmem_read_buf_index() const2453     int gmem_read_buf_index() const {
2454         ir_assert(do_s2r_load());
2455         return (iter - (gmem_bufs() - 1)) % gmem_bufs();
2456     }
2457 
slm_read_offset_update() const2458     int slm_read_offset_update() const {
2459         ir_assert(params.use_slm);
2460         ir_assert(do_mul());
2461 
2462         int slm_iter = iter - (gmem_bufs() - 1) - (slm_bufs() - 1);
2463         int cur_slm_idx = slm_iter % slm_bufs();
2464         int next_slm_idx = (slm_iter + 1) % slm_bufs();
2465         int ret = next_slm_idx * params.slm_buf_size
2466                 - cur_slm_idx * params.slm_buf_size;
2467         return ret;
2468     }
2469 
slm_write_offset_update() const2470     int slm_write_offset_update() const {
2471         ir_assert(params.use_slm);
2472         ir_assert(do_s2r_load());
2473 
2474         int slm_iter = iter - (gmem_bufs() - 1);
2475         int cur_slm_idx = slm_iter % slm_bufs();
2476         int next_slm_idx = (slm_iter + 1) % slm_bufs();
2477         int ret = next_slm_idx * params.slm_buf_size
2478                 - cur_slm_idx * params.slm_buf_size;
2479         return ret;
2480     }
2481 
2482     compute_params_t params;
2483     multi_loop_iterator_t preload_loop_it;
2484     multi_loop_iterator_t mul_loop_it;
2485 
2486     // ramp_up_iters + body_iters + ramp_down_iters == iters
2487     int iters;
2488     int ramp_up_iters;
2489     int body_iters;
2490     int ramp_down_iters;
2491 
2492     // Invariant: iter + riter = iters - 1
2493     int iter;
2494     int riter;
2495 
2496     int linear_id;
2497 };
2498 
2499 // Basic LRU SBID allocator, tries to use the same SBIDs for the same GRF
2500 // buffers.
2501 class sbid_manager_t {
2502 public:
sbid_manager_t()2503     sbid_manager_t() : tuple_func_(builtin_t::make("tuple")) {}
2504 
get_sbid(const expr_t & buf,int index=0)2505     ngen_proxy::SBID get_sbid(const expr_t &buf, int index = 0) {
2506         auto key = tuple_func_.call({buf, expr_t(index)});
2507 
2508         int free_idx = -1;
2509         for (int i = 0; i < sbid_count; i++) {
2510             auto &e = entries_[i];
2511             if (key.is_equal(e.key)) {
2512                 e.time = cur_time_++;
2513                 return ngen_proxy::SBID(i);
2514             }
2515             if (free_idx == -1 && e.key.is_empty()) free_idx = i;
2516         }
2517 
2518         // Not found but there is a free SBID.
2519         if (free_idx != -1) {
2520             entries_[free_idx] = {key, cur_time_++};
2521             return ngen_proxy::SBID(free_idx);
2522         }
2523 
2524         // Find the oldest SBID and use it.
2525         int old_idx = 0;
2526         int old_time = entries_[0].time;
2527         for (int i = 1; i < sbid_count; i++) {
2528             if (entries_[i].time < old_time) {
2529                 old_idx = i;
2530                 old_time = entries_[i].time;
2531             }
2532         }
2533 
2534         entries_[old_idx] = entry_t({key, cur_time_++});
2535         return ngen_proxy::SBID(old_idx);
2536     }
2537 
2538 private:
2539     struct entry_t {
2540         stmt_t key;
2541         int time;
2542     };
2543 
2544     static const int sbid_count = 16;
2545     std::array<entry_t, sbid_count> entries_;
2546 
2547     func_t tuple_func_;
2548     int cur_time_ = 0;
2549 };
2550 
2551 // Helper to assign SBIDs to IR function calls.
2552 class sbid_assigner_t {
2553 public:
2554     sbid_assigner_t() = default;
2555 
sbid_assigner_t(sbid_manager_t & external_sbid_mgr)2556     sbid_assigner_t(sbid_manager_t &external_sbid_mgr)
2557         : external_sbid_mgr_(&external_sbid_mgr) {}
2558 
assign(const stmt_t & stmt)2559     stmt_t assign(const stmt_t &stmt) {
2560         auto stmt_vec = flatten_statements(stmt);
2561         stmt_t ret = stmt;
2562         int prefetch_idx = 0;
2563         for (auto &_s : stmt_vec) {
2564             if (!_s.is<func_call_t>()) continue;
2565             auto s = _s;
2566             if (is_func_call<send_t>(s)) {
2567                 auto &send = s.as<func_call_t>().func.as<send_t>();
2568                 int idx = (send.is_prefetch ? prefetch_idx++ : 0);
2569                 auto sbid = get_sbid(send_t::arg_reg_buf(s), idx);
2570                 s = update_call_with_sbid(s, sbid);
2571             } else if (is_func_call<dpas_t>(s)) {
2572                 auto &attr = s.as<func_call_t>().attr;
2573                 auto *mod_attr = attr.as_ptr<instruction_modifier_attr_t>();
2574                 if (!mod_attr || !mod_attr->mod.is_atomic) {
2575                     // Last dpas in Atomic chain.
2576                     auto sbid = get_sbid(dpas_t::arg_src1(s));
2577                     s = update_call_with_sbid(s, sbid);
2578                 }
2579             } else if (s.is<func_call_t>()) {
2580                 auto &c = s.as<func_call_t>();
2581                 if (c.func.is_equal(funcs::signal_func())
2582                         || c.func.is_equal(funcs::slm_fence_func())
2583                         || c.func.is_equal(funcs::barrier_func())) {
2584                     // Use 0 as the key for signals and SLM fences.
2585                     auto sbid = get_sbid(expr_t(0));
2586                     s = update_call_with_sbid(s, sbid);
2587                 }
2588             } else {
2589                 ir_error_not_expected() << s;
2590             }
2591             ret = substitute(ret, _s, s);
2592         }
2593         return ret;
2594     }
2595 
2596 private:
get_sbid(const expr_t & ptr,int index=0)2597     ngen_proxy::SBID get_sbid(const expr_t &ptr, int index = 0) {
2598         auto &sbid_mgr
2599                 = (external_sbid_mgr_ ? *external_sbid_mgr_ : local_sbid_mgr_);
2600         return sbid_mgr.get_sbid(ptr, index);
2601     }
2602 
update_call_with_sbid(const stmt_t & s,const ngen_proxy::SBID & sbid)2603     static stmt_t update_call_with_sbid(
2604             const stmt_t &s, const ngen_proxy::SBID &sbid) {
2605         return instruction_modifier_attr_t::make(
2606                 ngen_proxy::InstructionModifier().with_sbid(sbid))
2607                 .apply_to(s);
2608     }
2609 
2610     sbid_manager_t local_sbid_mgr_;
2611     sbid_manager_t *external_sbid_mgr_ = nullptr;
2612 };
2613 
2614 class simple_slm_buffering_injector_t {
2615 public:
simple_slm_buffering_injector_t(ngen::HW hw,const stmt_t & root,const conv_config_t & cfg,ir_context_t & ir_ctx,int ab_slm_size)2616     simple_slm_buffering_injector_t(ngen::HW hw, const stmt_t &root,
2617             const conv_config_t &cfg, ir_context_t &ir_ctx, int ab_slm_size)
2618         : hw_(hw)
2619         , cfg_(cfg)
2620         , ir_ctx_(ir_ctx)
2621         , ab_slm_size_(ab_slm_size)
2622         , root_(root)
2623         , alloc_mgr_(root_)
2624         , step_(root)
2625         , loop_nest_(root, ir_ctx) {}
2626 
inject()2627     stmt_t inject() {
2628         ir_assert(cfg_.gmem_bufs == 1) << "GRF buffering is not supported.";
2629         if (utils::one_of(cfg_.slm_bufs, 0, 1)) return root_;
2630 
2631         ir_assert(cfg_.use_a_slm == cfg_.use_b_slm)
2632                 << "Mixed SLM/GMEM loads are not supported.";
2633 
2634         auto loop = step_.compute_loop();
2635 
2636         // SLM indices are allocated as follows:
2637         // slm_idx[0] -> slm_buf_store
2638         // slm_idx[1] -> slm_buf_compute
2639         // slm_idx[2] -> slm_counter
2640         auto slm_idx_buf
2641                 = ir_ctx_.create_tmp_var(type_t::byte_ptr(), "slm_idx");
2642         int slm_idx_size = type_t::s32().size();
2643 
2644         auto slm_idx_load = [&](int off, int elems) {
2645             return load_t::make(
2646                     type_t::s32(elems), slm_idx_buf, slm_idx_size * off);
2647         };
2648 
2649         // Initialize slm_idx.
2650         int off = 0;
2651         auto store0 = store_t::make(slm_idx_buf, off, 0);
2652         off += slm_idx_size;
2653 
2654         auto store1 = store_t::make(slm_idx_buf, off, 1);
2655         off += slm_idx_size;
2656 
2657         auto store2 = store_t::make(
2658                 slm_idx_buf, off, int_imm_t::make(0, type_t::s32()));
2659 
2660         auto slm_idx_init = store0.append(store1).append(store2);
2661 
2662         auto slm_idx_load2 = slm_idx_load(0, 2);
2663         auto slm_idx_load4 = slm_idx_load(0, 4);
2664         auto slm_idx_store = store_t::make(slm_idx_buf, 0,
2665                 slm_idx_load4 + shuffle_t::make_broadcast(1, 4));
2666 
2667         // Update slm_idx.
2668         auto mask = (slm_idx_load2
2669                 == shuffle_t::make_broadcast(cfg_.slm_bufs, 2));
2670         auto slm_idx_store_fix = store_t::make(slm_idx_buf, 0,
2671                 shuffle_t::make_broadcast(int_imm_t::make(0, type_t::s32()), 2),
2672                 store_t::default_stride, mask);
2673 
2674         auto slm_idx_update = slm_idx_store.append(slm_idx_store_fix);
2675 
2676         loop = slm_idx_init.append(loop);
2677 
2678         auto &g2s_store_orig = step_.g2s_store();
2679         auto &s2r_load = step_.s2r_load();
2680         auto &mul = step_.mul();
2681 
2682         auto g2s_store = g2s_store_orig;
2683 
2684         ir_assert(s2r_load.size() == mul.size());
2685 
2686         stmt_t s2r_mul;
2687         for (int i = 0; i < int(mul.size()); i++) {
2688             s2r_mul = s2r_mul.append(s2r_load[i]);
2689             loop = substitute(loop, s2r_load[i], stmt_t(), 1);
2690             s2r_mul = s2r_mul.append(mul[i]);
2691             loop = substitute(loop, mul[i], stmt_t(), 1);
2692         }
2693 
2694         loop = remove_synchronization(loop);
2695 
2696         s2r_mul = sub_slm_bufs(s2r_mul, slm_idx_load(1, 1));
2697         g2s_store = sub_slm_bufs(g2s_store, slm_idx_load(0, 1));
2698         g2s_store = g2s_store.append(slm_idx_update);
2699 
2700         auto s2r_mul_body = s2r_mul;
2701         auto s2r_mul_tail = s2r_mul;
2702         auto slm_counter = slm_idx_load(2, 1);
2703         auto cond = (slm_counter >= cfg_.slm_bufs - 1);
2704 
2705         if (cfg_.slm_bufs == 2) {
2706             s2r_mul_body = if_t::make(cond, s2r_mul_body);
2707             g2s_store = g2s_store.append(funcs::barrier());
2708         } else {
2709             // In general we have to use SLM fence before signal to flush all
2710             // previous SLM stores. However any SLM load behaves as implicit
2711             // SLM fence for all previous SLM stores. This means we don't need
2712             // explicit SLM fence when we perform SLM load/multiplication
2713             // before signal.
2714             auto fence_signal = funcs::slm_fence().append(funcs::signal());
2715             s2r_mul_body = s2r_mul_body.append(funcs::signal());
2716             s2r_mul_body = if_t::make(cond, s2r_mul_body, fence_signal);
2717             s2r_mul_body = funcs::barrier_wait().append(s2r_mul_body);
2718         }
2719 
2720         loop = substitute(
2721                 loop, g2s_store_orig, s2r_mul_body.append(g2s_store), 1);
2722 
2723         if (cfg_.slm_bufs == 3) {
2724             // Emit initial signal, to match wait-signal pairs in the loop.
2725             loop = funcs::signal().append(loop);
2726         }
2727 
2728         // Complete the remaining iterations.
2729         int rem_iters = cfg_.slm_bufs - 1;
2730         int mul_start = std::max(0, rem_iters - loop_nest_.size());
2731         for (int i = 0; i < rem_iters; i++) {
2732             if (cfg_.slm_bufs == 3) loop = loop.append(funcs::barrier_wait());
2733             if (i >= mul_start) {
2734                 // SLM load/multiplication works as implicit SLM fence.
2735                 loop = loop.append(s2r_mul_tail);
2736             } else {
2737                 loop = loop.append(funcs::slm_fence());
2738             }
2739             loop = loop.append(slm_idx_update);
2740             if (cfg_.slm_bufs == 3 && i + 1 < rem_iters)
2741                 loop = loop.append(funcs::signal());
2742         }
2743 
2744         if (cfg_.assign_sbids) loop = sbid_assigner_t().assign(loop);
2745 
2746         const auto grf_size = ngen::GRF::bytes(hw_);
2747         loop = alloc_t::make(
2748                 slm_idx_buf, grf_size, alloc_kind_t::grf, {}, loop);
2749 
2750         alloc_updater_t alloc_updater;
2751 
2752         auto slm_buffers = alloc_mgr_.find_buffers(alloc_kind_t::slm);
2753         ir_assert(slm_buffers.size() == 1);
2754         auto &slm_buf = slm_buffers[0];
2755         int non_ab_slm_size = alloc_mgr_.alloc_size(slm_buf) - ab_slm_size_;
2756         alloc_updater.resize(
2757                 slm_buf, non_ab_slm_size + ab_slm_size_ * cfg_.slm_bufs);
2758 
2759         auto ret = substitute(root_, step_.compute_loop(), loop, 1);
2760         ret = alloc_updater.update(ret);
2761         return ret;
2762     }
2763 
remove_synchronization(const stmt_t & s)2764     static stmt_t remove_synchronization(const stmt_t &s) {
2765         auto ret = s;
2766         for (auto &_c : find_objects<func_call_t>(s)) {
2767             auto &c = _c.as<func_call_t>();
2768             if (c.func.is_equal(funcs::signal_func())
2769                     || c.func.is_equal(funcs::slm_fence_func())
2770                     || c.func.is_equal(funcs::barrier_func())) {
2771                 ret = substitute(ret, _c, stmt_t(), 1);
2772             }
2773         }
2774         return ret;
2775     }
2776 
sub_slm_bufs(const stmt_t & stmt,const expr_t & slm_idx) const2777     stmt_t sub_slm_bufs(const stmt_t &stmt, const expr_t &slm_idx) const {
2778         auto stmt_vec = flatten_statements(stmt);
2779 
2780         stmt_t ret = stmt;
2781         for (auto &s : stmt_vec) {
2782             if (!is_func_call<send_t>(s)) continue;
2783 
2784             auto &send = s.as<func_call_t>().func.as<send_t>();
2785 
2786             // This is not send to SLM, skip.
2787             if (send.address_model != ngen_proxy::AddressModel::ModelSLM)
2788                 continue;
2789 
2790             auto new_args = s.as<func_call_t>().args;
2791             send_t::arg_mem_off(new_args) += ab_slm_size_ * slm_idx;
2792             auto new_send = send.call(new_args);
2793             ret = substitute(ret, s, new_send, 1);
2794         }
2795 
2796         return ret;
2797     }
2798 
2799     ngen::HW hw_;
2800     const conv_config_t &cfg_;
2801     ir_context_t &ir_ctx_;
2802     int ab_slm_size_;
2803 
2804     stmt_t root_;
2805     alloc_manager_t alloc_mgr_;
2806     compute_step_t step_;
2807     compute_loop_nest_t loop_nest_;
2808 };
2809 
2810 // Injects SLM buffering without unrolling based on the config.
inject_simple_slm_buffering(ngen::HW hw,const stmt_t & s,const conv_config_t & cfg,ir_context_t & ir_ctx,int ab_slm_size)2811 stmt_t inject_simple_slm_buffering(ngen::HW hw, const stmt_t &s,
2812         const conv_config_t &cfg, ir_context_t &ir_ctx, int ab_slm_size) {
2813     auto ret = simple_slm_buffering_injector_t(hw, s, cfg, ir_ctx, ab_slm_size)
2814                        .inject();
2815     trace_pass("inject_simple_slm_buffering", ret);
2816     return ret;
2817 }
2818 
2819 class unrolling_injector_t {
2820 public:
unrolling_injector_t(const stmt_t & root,const conv_config_t & cfg,ir_context_t & ir_ctx,int ab_slm_size)2821     unrolling_injector_t(const stmt_t &root, const conv_config_t &cfg,
2822             ir_context_t &ir_ctx, int ab_slm_size)
2823         : cfg_(cfg)
2824         , ir_ctx_(ir_ctx)
2825         , ab_slm_size_(ab_slm_size)
2826         , root_(root)
2827         , alloc_mgr_(root_)
2828         , step_(root)
2829         , loop_nest_(root, ir_ctx) {
2830         int inner_iters = loop_nest_.inner_loops_size();
2831         params_ = compute_params_t(cfg_.slm_bufs, cfg_.gmem_bufs, ab_slm_size,
2832                 cfg_.prefetch_bufs, inner_iters);
2833         if (params_.use_slm) {
2834             for (auto &b :
2835                     find_send_buffers(step_.g2s_load(), /*is_mem=*/false)) {
2836                 g2s_reg_bufs_.emplace_back(b, alloc_mgr_.alloc_size(b));
2837             }
2838         }
2839     }
2840 
inject()2841     stmt_t inject() {
2842         compute_iterator_t it(params_, loop_nest_);
2843         stmt_t body;
2844 
2845         sbid_manager_t sbid_mgr;
2846 
2847         auto &outer_loop_info = loop_nest_.outer_loop_info();
2848 
2849         auto append_outer_post_inc = [&](const stmt_t &_s) {
2850             auto &mul = outer_loop_info.mul_post_inc_stmt();
2851             auto &preload = outer_loop_info.preload_post_inc_stmt();
2852             auto s = _s;
2853             if (it.mul_loop_it.is_outer_loop_end() && it.do_mul()) {
2854                 s = s.append(mul);
2855             }
2856             if (it.preload_loop_it.is_outer_loop_end() && it.do_preload()) {
2857                 s = s.append(preload);
2858             }
2859             return s;
2860         };
2861 
2862         // Ramp-up.
2863         for (int i = 0; i < it.ramp_up_iters; i++) {
2864             body = stmt_seq_t::make(body, create_iteration(it, sbid_mgr));
2865             body = append_outer_post_inc(body);
2866             ++it;
2867         }
2868 
2869         // Body.
2870         if (it.body_iters > 0) {
2871             int extent = it.body_iters / it.unroll();
2872             bool has_loop = (extent > 1);
2873 
2874             stmt_t loop_body;
2875             for (int i = 0; i < it.unroll(); i++) {
2876                 loop_body = loop_body.append(create_iteration(
2877                         it, sbid_mgr, /*in_loop_body=*/has_loop));
2878                 ir_assert(it.do_mul());
2879                 ir_assert(it.do_preload());
2880                 loop_body = append_outer_post_inc(loop_body);
2881                 ++it;
2882             }
2883             if (!has_loop) {
2884                 body = body.append(loop_body);
2885             } else {
2886                 ir_assert(extent > 0);
2887                 auto for_var = ir_ctx_.create_tmp_var(type_t::s32(), "i");
2888                 body = body.append(for_t::make(for_var, 0, extent, loop_body));
2889             }
2890             it.advance(it.body_iters - it.unroll());
2891         }
2892 
2893         // Ramp-down.
2894         for (int i = 0; i < it.ramp_down_iters; i++) {
2895             ir_assert(it.do_mul());
2896             body = body.append(create_iteration(it, sbid_mgr));
2897             body = append_outer_post_inc(body);
2898             ++it;
2899         }
2900 
2901         if (outer_loop_info.has_var_refs()) {
2902             body = outer_loop_info.init_stmt().append(body);
2903             body = outer_loop_info.inject_alloc_stmts(body);
2904         }
2905 
2906         auto ret = substitute(root_, step_.compute_loop(), body, 1);
2907         if (params_.use_slm) {
2908             alloc_updater_t alloc_updater;
2909 
2910             // Update buffer sizes.
2911             for (auto &b : g2s_reg_bufs_) {
2912                 alloc_updater.resize(
2913                         b.buf, alloc_mgr_.alloc_size(b.buf) * cfg_.gmem_bufs);
2914             }
2915 
2916             auto slm_buffers = alloc_mgr_.find_buffers(alloc_kind_t::slm);
2917             if (!slm_buffers.empty()) {
2918                 ir_assert(slm_buffers.size() == 1);
2919 
2920                 auto &slm_buf = slm_buffers[0];
2921                 int non_ab_slm_size
2922                         = alloc_mgr_.alloc_size(slm_buf) - ab_slm_size_;
2923                 alloc_updater.resize(slm_buf,
2924                         non_ab_slm_size + ab_slm_size_ * cfg_.slm_bufs);
2925             }
2926 
2927             ret = alloc_updater.update(ret);
2928         }
2929 
2930         // Remove zero-out statement for C (handled by sub_fma_acc_with_zero).
2931         ret = substitute(ret, step_.c_zero_out(), stmt_t(), 1);
2932 
2933         return ret;
2934     }
2935 
2936 private:
2937     struct buffer_info_t {
buffer_info_tdnnl::impl::gpu::jit::unrolling_injector_t::buffer_info_t2938         buffer_info_t(const expr_t &buf, int size) : buf(buf), size(size) {}
2939 
2940         expr_t buf;
2941         int size;
2942     };
2943 
create_iteration(const compute_iterator_t & it,sbid_manager_t & sbid_mgr,bool in_loop_body=false) const2944     stmt_t create_iteration(const compute_iterator_t &it,
2945             sbid_manager_t &sbid_mgr, bool in_loop_body = false) const {
2946         auto g2s_load = step_.g2s_load();
2947         auto g2s_store = step_.g2s_store();
2948         auto prefetch = step_.prefetch();
2949         auto g2r_load = step_.g2r_load();
2950         auto s2r_load = step_.s2r_load();
2951         auto mul = step_.mul();
2952         auto lets = step_.inner_let_stmts();
2953         auto &outer_loop_info = loop_nest_.outer_loop_info();
2954 
2955         loop_nest_.for_each_loop_var([&](const expr_t &v) {
2956             g2s_load = const_fold(substitute(
2957                     g2s_load, v, expr_t(it.preload_loop_it.var_value(v))));
2958             g2s_store = const_fold(substitute(
2959                     g2s_store, v, expr_t(it.preload_loop_it.var_value(v))));
2960             prefetch = const_fold(substitute(
2961                     prefetch, v, expr_t(it.preload_loop_it.var_value(v))));
2962             expr_t mul_var_value;
2963             expr_t preload_var_value;
2964             if (v.is_same(outer_loop_info.var) && in_loop_body
2965                     && outer_loop_info.has_var_refs()) {
2966                 mul_var_value = outer_loop_info.mul_var_load();
2967                 preload_var_value = outer_loop_info.preload_var_load();
2968             } else {
2969                 mul_var_value = it.mul_loop_it.var_value(v);
2970                 preload_var_value = it.preload_loop_it.var_value(v);
2971             }
2972             for (auto &s : g2r_load) {
2973                 s = const_fold(substitute(s, v, mul_var_value));
2974             }
2975             for (auto &s : s2r_load) {
2976                 s = const_fold(substitute(s, v, preload_var_value));
2977             }
2978             for (int i = 0; i < int(lets.size()); i++) {
2979                 auto &let = lets[i];
2980                 auto &orig_let = step_.inner_let_stmts()[i];
2981                 expr_t var_value;
2982                 bool is_preload_let = step_.is_preload_let(orig_let);
2983                 bool is_mul_let = step_.is_mul_let(orig_let);
2984                 if (is_preload_let && !is_mul_let) {
2985                     var_value = preload_var_value;
2986                 } else if (is_mul_let && !is_preload_let) {
2987                     var_value = mul_var_value;
2988                 } else {
2989                     ir_assert(count_object(let.as<let_t>().value, v) == 0)
2990                             << "Unexpected reference to variable " << v
2991                             << " from " << let;
2992                     continue;
2993                 }
2994                 let = const_fold(substitute(let, v, var_value));
2995             }
2996         });
2997 
2998         if (params_.use_slm) {
2999             g2s_load = sub_gmem_bufs(g2s_load, it, /*is_read=*/false);
3000             g2s_store = sub_gmem_bufs(g2s_store, it, /*is_read=*/true);
3001 
3002             g2s_store = sub_slm_bufs(g2s_store, it, /*is_read=*/false);
3003             for (auto &s : s2r_load) {
3004                 s = sub_slm_bufs(s, it, /*is_read=*/true);
3005             }
3006         }
3007 
3008         if (it.is_first_mul()) {
3009             for (auto &m : mul) {
3010                 m = sub_fma_acc_with_zero(m);
3011             }
3012         }
3013 
3014         stmt_t iter_stmt;
3015         if (it.slm_bufs() == 3 && it.do_mul()) {
3016             iter_stmt = iter_stmt.append(funcs::barrier_wait());
3017         }
3018 
3019         if (it.do_g2s_load()) iter_stmt = iter_stmt.append(g2s_load);
3020 
3021         if (it.slm_bufs() == 3 && it.iter == it.gmem_bufs()) {
3022             iter_stmt = iter_stmt.append(funcs::slm_fence());
3023             iter_stmt = iter_stmt.append(funcs::signal());
3024         }
3025 
3026         if (it.do_s2r_load() && it.slm_bufs() == 1) {
3027             iter_stmt = iter_stmt.append(funcs::barrier());
3028             iter_stmt = iter_stmt.append(g2s_store);
3029             iter_stmt = iter_stmt.append(funcs::barrier());
3030         }
3031 
3032         if (it.do_prefetch()) iter_stmt = iter_stmt.append(prefetch);
3033 
3034         if (it.do_mul()) {
3035             for (size_t i = 0; i < mul.size(); i++) {
3036                 iter_stmt = iter_stmt.append(g2r_load[i]);
3037                 iter_stmt = iter_stmt.append(s2r_load[i]);
3038                 iter_stmt = iter_stmt.append(mul[i]);
3039             }
3040             if (it.slm_bufs() == 3 && !it.is_last_mul()) {
3041                 iter_stmt = iter_stmt.append(funcs::signal());
3042             }
3043         }
3044         if (it.do_s2r_load() && it.slm_bufs() >= 2) {
3045             iter_stmt = iter_stmt.append(g2s_store);
3046             if (it.slm_bufs() == 2) {
3047                 iter_stmt = iter_stmt.append(funcs::barrier());
3048             }
3049         }
3050 
3051         if (cfg_.assign_sbids)
3052             iter_stmt = sbid_assigner_t(sbid_mgr).assign(iter_stmt);
3053 
3054         iter_stmt = inject_local_let(iter_stmt, lets, it.linear_id);
3055 
3056         return iter_stmt;
3057     }
3058 
sub_gmem_bufs(const stmt_t & stmt,const compute_iterator_t & it,bool is_read) const3059     stmt_t sub_gmem_bufs(const stmt_t &stmt, const compute_iterator_t &it,
3060             bool is_read) const {
3061         if (it.slm_bufs() == 0) return stmt;
3062         if (is_read && !it.do_s2r_load()) return stmt;
3063         if (!is_read && !it.do_g2s_load()) return stmt;
3064 
3065         int buf_idx = (is_read ? it.gmem_read_buf_index()
3066                                : it.gmem_write_buf_index());
3067         if (buf_idx == 0) return stmt;
3068 
3069         auto ret = stmt;
3070         for (auto &b : g2s_reg_bufs_) {
3071             ret = substitute(ret, b.buf, b.buf[buf_idx * b.size]);
3072         }
3073         return ret;
3074     }
3075 
sub_slm_bufs(const stmt_t & stmt,const compute_iterator_t & it,bool is_read) const3076     stmt_t sub_slm_bufs(const stmt_t &stmt, const compute_iterator_t &it,
3077             bool is_read) const {
3078         if (it.slm_bufs() <= 1) return stmt;
3079         if (is_read && !it.do_mul()) return stmt;
3080         if (!is_read && !it.do_s2r_load()) return stmt;
3081 
3082         int upd = (is_read ? it.slm_read_offset_update()
3083                            : it.slm_write_offset_update());
3084 
3085         auto stmt_vec = flatten_statements(stmt);
3086 
3087         stmt_t ret = stmt;
3088         for (auto &s : stmt_vec) {
3089             auto *call = s.as_ptr<func_call_t>();
3090             if (!call) continue;
3091             auto *func = call->func.as_ptr<send_t>();
3092             if (!func) continue;
3093 
3094             auto &send = call->func.as<send_t>();
3095             auto &args = call->args;
3096             auto &mem_buf = send_t::arg_mem_buf(args);
3097             auto &header_buf = send_t::arg_mem_off(args);
3098 
3099             // This is not send to SLM, skip.
3100             if (send.address_model != ngen_proxy::AddressModel::ModelSLM)
3101                 continue;
3102 
3103             // May have signed offset.
3104             auto store_obj = send.create_offset_store(
3105                     header_buf, mem_buf, upd, /*is_signed_offset=*/true);
3106             auto &store = store_obj.as<store_t>();
3107             expr_t old_value
3108                     = load_t::make(send.address_type(), store.buf, store.off);
3109             auto post_inc_store = store_t::make(
3110                     store.buf, store.off, old_value + store.value);
3111             ret = substitute(ret, s, stmt_seq_t::make(s, post_inc_store), 1);
3112         }
3113 
3114         return ret;
3115     }
3116 
sub_fma_acc_with_zero(const stmt_t & stmt)3117     static stmt_t sub_fma_acc_with_zero(const stmt_t &stmt) {
3118         auto stmt_vec = flatten_statements(stmt);
3119 
3120         object_eq_set_t<expr_t> seen_dst;
3121         stmt_t ret = stmt;
3122         for (auto &s : stmt_vec) {
3123             if (is_func_call<dpas_t>(s)) {
3124                 auto &call = s.as<func_call_t>();
3125 
3126                 auto &dst = dpas_t::arg_dst(s);
3127                 auto src0 = expr_t(0); // Will be translated to null register.
3128                 auto &src1 = dpas_t::arg_src1(s);
3129                 auto &src2 = dpas_t::arg_src2(s);
3130 
3131                 auto new_call = func_call_t::make(
3132                         call.func, {dst, src0, src1, src2}, call.attr);
3133                 ret = substitute(ret, s, new_call, 1);
3134             } else if (is_func_call<mad_t>(s)) {
3135                 auto &call = s.as<func_call_t>();
3136 
3137                 auto &dst = mad_t::arg_dst(s);
3138                 auto src0 = expr_t(0); // Will be translated to null register.
3139                 auto &src1 = mad_t::arg_src1(s);
3140                 auto &src2 = mad_t::arg_src2(s);
3141 
3142                 if (!seen_dst.insert(dst).second) continue;
3143 
3144                 auto new_call = func_call_t::make(
3145                         call.func, {dst, src0, src1, src2}, call.attr);
3146                 ret = substitute(ret, s, new_call, 1);
3147             }
3148         }
3149         return ret;
3150     }
3151 
3152     // Returns memory buffers if is_mem is true and register buffers otherwise.
find_send_buffers(const stmt_t & s,bool is_mem)3153     static object_set_t<expr_t> find_send_buffers(
3154             const stmt_t &s, bool is_mem) {
3155         object_set_t<expr_t> ret;
3156         auto calls = find_objects<func_call_t>(s);
3157         for (auto &_c : calls) {
3158             auto &c = _c.as<func_call_t>();
3159             if (!c.func.is<send_t>()) continue;
3160             auto &buf = (is_mem ? send_t::arg_mem_buf(_c)
3161                                 : send_t::arg_reg_buf(_c));
3162             ret.insert(buf.as<ptr_t>().base);
3163         }
3164         return ret;
3165     }
3166 
inject_local_let(const stmt_t & _s,const std::vector<stmt_t> & enclosed_lets,int id)3167     static stmt_t inject_local_let(const stmt_t &_s,
3168             const std::vector<stmt_t> &enclosed_lets, int id) {
3169         auto s = _s;
3170 
3171         // Inject let statements from the innermost loop.
3172         for (auto &_let : enclosed_lets) {
3173             auto &let = _let.as<let_t>();
3174             s = let_t::make(let.var, let.value, s);
3175         }
3176 
3177         // Substitute variables to avoid clashing.
3178         auto lets = find_objects<let_t>(s);
3179         for (auto &_let : lets) {
3180             auto &let = _let.as<let_t>();
3181             auto &var = let.var.as<var_t>();
3182             auto local_var = var_t::make(
3183                     var.type, var.name + "_" + std::to_string(id));
3184             s = substitute(s, let.var, local_var);
3185         }
3186         return s;
3187     }
3188 
3189     const conv_config_t &cfg_;
3190     ir_context_t &ir_ctx_;
3191     int ab_slm_size_;
3192 
3193     stmt_t root_;
3194     alloc_manager_t alloc_mgr_;
3195     compute_step_t step_;
3196     compute_loop_nest_t loop_nest_;
3197     compute_params_t params_;
3198 
3199     std::vector<buffer_info_t> g2s_reg_bufs_; // For SLM buffering.
3200 };
3201 
3202 // Injects loop unrolling based on the config. Possible options:
3203 // - Without preload (no SLM buffering, no prefetch)
3204 // - With SLM buffering
3205 // - With prefetch
inject_unrolling(const stmt_t & s,const conv_config_t & cfg,ir_context_t & ir_ctx,int ab_slm_size)3206 stmt_t inject_unrolling(const stmt_t &s, const conv_config_t &cfg,
3207         ir_context_t &ir_ctx, int ab_slm_size) {
3208     auto ret = unrolling_injector_t(s, cfg, ir_ctx, ab_slm_size).inject();
3209     trace_pass("inject_unrolling", ret);
3210     return ret;
3211 }
3212 
3213 class store_splitter_t : public ir_mutator_t {
3214 public:
store_splitter_t(ngen::HW hw)3215     store_splitter_t(ngen::HW hw) : hw_(hw) {}
3216 
_mutate(const store_t & obj)3217     object_t _mutate(const store_t &obj) override {
3218         int elems = obj.value.type().elems();
3219         int elem_size = obj.value.type().scalar().size();
3220         int stride = (obj.has_default_stride() ? 1 : obj.stride / elem_size);
3221         int store_size = elem_size * stride * elems;
3222         const auto grf_size = ngen::GRF::bytes(hw_);
3223         if (store_size <= 2 * grf_size) return ir_mutator_t::_mutate(obj);
3224 
3225         int step = 2 * grf_size / (stride * elem_size);
3226         stmt_t new_stmt;
3227         for (int i = 0; i < elems; i += step) {
3228             int cur_elems = std::min(step, elems - i);
3229             ir_assert(math::is_pow2(cur_elems));
3230             int off = i * stride * elem_size;
3231             auto store = store_t::make(obj.buf, obj.off + off,
3232                     split_expr(obj.value, i, i + cur_elems), obj.stride);
3233             new_stmt = new_stmt.append(store);
3234         }
3235         return std::move(new_stmt);
3236     }
3237 
3238 private:
split_expr(const expr_t & e,int beg,int end)3239     static expr_t split_expr(const expr_t &e, int beg, int end) {
3240         auto *shuffle = e.as_ptr<shuffle_t>();
3241         if (shuffle) return shuffle_t::make(shuffle, beg, end);
3242 
3243         auto *binary = e.as_ptr<binary_op_t>();
3244         if (binary) {
3245             auto a = split_expr(binary->a, beg, end);
3246             auto b = split_expr(binary->b, beg, end);
3247             return binary_op_t::make(binary->op_kind, a, b);
3248         }
3249         ir_error_not_expected();
3250         return expr_t();
3251     }
3252 
3253     ngen::HW hw_;
3254 };
3255 
3256 // Splits wide GRF stores otherwise unsupported in HW.
split_wide_stores(ngen::HW hw,const stmt_t & s)3257 stmt_t split_wide_stores(ngen::HW hw, const stmt_t &s) {
3258     auto ret = store_splitter_t(hw).mutate(s);
3259     trace_pass("split_wide_stores", ret);
3260     return ret;
3261 }
3262 
3263 class peephole_optimizer_t : public ir_mutator_t {
3264 public:
_mutate(const binary_op_t & obj)3265     object_t _mutate(const binary_op_t &obj) override {
3266         auto old_obj = ir_mutator_t::_mutate(obj);
3267         auto new_obj
3268                 = simplify_rewrite_with_ternary(old_obj, /*recursive=*/false);
3269         auto *ternary = new_obj.as_ptr<ternary_op_t>();
3270         if (!ternary) return std::move(new_obj);
3271 
3272         switch (ternary->op_kind) {
3273             case op_kind_t::_add3: {
3274                 bool ok = true;
3275                 // Allowed form: add3(dword/word, dword/word, dword/word).
3276                 ok &= add3_type_ok(ternary->a);
3277                 ok &= add3_type_ok(ternary->b);
3278                 ok &= add3_type_ok(ternary->c);
3279                 ok &= !is_const(ternary->a);
3280                 ok &= !is_const(ternary->b);
3281                 if (!ok) new_obj = old_obj;
3282                 break;
3283             }
3284             case op_kind_t::_mad: {
3285                 auto a_type = real_type(ternary->a);
3286                 auto b_type = real_type(ternary->b);
3287                 auto c_type = real_type(ternary->c);
3288                 bool ok = true;
3289                 // Allowed form: mad(dword, dword, word).
3290                 ok &= utils::one_of(a_type, type_t::s32(), type_t::u32());
3291                 ok &= utils::one_of(b_type, type_t::s32(), type_t::u32());
3292                 ok &= utils::one_of(c_type, type_t::s16(), type_t::u16());
3293                 if (!ok) new_obj = old_obj;
3294                 break;
3295             }
3296             default: ir_error_not_expected();
3297         }
3298         return std::move(new_obj);
3299     }
3300 
3301 private:
real_type(const expr_t & e)3302     static type_t real_type(const expr_t &e) {
3303         auto *imm = e.as_ptr<int_imm_t>();
3304         if (!imm) return e.type();
3305         if (int_imm_t::try_shrink_type<int16_t>(imm->value))
3306             return type_t::s16();
3307         if (int_imm_t::try_shrink_type<int32_t>(imm->value))
3308             return type_t::s32();
3309         return type_t::s64();
3310     }
3311 
add3_type_ok(const expr_t & e)3312     static bool add3_type_ok(const expr_t &e) {
3313         auto t = real_type(e);
3314         if (!t.is_scalar()) return false;
3315         switch (t.kind()) {
3316             case type_kind_t::s32:
3317             case type_kind_t::u32: return !is_const(e);
3318             case type_kind_t::s16:
3319             case type_kind_t::u16: return true;
3320             default: return false;
3321         }
3322     }
3323 };
3324 
optimize_peephole(const stmt_t & s)3325 stmt_t optimize_peephole(const stmt_t &s) {
3326     auto ret = peephole_optimizer_t().mutate(s);
3327     trace_pass("optimize_peephole", ret);
3328     return ret;
3329 }
3330 
3331 class if_condition_fixer_t : public ir_mutator_t {
3332 public:
if_condition_fixer_t(int simd_size)3333     if_condition_fixer_t(int simd_size) : simd_size_(simd_size) {}
3334 
_mutate(const if_t & obj)3335     object_t _mutate(const if_t &obj) override {
3336         auto _new_obj = ir_mutator_t::_mutate(obj);
3337         auto &new_obj = _new_obj.as<if_t>();
3338         auto cond = shuffle_t::make_broadcast(new_obj.cond, simd_size_);
3339         return if_t::make(cond, new_obj.body, new_obj.else_body);
3340     }
3341 
3342 private:
3343     int simd_size_;
3344 };
3345 
3346 // Injects broadcasts for scalar if conditions. Example:
3347 // Before:
3348 //     if (cond) { ... }
3349 // After (for SIMD8):
3350 //     if (bcast8(cond)) { ... }
fixup_if_conditions(const stmt_t & s,const conv_config_t & cfg)3351 stmt_t fixup_if_conditions(const stmt_t &s, const conv_config_t &cfg) {
3352     auto ret = if_condition_fixer_t(cfg.simd_size).mutate(s);
3353     trace_pass("fixup_if_conditions", ret);
3354     return ret;
3355 }
3356 
3357 class loop_unroller_t : public ir_mutator_t {
3358 public:
loop_unroller_t(ir_context_t & ir_ctx)3359     loop_unroller_t(ir_context_t &ir_ctx) : ir_ctx_(ir_ctx) {}
3360 
_mutate(const for_t & obj)3361     object_t _mutate(const for_t &obj) override {
3362         auto new_obj = ir_mutator_t::_mutate(obj);
3363         auto &_for = new_obj.as<for_t>();
3364         // No unrolling.
3365         if (_for.unroll == 1) return new_obj;
3366 
3367         ir_assert(is_const(obj.init))
3368                 << "Can't unroll loop with non-const bound: " << obj.init;
3369         ir_assert(is_const(obj.bound))
3370                 << "Can't unroll loop with non-const bound: " << obj.bound;
3371 
3372         auto init = to_cpp<int>(obj.init);
3373         auto bound = to_cpp<int>(obj.bound);
3374 
3375         ir_assert(_for.unroll == (bound - init))
3376                 << "Only full loop unroll is supported.";
3377 
3378         stmt_t ret;
3379         for (int i = init; i < bound; i++) {
3380             auto iter_stmt
3381                     = substitute(obj.body, obj.var, to_expr(i, obj.var.type()));
3382             iter_stmt = rename_let_alloc(iter_stmt, i - init);
3383             ret = ret.append(iter_stmt);
3384         }
3385         return std::move(ret);
3386     }
3387 
3388 private:
rename_let_alloc(const stmt_t & s,int idx)3389     stmt_t rename_let_alloc(const stmt_t &s, int idx) {
3390         auto lets = find_objects<let_t>(s);
3391         auto ret = s;
3392         for (auto &_let : lets) {
3393             auto &let = _let.as<let_t>();
3394             auto &var = let.var.as<var_t>();
3395             auto new_var = ir_ctx_.create_tmp_var(var.type, var.name);
3396             ret = substitute(ret, let.var, new_var);
3397         }
3398         auto allocs = find_objects<alloc_t>(s);
3399         for (auto &_alloc : allocs) {
3400             auto &alloc = _alloc.as<alloc_t>();
3401             auto &buf = alloc.buf.as<var_t>();
3402             auto new_buf = ir_ctx_.create_tmp_var(buf.type, buf.name);
3403             ret = substitute(ret, alloc.buf, new_buf);
3404         }
3405         return ret;
3406     }
3407 
3408     ir_context_t &ir_ctx_;
3409 };
3410 
3411 // Unrolls loops according to their unroll attribute.
3412 // Before:
3413 //     for (int i = 0; i < 2; i++) [unroll: 2] {
3414 //         body(i);
3415 //     }
3416 // After:
3417 //     body(0);
3418 //     body(1);
unroll_loops(const stmt_t & s,ir_context_t & ir_ctx)3419 stmt_t unroll_loops(const stmt_t &s, ir_context_t &ir_ctx) {
3420     auto ret = loop_unroller_t(ir_ctx).mutate(s);
3421     trace_pass("unroll_loops", ret);
3422     return ret;
3423 }
3424 
create_reorder_stmt(const layout_t & src,const layout_t & dst,const expr_t & src_buf,const expr_t & dst_buf)3425 stmt_t create_reorder_stmt(const layout_t &src, const layout_t &dst,
3426         const expr_t &src_buf, const expr_t &dst_buf) {
3427     ir_assert(src.ndims() == dst.ndims()) << "Layouts are incompatible.";
3428     ir_assert(src.elems() == dst.elems()) << "Layouts are incompatible.";
3429     auto func = reorder_t::make(src, dst);
3430     return func.call({dst_buf, src_buf});
3431 }
3432 
create_reduce_stmt(const layout_t & src,const layout_t & dst,const expr_t & src_buf,const expr_t & dst_buf,const tensor_t & _sub_tile,uint32_t reduction_mask)3433 stmt_t create_reduce_stmt(const layout_t &src, const layout_t &dst,
3434         const expr_t &src_buf, const expr_t &dst_buf, const tensor_t &_sub_tile,
3435         uint32_t reduction_mask) {
3436     auto sub_tile = _sub_tile;
3437     if (sub_tile.is_empty()) sub_tile = tensor_t(src.dims());
3438     ir_assert(src.ndims() == sub_tile.ndims());
3439     int ndims = src.ndims();
3440 
3441     // Align dst layout with src layout according to the mask.
3442     std::vector<int> dst2src(dst.ndims());
3443     int dst_dim_idx = 0;
3444     for (int i = 0; i < ndims; i++) {
3445         if ((reduction_mask & (1 << i)) != 0) {
3446             dst2src[dst_dim_idx] = i;
3447             dst_dim_idx++;
3448         }
3449     }
3450     ir_assert(dst_dim_idx == dst.ndims()) << "Incompatible reduction mask.";
3451 
3452     auto dst_blocks = dst.blocks();
3453     for (auto &b : dst_blocks)
3454         b.dim_idx = dst2src[b.dim_idx];
3455 
3456     // Create final layouts.
3457     auto dst_aligned = layout_t(dst.type(), ndims, dst.offset(), dst_blocks);
3458 
3459     std::vector<dim_t> dst_tile_dims = sub_tile.dims();
3460     std::vector<expr_t> dst_tile_start = sub_tile.start();
3461     for (int i = 0; i < ndims; i++) {
3462         if ((reduction_mask & (1 << i)) == 0) {
3463             dst_tile_dims[i] = 1;
3464             dst_tile_start[i] = expr_t(0);
3465             continue;
3466         }
3467     }
3468     dst_aligned = dst_aligned.map(tensor_t(dst_tile_dims, dst_tile_start));
3469 
3470     auto func = reduce_t::make(src, dst_aligned);
3471     return func.call({dst_buf, src_buf});
3472 }
3473 
create_zero_out_stmt(ngen::HW hw,const expr_t & buf,int size)3474 stmt_t create_zero_out_stmt(ngen::HW hw, const expr_t &buf, int size) {
3475     stmt_t ret;
3476     int step_bytes = 2 * ngen::GRF::bytes(hw);
3477     for (int i = 0; i < size; i += step_bytes) {
3478         int cur_step_bytes = std::min(step_bytes, size - i);
3479         ret = ret.append(store_t::make(buf, i,
3480                 shuffle_t::make_broadcast(
3481                         expr_t(0.0f), cur_step_bytes / sizeof(float))));
3482     }
3483     return ret;
3484 }
3485 
3486 // Generates loads or stores to move data between memory (global or SLM) and
3487 // GRF. Memory layout is a parameter. GRF layout is deduced automatically,
3488 // according to the decomposition into messages.
3489 class access_builder_t {
3490 public:
3491     access_builder_t() = default;
3492 
access_builder_t(ngen::HW hw,ir_context_t & ir_ctx,const constraint_set_t & cset,const view_t & mem_view,const expr_t & mem_buf,const expr_t & reg_buf,bool is_slm,bool is_prefetch,bool is_load,ngen_proxy::AtomicOp atomic_op)3493     access_builder_t(ngen::HW hw, ir_context_t &ir_ctx,
3494             const constraint_set_t &cset, const view_t &mem_view,
3495             const expr_t &mem_buf, const expr_t &reg_buf, bool is_slm,
3496             bool is_prefetch, bool is_load, ngen_proxy::AtomicOp atomic_op)
3497         : hw_(hw)
3498         , ir_ctx_(&ir_ctx)
3499         , cset_(&cset)
3500         , mem_view_(mem_view)
3501         , mem_buf_(mem_buf)
3502         , reg_buf_(reg_buf)
3503         , is_slm_(is_slm)
3504         , is_prefetch_(is_prefetch)
3505         , is_load_(is_load)
3506         , atomic_op_(atomic_op) {
3507         build();
3508     }
3509 
is_slm() const3510     bool is_slm() const { return is_slm_; }
3511 
is_prefetch() const3512     bool is_prefetch() const { return is_prefetch_; }
3513 
reg_layout() const3514     const layout_t &reg_layout() const { return reg_layout_; }
3515 
reg_buf_size() const3516     int reg_buf_size() const { return reg_buf_size_; }
3517 
stmt() const3518     const stmt_t &stmt() const { return stmt_; }
3519 
str() const3520     std::string str() const {
3521         const auto grf_size = ngen::GRF::bytes(hw_);
3522         std::ostringstream oss;
3523         oss << "Memory view:          " << mem_view_ << std::endl;
3524         oss << "Register layout:      " << reg_layout_ << std::endl;
3525         oss << "Register buffer:      " << reg_buf_ << std::endl;
3526         oss << "Register buffer size: " << reg_buf_size_ << " ("
3527             << reg_buf_size_ / grf_size << " regs)" << std::endl;
3528         oss << "Statement:            " << std::endl << stmt_;
3529         return oss.str();
3530     }
3531 
3532 private:
build()3533     void build() {
3534         auto send_list = get_send_list(mem_view_.type());
3535 
3536         auto mask_tensor = mem_view_.create_mask_tensor(*cset_);
3537 
3538         // Find the first send candidate matching the layout.
3539         func_t _send;
3540         tensor_t send_tensor;
3541         for (auto &_s_base : send_list) {
3542             auto &s_base = _s_base.as<send_t>();
3543             int type_size = mem_view_.type().size();
3544             int block_bytes_base = s_base.block_size();
3545             if (block_bytes_base % type_size != 0) continue;
3546             int elems_per_block_base = block_bytes_base / type_size;
3547 
3548             dim_t elems_per_block = elems_per_block_base;
3549             dim_t slots = s_base.slots;
3550 
3551             // Check if the view can be decomposed for this send.
3552             auto tensor
3553                     = mem_view_.split_into_dense_tile(elems_per_block, slots);
3554             if (tensor.is_empty()) continue;
3555 
3556             auto _s = s_base.adjust(
3557                     int(elems_per_block * type_size), int(slots));
3558             if (_s.is_empty()) continue;
3559             auto &s = _s.as<send_t>();
3560 
3561             // Check if this send supports the required mask.
3562             if (!has_compatible_mask(s, mem_view_, tensor, mask_tensor))
3563                 continue;
3564 
3565             // TODO: Check alignment requirements.
3566 
3567             // Success, send is found, stop iterating.
3568             _send = _s;
3569             send_tensor = tensor;
3570             break;
3571         }
3572         // Support for prefetch messages is limited. If message is not found,
3573         // skip prefetch generation.
3574         if (_send.is_empty() && is_prefetch()) return;
3575         ir_assert(!_send.is_empty()) << "Can't decompose view into messages.";
3576 
3577         auto &send = _send.as<send_t>();
3578         reg_layout_ = create_register_layout_for_message(
3579                 send, mem_view_, reg_buf_size_);
3580 
3581         mem_view_.for_each_tile(
3582                 send_tensor, [&](const std::vector<dim_t> &start) {
3583                     auto tile = tensor_t(send_tensor.dims(), start);
3584                     auto sub_view = mem_view_.create_sub_view(tile);
3585                     auto sub_mask_tensor = mask_tensor.map(tile);
3586                     auto reg_sub_buf = (is_prefetch()
3587                                     ? expr_t()
3588                                     : reg_buf_[reg_layout_(start)
3589                                             * reg_layout_.type().size()]);
3590                     stmt_ = stmt_seq_t::make(stmt_,
3591                             create_send_stmt(*ir_ctx_, send, mem_buf_,
3592                                     reg_sub_buf, sub_view, sub_mask_tensor));
3593                 });
3594     }
3595 
3596     // Returns a list of send functions that can be used for the access.
get_send_list(const type_t & data_type) const3597     std::vector<func_t> get_send_list(const type_t &data_type) const {
3598         using namespace ngen_proxy;
3599         bool is_atomic = (atomic_op_ != AtomicOp::undef);
3600         Access access_type = (is_load_ ? Access::Read : Access::Write);
3601         // TODO: use stateless access on XeHPC until driver fix
3602         bool use_stateful_msgs = is_atomic && hw_ < ngen::HW::XeHPC;
3603         AddressModel address_model
3604                 = (is_slm() ? AddressModel::ModelSLM
3605                             : use_stateful_msgs ? AddressModel::ModelBTS
3606                                                 : AddressModel::ModelA64);
3607         auto send_list = send_t::get_all(hw_, data_type, access_type,
3608                 address_model, atomic_op_, is_prefetch_);
3609         return send_list;
3610     }
3611 
3612     ngen::HW hw_;
3613     ir_context_t *ir_ctx_;
3614     const constraint_set_t *cset_;
3615 
3616     view_t mem_view_;
3617     expr_t mem_buf_;
3618     layout_t reg_layout_;
3619     expr_t reg_buf_;
3620     int reg_buf_size_;
3621     bool is_slm_;
3622     bool is_prefetch_;
3623     bool is_load_;
3624     stmt_t stmt_;
3625     ngen_proxy::AtomicOp atomic_op_;
3626 };
3627 
3628 class read_builder_t : public access_builder_t {
3629 public:
3630     read_builder_t() = default;
3631 
read_builder_t(ngen::HW hw,ir_context_t & ir_ctx,const constraint_set_t & cset,const view_t & view,const expr_t & mem_buf,const expr_t & reg_buf,bool is_slm,bool is_prefetch=false)3632     read_builder_t(ngen::HW hw, ir_context_t &ir_ctx,
3633             const constraint_set_t &cset, const view_t &view,
3634             const expr_t &mem_buf, const expr_t &reg_buf, bool is_slm,
3635             bool is_prefetch = false)
3636         : access_builder_t(hw, ir_ctx, cset, view, mem_buf, reg_buf, is_slm,
3637                 is_prefetch, /*is_load=*/true, ngen_proxy::AtomicOp::undef) {}
3638 };
3639 
3640 class write_builder_t : public access_builder_t {
3641 public:
3642     write_builder_t() = default;
3643 
write_builder_t(ngen::HW hw,ir_context_t & ir_ctx,const constraint_set_t & cset,const view_t & view,const expr_t & mem_buf,const expr_t & reg_buf,bool is_slm,ngen_proxy::AtomicOp atomic_op=ngen_proxy::AtomicOp::undef)3644     write_builder_t(ngen::HW hw, ir_context_t &ir_ctx,
3645             const constraint_set_t &cset, const view_t &view,
3646             const expr_t &mem_buf, const expr_t &reg_buf, bool is_slm,
3647             ngen_proxy::AtomicOp atomic_op = ngen_proxy::AtomicOp::undef)
3648         : access_builder_t(hw, ir_ctx, cset, view, mem_buf, reg_buf, is_slm,
3649                 /*is_prefetch=*/false, /*is_load=*/false, atomic_op) {}
3650 };
3651 
3652 // Generates loads to the post-op buffer and applies a single post-op.
3653 // There are two types of post-ops:
3654 // - Eltwise: lhs = F(lhs)
3655 // - Binary:  lhs = F(lhs, rhs)
3656 // Binary requires rhs load which may be either:
3657 // - Pre-loaded and used for all updates (preferred approach)
3658 // - Loaded for every tile
3659 // Right-hand side tensor supports implicit broadcasting: value is broadcasted
3660 // across a size one dimension.
3661 class post_op_builder_t {
3662 public:
post_op_builder_t(ngen::HW hw,ir_context_t & ir_ctx,const constraint_set_t & cset,const post_op_t & post_op,int & available_pre_load_size)3663     post_op_builder_t(ngen::HW hw, ir_context_t &ir_ctx,
3664             const constraint_set_t &cset, const post_op_t &post_op,
3665             int &available_pre_load_size)
3666         : hw_(hw), ir_ctx_(ir_ctx), cset_(cset), post_op_(post_op) {
3667         if (!post_op_.needs_load()) return;
3668 
3669         // Estimate buffer size required to load full rhs, do not do pre-load
3670         // if it requires too much GRF memory.
3671         int estimated_rhs_bytes = 0;
3672 
3673         estimated_rhs_bytes
3674                 += int(post_op.rhs_view().create_dense_vlayout().size());
3675 
3676         if (needs_rhs_convert()) {
3677             estimated_rhs_bytes += int(post_op.rhs_view()
3678                                                .create_dense_vlayout()
3679                                                .retype(type_t::f32())
3680                                                .size());
3681         }
3682 
3683         if (estimated_rhs_bytes <= available_pre_load_size) {
3684             available_pre_load_size -= estimated_rhs_bytes;
3685             do_preload_ = true;
3686         }
3687     }
3688 
3689     // Pre-loads rhs data for the whole update.
build_pre_load()3690     stmt_t build_pre_load() {
3691         if (!do_preload_) return stmt_t();
3692 
3693         auto rhs_load_reg_buf = make_tmp_rhs_buffer();
3694         read_builder_t read(hw_, ir_ctx_, cset_, post_op_.rhs_view(),
3695                 post_op_.rhs_buf(), rhs_load_reg_buf, /*is_slm=*/false);
3696         pre_load_rhs_reg_buf_ = rhs_load_reg_buf;
3697         pre_load_rhs_reg_layout_ = read.reg_layout();
3698         if (!needs_rhs_convert()) rhs_reg_buf_ = rhs_load_reg_buf;
3699         update_rhs_buf_size(rhs_load_reg_buf, read.reg_buf_size());
3700         return read.stmt();
3701     }
3702 
3703     // Converts the pre-loaded rhs data to f32.
build_pre_convert()3704     stmt_t build_pre_convert() {
3705         if (!do_preload_ || !needs_rhs_convert()) return stmt_t();
3706 
3707         auto rhs_f32_reg_buf = make_tmp_rhs_buffer();
3708         auto f32_layout
3709                 = pre_load_rhs_reg_layout_.make_dense().retype(type_t::f32());
3710         update_rhs_buf_size(rhs_f32_reg_buf, int(f32_layout.size()));
3711 
3712         // Reorder to f32.
3713         auto ret = create_reorder_stmt(pre_load_rhs_reg_layout_, f32_layout,
3714                 pre_load_rhs_reg_buf_, rhs_f32_reg_buf);
3715 
3716         // Now rhs is converted to f32.
3717         pre_load_rhs_reg_layout_ = f32_layout;
3718         rhs_reg_buf_ = rhs_f32_reg_buf;
3719 
3720         return ret;
3721     }
3722 
3723     // Loads rhs data for one tile.
build_tile_load(const tensor_t & tile)3724     stmt_t build_tile_load(const tensor_t &tile) {
3725         if (!post_op_.needs_load()) return stmt_t();
3726 
3727         stmt_t stmt;
3728         auto rhs_tile = post_op_.apply_mask(tile);
3729         if (post_op_.needs_load() && !do_preload_) {
3730             // Load and convert now.
3731             auto po = post_op_.create_sub_post_op(rhs_tile);
3732             auto rhs_load_reg_buf = make_tmp_rhs_buffer();
3733             read_builder_t read(hw_, ir_ctx_, cset_, po.rhs_view(),
3734                     po.rhs_buf(), rhs_load_reg_buf,
3735                     /*is_slm=*/false);
3736             stmt = stmt.append(read.stmt());
3737 
3738             update_rhs_buf_size(rhs_load_reg_buf, read.reg_buf_size());
3739 
3740             if (needs_rhs_convert()) {
3741                 auto rhs_f32_reg_buf = make_tmp_rhs_buffer();
3742                 auto f32_layout
3743                         = read.reg_layout().make_dense().retype(type_t::f32());
3744                 update_rhs_buf_size(rhs_f32_reg_buf, int(f32_layout.size()));
3745                 // Reorder to f32.
3746                 stmt = stmt.append(create_reorder_stmt(read.reg_layout(),
3747                         f32_layout, rhs_load_reg_buf, rhs_f32_reg_buf));
3748 
3749                 // Now rhs is converted to f32.
3750                 rhs_reg_layout_ = f32_layout;
3751                 rhs_reg_buf_ = rhs_f32_reg_buf;
3752             } else {
3753                 rhs_reg_layout_ = read.reg_layout();
3754                 rhs_reg_buf_ = rhs_load_reg_buf;
3755             }
3756         } else {
3757             // Already pre-loaded and pre-converted.
3758             rhs_reg_layout_ = pre_load_rhs_reg_layout_.map(rhs_tile);
3759         }
3760         return stmt;
3761     }
3762 
3763     // Applies post-op for a single tile.
build_tile_stmt(const tensor_t & tile,const layout_t & lhs_reg_layout,const expr_t & lhs_buf)3764     stmt_t build_tile_stmt(const tensor_t &tile, const layout_t &lhs_reg_layout,
3765             const expr_t &lhs_buf) {
3766         auto po = post_op_.create_sub_post_op(tile);
3767         if (!po.has_rhs()) {
3768             // Apply eltwise post-op.
3769             int lhs_size = lhs_reg_layout.size();
3770             int lhs_elems = lhs_size / int(sizeof(float));
3771             return po.eltwise().call({expr_t(lhs_elems), lhs_buf});
3772         }
3773 
3774         auto lhs_layout = lhs_reg_layout;
3775         auto rhs_layout = (po.needs_load()
3776                         ? rhs_reg_layout_
3777                         : lhs_layout.map(
3778                                 tensor_t(std::vector<dim_t>(tile.ndims(), 1))));
3779 
3780         int inner_dim_idx = lhs_layout.blocks().front().dim_idx;
3781         bool do_broadcast = po.is_broadcast_dim(inner_dim_idx);
3782         if (!do_broadcast) layout_t::align_layouts(lhs_layout, rhs_layout);
3783 
3784         auto lhs_blocks = lhs_layout.blocks();
3785         auto rhs_blocks = rhs_layout.blocks();
3786 
3787         auto &lhs0 = lhs_blocks[0];
3788 
3789         ir_assert(lhs0.dim_idx == inner_dim_idx);
3790         ir_assert(dim_t(lhs0.stride) == 1);
3791 
3792         if (!do_broadcast) {
3793             auto &rhs0 = rhs_blocks[0];
3794             ir_assert(lhs0.dim_idx == rhs0.dim_idx);
3795             ir_assert(lhs0.block == rhs0.block);
3796             MAYBE_UNUSED(rhs0);
3797         }
3798 
3799         std::vector<dim_t> inner_tile_dims(tile.ndims(), 1);
3800         inner_tile_dims[inner_dim_idx] = lhs0.block;
3801 
3802         auto &lhs_type = lhs_layout.type();
3803         auto &rhs_type = rhs_layout.type();
3804         ir_assert(lhs_type == type_t::f32());
3805         ir_assert(rhs_type == type_t::f32());
3806 
3807         // Handle one inner tile at a time. Inner tile covers a single block
3808         // with a single dimension.
3809         stmt_t stmt;
3810         lhs_layout.for_each_tile(tensor_t(inner_tile_dims),
3811                 [&](const std::vector<dim_t> &lhs_start) {
3812                     auto rhs_start = po.apply_mask(lhs_start, 0);
3813                     int lhs_off0 = lhs_layout(lhs_start) * lhs_type.size();
3814                     int rhs_off0 = rhs_layout(rhs_start) * rhs_type.size();
3815 
3816                     int elems = lhs0.block;
3817                     int step = (elems < 16 ? 8 : 16);
3818                     for (int i = 0; i < elems; i += step) {
3819                         int cur_elems = std::min(step, elems - i);
3820                         ir_assert(math::is_pow2(cur_elems));
3821                         auto lhs_vec_type = lhs_type.with_elems(cur_elems);
3822                         auto rhs_vec_type = rhs_type.with_elems(
3823                                 do_broadcast ? 1 : cur_elems);
3824 
3825                         int lhs_off = lhs_off0 + i * lhs_type.size();
3826                         int rhs_off = rhs_off0;
3827                         if (!do_broadcast) rhs_off += i * rhs_type.size();
3828 
3829                         auto lhs = load_t::make(lhs_vec_type, lhs_buf, lhs_off);
3830                         expr_t rhs;
3831                         if (po.needs_load()) {
3832                             int stride
3833                                     = (do_broadcast ? load_t::default_stride
3834                                                     : int(rhs_blocks[0].stride)
3835                                                             * rhs_type.size());
3836                             rhs = load_t::make(rhs_vec_type, rhs_reg_buf_,
3837                                     rhs_off, stride);
3838                         } else {
3839                             // rhs is scalar and passed in the kernel arguments.
3840                             rhs = po.rhs_buf();
3841                             ir_assert(rhs.type().is_scalar());
3842                         }
3843 
3844                         if (rhs.type().elems() != cur_elems) {
3845                             rhs = shuffle_t::make_broadcast(rhs, cur_elems);
3846                         }
3847 
3848                         if (po.rhs_scale() != 1) {
3849                             // Scale rhs first.
3850                             rhs = binary_op_t::make(op_kind_t::_mul, rhs,
3851                                     shuffle_t::make_broadcast(
3852                                             po.rhs_scale(), cur_elems));
3853                         }
3854 
3855                         auto new_lhs
3856                                 = binary_op_t::make(po.op_kind(), lhs, rhs);
3857                         if (new_lhs.type().is_bool()) {
3858                             // Apply bool -> f32 cast when binary is a comparison op.
3859                             new_lhs = cast(new_lhs, type_t::f32(cur_elems));
3860                         }
3861                         auto store = store_t::make(lhs_buf, lhs_off, new_lhs);
3862                         stmt = stmt.append(store);
3863                     }
3864                 });
3865 
3866         // Reset rhs layout.
3867         rhs_reg_layout_ = layout_t();
3868         return stmt;
3869     }
3870 
allocs() const3871     std::vector<stmt_t> allocs() const {
3872         std::vector<stmt_t> allocs;
3873         for (auto &kv : rhs_bufs_)
3874             allocs.push_back(
3875                     alloc_t::make(kv.first, kv.second, alloc_kind_t::grf));
3876         return allocs;
3877     }
3878 
3879 private:
make_tmp_rhs_buffer() const3880     expr_t make_tmp_rhs_buffer() const {
3881         auto &rhs_name = post_op_.rhs_buf().as<var_t>().name;
3882         return ir_ctx_.create_tmp_var(type_t::byte_ptr(), "tmp_" + rhs_name);
3883     }
3884 
update_rhs_buf_size(const expr_t & buf,int size)3885     void update_rhs_buf_size(const expr_t &buf, int size) {
3886         rhs_bufs_[buf] = std::max(rhs_bufs_[buf], size);
3887     }
3888 
needs_rhs_convert() const3889     bool needs_rhs_convert() const {
3890         if (!post_op_.has_rhs()) return false;
3891         return post_op_.rhs_view().type() != type_t::f32();
3892     }
3893 
3894     ngen::HW hw_;
3895     ir_context_t &ir_ctx_;
3896     const constraint_set_t &cset_;
3897     post_op_t post_op_;
3898 
3899     bool do_preload_ = false;
3900 
3901     expr_t pre_load_rhs_reg_buf_;
3902     layout_t pre_load_rhs_reg_layout_;
3903 
3904     expr_t rhs_reg_buf_;
3905     layout_t rhs_reg_layout_;
3906 
3907     object_map_t<expr_t, int> rhs_bufs_;
3908 };
3909 
3910 // Zero pads a register buffer of f32 type.
3911 class zero_pad_builder_t {
3912 public:
zero_pad_builder_t(const constraint_set_t & cset,const post_op_context_t & post_op_ctx,const view_t & mem_view,const layout_t & reg_layout,const expr_t & reg_buf)3913     zero_pad_builder_t(const constraint_set_t &cset,
3914             const post_op_context_t &post_op_ctx, const view_t &mem_view,
3915             const layout_t &reg_layout, const expr_t &reg_buf)
3916         : cset_(cset)
3917         , post_op_ctx_(post_op_ctx)
3918         , mem_view_(mem_view)
3919         , reg_layout_(reg_layout)
3920         , reg_buf_(reg_buf) {
3921         ir_assert(mem_view_.nvdims() == reg_layout_.ndims())
3922                 << "Incompatible view/layout.";
3923         build();
3924     }
3925 
stmt() const3926     const stmt_t &stmt() const { return stmt_; }
3927 
3928 private:
build()3929     void build() {
3930         int max_step = 16; // Handle 16 elements at most in one step.
3931         auto base_tile = reg_layout_.split_into_max_tile(
3932                 max_step, /*is_dense_tile=*/true);
3933         reg_layout_.for_each_tile(
3934                 base_tile, [&](const std::vector<dim_t> &start) {
3935                     tensor_t tile(base_tile.dims(), start);
3936                     auto sub_layout = reg_layout_.map(tile);
3937                     auto sub_view = mem_view_.create_sub_view(tile);
3938                     int elems = tile.elems();
3939                     int off = reg_layout_(start) * reg_layout_.type().size();
3940                     auto mask_tensor = create_mask(sub_view, sub_layout);
3941                     auto mask = mask_tensor.to_expr(elems);
3942                     auto store = store_t::make(reg_buf_, off,
3943                             shuffle_t::make_broadcast(expr_t(0.0f), elems),
3944                             store_t::default_stride, -mask);
3945                     stmt_ = stmt_.append(store);
3946                 });
3947     }
3948 
create_mask(const view_t & view,const layout_t & layout) const3949     mask_tensor_t create_mask(
3950             const view_t &view, const layout_t &layout) const {
3951         mask_tensor_t mask_tensor(layout);
3952         std::vector<dim_t> args(layout.ndims());
3953         fill_mask_impl(mask_tensor, 0, args, view, layout);
3954         mask_tensor.simplify(cset_);
3955         return mask_tensor;
3956     }
3957 
fill_mask_impl(mask_tensor_t & mask_tensor,int idx,std::vector<dim_t> & args,const view_t & view,const layout_t & layout) const3958     void fill_mask_impl(mask_tensor_t &mask_tensor, int idx,
3959             std::vector<dim_t> &args, const view_t &view,
3960             const layout_t &layout) const {
3961         if (idx == layout.ndims()) {
3962             expr_t mask = bool_imm_t::make(true);
3963             for (int i = 0; i < layout.ndims(); i++) {
3964                 if (!post_op_ctx_.is_lhs_dim_zero_padded(i)) continue;
3965                 mask &= (view.vstart(i) + args[i] < post_op_ctx_.lhs_dim(i));
3966             }
3967             auto off = layout.offset(args, /*ignore_offset=*/true);
3968             mask_tensor.set_mask(off, mask);
3969             return;
3970         }
3971 
3972         for (int i = 0; i < int(layout.dims()[idx]); i++) {
3973             args[idx] = i;
3974             fill_mask_impl(mask_tensor, idx + 1, args, view, layout);
3975         }
3976     }
3977 
3978     const constraint_set_t &cset_;
3979     const post_op_context_t &post_op_ctx_;
3980 
3981     view_t mem_view_;
3982     layout_t reg_layout_;
3983     expr_t reg_buf_;
3984 
3985     stmt_t stmt_;
3986 };
3987 
3988 // Performs the following steps after the computation:
3989 // - Conversion
3990 // - Applying post-ops
3991 // - GRF reorder to match the memory layout
3992 // - Store to the destination
3993 class epilogue_builder_t {
3994 public:
epilogue_builder_t(const conv_config_t & cfg,ir_context_t & ir_ctx,const constraint_set_t & cset,const post_op_context_t & post_op_ctx,const tensor_t & tile,const view_t & mem_view,const layout_t & reg_layout,const expr_t & mem_buf,const expr_t & reg_buf)3995     epilogue_builder_t(const conv_config_t &cfg, ir_context_t &ir_ctx,
3996             const constraint_set_t &cset, const post_op_context_t &post_op_ctx,
3997             const tensor_t &tile, const view_t &mem_view,
3998             const layout_t &reg_layout, const expr_t &mem_buf,
3999             const expr_t &reg_buf)
4000         : cfg_(cfg)
4001         , ir_ctx_(ir_ctx)
4002         , cset_(cset)
4003         , post_op_ctx_(post_op_ctx)
4004         , mem_view_(mem_view)
4005         , reg_layout_(reg_layout)
4006         , mem_buf_(mem_buf)
4007         , reg_buf_(reg_buf) {
4008 
4009         int pre_load_size = pre_load_max_size_;
4010         for (auto &po : post_op_ctx_.post_ops()) {
4011             auto sub_po = po.create_sub_post_op(tile);
4012             post_op_builders_.emplace_back(
4013                     cfg.hw, ir_ctx, cset_, sub_po, pre_load_size);
4014         }
4015         build();
4016     }
4017 
stmt() const4018     const stmt_t &stmt() const { return stmt_; }
4019 
4020 private:
4021     // Represents one stage in the flow between multiplication and storing the
4022     // updated result to memory.
4023     //
4024     // Flow with post-ops:
4025     //   Multiplication ->
4026     //     M_x -> [R_f32] -> P0_f32 -> ... -> Pn_f32 -> [Z_f32] -> S_y ->
4027     //   GMEM
4028     // Flow without post-ops:
4029     //   Multiplication ->
4030     //     M_x -> S_y ->
4031     //   GMEM
4032     // Where:
4033     // - x      is data type after multiplication
4034     // - y      is destination data type
4035     // - M_x    is a stage after multiplication
4036     // - R_f32  is a stage after reordering from M_x to f32 (optional)
4037     // - Pi_f32 is a stage after applying Pi post-op
4038     // - Z_f32  is a stage after restoring zero padding (optional)
4039     // - S_y    is a stage before storing data to destination
4040     struct stage_t {
stage_tdnnl::impl::gpu::jit::epilogue_builder_t::stage_t4041         stage_t(const layout_t &layout, const expr_t &buf,
4042                 const stmt_t &stmt = stmt_t())
4043             : layout(layout), buf(buf), stmt(stmt) {}
4044 
set_nextdnnl::impl::gpu::jit::epilogue_builder_t::stage_t4045         void set_next(ngen::HW hw, ir_context_t &ir_ctx, stage_t *next,
4046                 bool force_reorder) {
4047             if (!next) return;
4048             bool do_reorder
4049                     = !layout.is_equal(next->layout, /*compare_offset=*/false);
4050             if (force_reorder) do_reorder = true;
4051             if (do_reorder) {
4052                 ir_assert(stmt.is_empty());
4053                 // Generate reorder between stages.
4054                 stmt = create_reorder_stmt(
4055                         layout, next->layout, buf, next->buf);
4056             } else {
4057                 // Reuse the same GRF buffer for the next stage.
4058                 int this_off = to_cpp<int>(layout.offset_in_bytes());
4059                 int next_off = to_cpp<int>(next->layout.offset_in_bytes());
4060                 ir_assert(next_off == 0);
4061                 MAYBE_UNUSED(next_off);
4062                 next->set_buf(buf[this_off]);
4063             }
4064         }
4065 
set_bufdnnl::impl::gpu::jit::epilogue_builder_t::stage_t4066         void set_buf(const expr_t &buf) {
4067             // Replace old buffer if there is an assigned statement.
4068             if (!stmt.is_empty()) { stmt = substitute(stmt, this->buf, buf); }
4069             this->buf = buf;
4070         }
4071 
buf_basednnl::impl::gpu::jit::epilogue_builder_t::stage_t4072         const expr_t &buf_base() const {
4073             if (buf.is<var_t>()) return buf;
4074             return buf.as<ptr_t>().base;
4075         }
4076 
buf_sizednnl::impl::gpu::jit::epilogue_builder_t::stage_t4077         int buf_size() const {
4078             ir_assert(buf.is_same(buf_base()))
4079                     << "Size must be queried from another stage.";
4080             return int(layout.size());
4081         }
4082 
prepend_stmtdnnl::impl::gpu::jit::epilogue_builder_t::stage_t4083         void prepend_stmt(const stmt_t &stmt) {
4084             this->stmt = stmt.append(this->stmt);
4085         }
4086 
4087         layout_t layout;
4088         expr_t buf;
4089         stmt_t stmt;
4090     };
4091 
build()4092     void build() {
4093         for (auto &po_builder : post_op_builders_) {
4094             stmt_ = stmt_.append(po_builder.build_pre_load());
4095         }
4096 
4097         for (auto &po_builder : post_op_builders_) {
4098             stmt_ = stmt_.append(po_builder.build_pre_convert());
4099         }
4100 
4101         auto tmp_type = (post_op_builders_.empty() ? mem_view_.type()
4102                                                    : type_t::f32());
4103         int tmp_buf_elems = tmp_buf_size_ / tmp_type.size();
4104         auto base_tile = mem_view_.split_into_max_tile(
4105                 tmp_buf_elems, /*is_dense=*/false);
4106         mem_view_.for_each_tile(
4107                 base_tile, [&](const std::vector<dim_t> &start) {
4108                     build_tile(tensor_t(base_tile.dims(), start));
4109                 });
4110 
4111         // Generate alloc statements for rhs post-op buffers.
4112         std::vector<stmt_t> allocs;
4113         for (auto &po_builder : post_op_builders_) {
4114             auto po_allocs = po_builder.allocs();
4115             allocs.insert(allocs.end(), po_allocs.begin(), po_allocs.end());
4116         }
4117         stmt_ = jit::inject_alloc_stmts(stmt_, allocs, /*put_innermost=*/true);
4118     }
4119 
4120     // Builds statements for a tile iterating through all stages (see stage_t
4121     // description).
build_tile(const tensor_t & tile)4122     void build_tile(const tensor_t &tile) {
4123         auto mem_sub_view = mem_view_.create_sub_view(tile);
4124         auto reg_sub_layout = reg_layout_.map(tile);
4125 
4126         auto tmp_reg_buf = ir_ctx_.create_tmp_var(type_t::byte_ptr(), "c_tmp");
4127         bool restore_zero_padding = post_op_ctx_.need_to_restore_zero_padding();
4128 
4129         // S_y -> GMEM.
4130         ngen_proxy::AtomicOp atomic_op
4131                 = (cfg_.do_atomic_update ? ngen_proxy::AtomicOp::fadd
4132                                          : ngen_proxy::AtomicOp::undef);
4133         write_builder_t r2g(cfg_.hw, ir_ctx_, cset_, mem_sub_view, mem_buf_,
4134                 tmp_reg_buf,
4135                 /*is_slm=*/false, /*atomic_op=*/atomic_op);
4136 
4137         // Initialize stages.
4138         std::vector<stage_t> stages;
4139         stages.emplace_back(reg_sub_layout, reg_buf_); // M_x
4140         if (!post_op_builders_.empty()) {
4141             auto po_layout
4142                     = r2g.reg_layout().retype(type_t::f32()).make_dense();
4143             for (int i = 0; i < int(post_op_builders_.size()); i++) {
4144                 auto buf = ir_ctx_.create_tmp_var(type_t::byte_ptr(), "c_tmp");
4145                 stages.emplace_back(po_layout, buf); // Pi_f32
4146             }
4147             if (restore_zero_padding) {
4148                 auto &last = stages.back();
4149                 stages.emplace_back(last.layout, last.buf); // Z_f32.
4150             }
4151         }
4152         stages.emplace_back(r2g.reg_layout(), tmp_reg_buf, r2g.stmt()); // S_y
4153 
4154         int nstages = int(stages.size());
4155         int npost_ops = int(post_op_builders_.size());
4156 
4157         bool is_dpasw = (cfg_.fma_kind == fma_kind_t::dpasw);
4158 
4159         // Generate reorders between stages and create buffers.
4160         for (int i = 0; i < nstages; i++) {
4161             auto *next_stage = (i + 1 < nstages ? &stages[i + 1] : nullptr);
4162             // Always perform reorder when dpasw is used. This is to ensure
4163             // that C is properly restored and permuted after dpasw.
4164             stages[i].set_next(cfg_.hw, ir_ctx_, next_stage,
4165                     /*force_reorder=*/i == 0 && is_dpasw);
4166         }
4167 
4168         std::vector<stmt_t> tile_load_stmts;
4169         for (int i = 0; i < npost_ops; i++) {
4170             auto &po_builder = post_op_builders_[i];
4171             // Generate load for post-op.
4172             tile_load_stmts.push_back(po_builder.build_tile_load(tile));
4173 
4174             // Generate post-op statement.
4175             auto &s = stages[i + 1];
4176             s.prepend_stmt(po_builder.build_tile_stmt(tile, s.layout, s.buf));
4177         }
4178 
4179         // Restore zero padding if needed.
4180         if (restore_zero_padding) {
4181             auto &s = stages[nstages - 2];
4182             zero_pad_builder_t builder(
4183                     cset_, post_op_ctx_, mem_sub_view, s.layout, s.buf);
4184             s.prepend_stmt(builder.stmt());
4185         }
4186 
4187         stmt_t tile_stmt;
4188 
4189         // Add stage statements. Emit stages in blocks to reduce GRF
4190         // consumption.
4191         int stage_blk = 8;
4192         for (int i = 0; i < nstages; i += stage_blk) {
4193             int stage_beg = i;
4194             int stage_end = std::min(nstages, i + stage_blk);
4195             int po_beg = std::max(0, i - 1);
4196             int po_end = std::min(npost_ops, i + stage_blk - 1);
4197             stmt_t blk_stmt;
4198             for (int j = po_beg; j < po_end; j++) {
4199                 blk_stmt = blk_stmt.append(tile_load_stmts[j]);
4200             }
4201             for (int j = stage_beg; j < stage_end; j++) {
4202                 blk_stmt = blk_stmt.append(stages[j].stmt);
4203             }
4204             tile_stmt = tile_stmt.append(blk_stmt);
4205         }
4206 
4207         // Generate alloc statements for stage buffers.
4208         object_set_t<expr_t> seen;
4209         for (int i = 0; i < nstages; i++) {
4210             auto &s = stages[i];
4211             auto &buf = s.buf_base();
4212             auto ret = seen.insert(buf);
4213             if (i == 0 || !ret.second) continue;
4214             tile_stmt = alloc_t::make(
4215                     buf, s.buf_size(), alloc_kind_t::grf, {}, tile_stmt);
4216         }
4217 
4218         stmt_ = stmt_.append(tile_stmt);
4219     }
4220 
4221     const conv_config_t &cfg_;
4222     ir_context_t &ir_ctx_;
4223     const constraint_set_t &cset_;
4224     const post_op_context_t &post_op_ctx_;
4225 
4226     view_t mem_view_;
4227     layout_t reg_layout_;
4228 
4229     expr_t mem_buf_;
4230     expr_t reg_buf_;
4231 
4232     // Tile size in bytes. The tile data type is:
4233     // - the destination data type without post-ops
4234     // - f32 with post-ops
4235     static const int tmp_buf_size_ = 128;
4236     static const int pre_load_max_size_ = 256;
4237 
4238     std::vector<post_op_builder_t> post_op_builders_;
4239 
4240     stmt_t stmt_;
4241 };
4242 
4243 class multiply_builder_t {
4244 public:
4245     multiply_builder_t() = default;
4246 
multiply_builder_t(const conv_config_t & cfg,const bmnk_mapper_t & bmnk_mapper,const view_t & a_view,const view_t & b_view,const expr_t & a_buf,const expr_t & b_buf,const expr_t & c_buf)4247     multiply_builder_t(const conv_config_t &cfg,
4248             const bmnk_mapper_t &bmnk_mapper, const view_t &a_view,
4249             const view_t &b_view, const expr_t &a_buf, const expr_t &b_buf,
4250             const expr_t &c_buf)
4251         : hw_(cfg.hw)
4252         , simd_size_(cfg.simd_size)
4253         , bmnk_mapper_(bmnk_mapper)
4254         , a_view_(a_view)
4255         , b_view_(b_view)
4256         , a_buf_(a_buf)
4257         , b_buf_(b_buf)
4258         , c_buf_(c_buf) {
4259         switch (cfg.fma_kind) {
4260             case fma_kind_t::dpasw:
4261             case fma_kind_t::dpas:
4262                 if (try_build_dpas()) return;
4263                 break;
4264             case fma_kind_t::mad:
4265                 if (try_build_mad()) return;
4266                 break;
4267             default: ir_error_not_expected() << "Unknown FMA kind.";
4268         }
4269 
4270         ir_error_not_expected()
4271                 << "Can't decompose into multiplication instructions.";
4272     }
4273 
stmt() const4274     const stmt_t &stmt() const { return stmt_; }
4275 
c_layout() const4276     const layout_t &c_layout() const { return c_layout_; }
4277 
a_grf_bundle()4278     ngen_proxy::Bundle a_grf_bundle() {
4279         if (!do_transpose_) return ngen_proxy::Bundle();
4280         return ngen_proxy::Bundle(1, ngen_proxy::Bundle::any);
4281     }
4282 
b_grf_bundle()4283     ngen_proxy::Bundle b_grf_bundle() {
4284         if (do_transpose_) return ngen_proxy::Bundle();
4285         return ngen_proxy::Bundle(1, ngen_proxy::Bundle::any);
4286     }
4287 
c_grf_bundle()4288     ngen_proxy::Bundle c_grf_bundle() {
4289         return ngen_proxy::Bundle(0, ngen_proxy::Bundle::any);
4290     }
4291 
str() const4292     std::string str() const {
4293         std::ostringstream oss;
4294         oss << "A view:    " << a_view_ << std::endl;
4295         oss << "B view:    " << b_view_ << std::endl;
4296         oss << "C layout:  " << c_layout_ << std::endl;
4297         oss << "Statement: " << std::endl << stmt_;
4298         return oss.str();
4299     }
4300 
4301 private:
4302     struct loop_info_t {
4303         loop_info_t() = default;
4304 
loop_info_tdnnl::impl::gpu::jit::multiply_builder_t::loop_info_t4305         loop_info_t(const expr_t &var, bmnk_kind_t bmnk_kind, int dim)
4306             : var(var), bmnk_kind(bmnk_kind), dim(dim) {}
4307 
4308         expr_t var;
4309         bmnk_kind_t bmnk_kind;
4310 
4311         int dim;
4312         int a_idx = -1;
4313         int b_idx = -1;
4314         int c_idx = -1;
4315         int block = 1;
4316     };
4317 
try_build_dpas()4318     bool try_build_dpas() {
4319         ir_assert(a_view_.can_convert_to_vlayout())
4320                 << "Views are not supported with dpas/dpasw.";
4321         ir_assert(b_view_.can_convert_to_vlayout())
4322                 << "Views are not supported with dpas/dpasw.";
4323 
4324         auto a_layout = a_view_.create_vlayout();
4325         auto b_layout = b_view_.create_vlayout();
4326 
4327         bmnk_block_mapper_t from_bmnk_mapper(bmnk_mapper_);
4328         from_bmnk_mapper.push_blocks(abc_kind_t::a, a_layout.blocks());
4329         from_bmnk_mapper.push_blocks(abc_kind_t::b, b_layout.blocks());
4330 
4331         // Convert to MNK layouts.
4332         a_layout = bmnk_mapper_.map_to_bmnk(
4333                 abc_kind_t::a, {bmnk_kind_t::m, bmnk_kind_t::k}, a_layout);
4334         b_layout = bmnk_mapper_.map_to_bmnk(
4335                 abc_kind_t::b, {bmnk_kind_t::k, bmnk_kind_t::n}, b_layout);
4336 
4337         multiply_desc_t desc(a_layout, b_layout, /*force_c_upconvert=*/true);
4338         if (!dpas_t::matches_types(
4339                     hw_, desc.a_type(), desc.b_type(), desc.c_type()))
4340             return false;
4341 
4342         int sdepth = 8;
4343         int rcount = std::min(utils::rnd_up_pow2(desc.n()), 8);
4344         auto _dpas = dpas_t::make(/*is_dpasw=*/false, simd_size_, sdepth,
4345                 rcount, desc.c_type(), desc.a_type(), desc.b_type());
4346         if (_dpas.as<dpas_t>().matches(desc)) {
4347             build_dpas(from_bmnk_mapper, _dpas.as<dpas_t>(), desc);
4348             return true;
4349         }
4350 
4351         // Try to transpose and flip: C += A * B -> C^T = B^T * A^T.
4352         rcount = std::min(utils::rnd_up_pow2(desc.m()), 8);
4353         desc = multiply_desc_t(
4354                 b_layout.transpose(), a_layout.transpose(), true);
4355         _dpas = dpas_t::make(/*is_dpasw=*/false, /*exec_size=*/simd_size_,
4356                 sdepth, rcount, desc.c_type(), desc.a_type(), desc.b_type());
4357 
4358         if (_dpas.as<dpas_t>().matches(desc)) {
4359             do_transpose_ = true;
4360             build_dpas(from_bmnk_mapper, _dpas.as<dpas_t>(), desc);
4361             return true;
4362         }
4363         return false;
4364     }
4365 
build_dpas(const bmnk_block_mapper_t & from_bmnk_mapper,const dpas_t & dpas,const multiply_desc_t & desc)4366     void build_dpas(const bmnk_block_mapper_t &from_bmnk_mapper,
4367             const dpas_t &dpas, const multiply_desc_t &desc) {
4368         int m_blk = dpas.simd_size;
4369         int n_blk = dpas.rcount;
4370         int k_blk = dpas.sdepth * 4 / dpas.src1_type.size();
4371 
4372         c_layout_ = compute_dpas_c_layout(m_blk, n_blk, dpas.c_layout(), desc);
4373 
4374         expr_t a_buf = a_buf_;
4375         expr_t b_buf = b_buf_;
4376         if (do_transpose_) std::swap(a_buf, b_buf);
4377 
4378         for (int i_k = 0; i_k < desc.k(); i_k += k_blk) {
4379             for (int i_m = 0; i_m < desc.m(); i_m += m_blk) {
4380                 for (int i_n = 0; i_n < desc.n(); i_n += n_blk) {
4381                     std::vector<int> a_args = {i_m, 0};
4382                     std::vector<int> b_args = {0, i_n};
4383                     std::vector<int> c_args = {i_m, i_n};
4384                     auto a = a_buf[desc.a_layout()(a_args)
4385                             * desc.a_type().size()];
4386                     auto b = b_buf[desc.b_layout()(b_args)
4387                             * desc.b_type().size()];
4388                     auto c = c_buf_[c_layout_(c_args) * desc.c_type().size()];
4389                     stmt_ = stmt_.append(dpas(c, c, a, b));
4390                 }
4391             }
4392         }
4393 
4394         // Transpose C layout back if needed.
4395         if (do_transpose_) c_layout_ = c_layout_.transpose();
4396 
4397         // Convert C layout back to problem notation.
4398         c_layout_ = from_bmnk_mapper.map_from_bmnk(
4399                 abc_kind_t::c, {bmnk_kind_t::m, bmnk_kind_t::n}, c_layout_);
4400     }
4401 
compute_dpas_c_layout(int m_blk,int n_blk,const layout_t & blk_layout,const multiply_desc_t & desc)4402     static layout_t compute_dpas_c_layout(int m_blk, int n_blk,
4403             const layout_t &blk_layout, const multiply_desc_t &desc) {
4404         auto c_layout = blk_layout;
4405         c_layout = c_layout.add_outer_block(1, desc.n() / n_blk);
4406         c_layout = c_layout.add_outer_block(0, desc.m() / m_blk);
4407         return c_layout;
4408     }
4409 
try_build_mad()4410     bool try_build_mad() {
4411         auto loops = create_loop_nest();
4412 
4413         if (try_build_mad_kmn_block_by_n(loops)) return true;
4414         if (try_build_mad_kmn_block_by_b(loops)) return true;
4415 
4416         return false;
4417     }
4418 
create_loop_nest() const4419     std::vector<loop_info_t> create_loop_nest() const {
4420         object_map_t<expr_t, loop_info_t> loops;
4421         for (auto *view : {&a_view_, &b_view_}) {
4422             abc_kind_t abc_kind
4423                     = (view == &a_view_ ? abc_kind_t::a : abc_kind_t::b);
4424             for (int i = 0; i < view->nvdims(); i++) {
4425                 auto &var = bmnk_mapper_.var(abc_kind, i);
4426                 int dim = int(view->vdims()[i]);
4427                 if (dim == 1) continue;
4428 
4429                 if (loops.count(var) > 0) continue;
4430                 loops[var] = loop_info_t(var, bmnk_mapper_.bmnk_kind(var), dim);
4431             }
4432         }
4433 
4434         std::vector<loop_info_t> ret;
4435         for (auto &kv : loops) {
4436             auto &loop = kv.second;
4437             loop.a_idx = bmnk_mapper_.dim_idx(abc_kind_t::a, loop.var);
4438             loop.b_idx = bmnk_mapper_.dim_idx(abc_kind_t::b, loop.var);
4439             loop.c_idx = bmnk_mapper_.dim_idx(abc_kind_t::c, loop.var);
4440             ret.push_back(kv.second);
4441         }
4442         return ret;
4443     }
4444 
4445     // Order of loops: BKMN, block by N.
try_build_mad_kmn_block_by_n(std::vector<loop_info_t> & _loops)4446     bool try_build_mad_kmn_block_by_n(std::vector<loop_info_t> &_loops) {
4447         return try_build_mad_impl(_loops,
4448                 {bmnk_kind_t::b, bmnk_kind_t::k, bmnk_kind_t::m,
4449                         bmnk_kind_t::n},
4450                 bmnk_kind_t::n);
4451     }
4452 
4453     // Order of loops: BKMN, block by B.
try_build_mad_kmn_block_by_b(std::vector<loop_info_t> & _loops)4454     bool try_build_mad_kmn_block_by_b(std::vector<loop_info_t> &_loops) {
4455         return try_build_mad_impl(_loops,
4456                 {bmnk_kind_t::b, bmnk_kind_t::k, bmnk_kind_t::m,
4457                         bmnk_kind_t::n},
4458                 bmnk_kind_t::b);
4459     }
4460 
try_build_mad_impl(std::vector<loop_info_t> & _loops,const std::vector<bmnk_kind_t> & loop_order,bmnk_kind_t block_bmnk_kind)4461     bool try_build_mad_impl(std::vector<loop_info_t> &_loops,
4462             const std::vector<bmnk_kind_t> &loop_order,
4463             bmnk_kind_t block_bmnk_kind) {
4464         auto loops = _loops;
4465         int nloops = int(loops.size());
4466         std::sort(loops.begin(), loops.end(),
4467                 [&](const loop_info_t &a, const loop_info_t &b) {
4468                     int a_key = ir_utils::find_index(loop_order, a.bmnk_kind);
4469                     int b_key = ir_utils::find_index(loop_order, b.bmnk_kind);
4470                     ir_assert(a_key != -1);
4471                     ir_assert(b_key != -1);
4472                     return a_key < b_key;
4473                 });
4474 
4475         int block_idx = -1;
4476         for (int i = 0; i < nloops; i++) {
4477             auto &l = loops[i];
4478             if (l.bmnk_kind == block_bmnk_kind) {
4479                 ir_assert(block_idx == -1) << "Can't block 2+ dimensions.";
4480                 block_idx = i;
4481             }
4482         }
4483 
4484         // Couldn't find N dimension, try different blocking scheme.
4485         if (block_idx == -1) return false;
4486 
4487         auto &block_loop = loops[block_idx];
4488 
4489         int block = simd_size_;
4490         while (block >= 1) {
4491             if (block_loop.dim % block == 0) break;
4492             block /= 2;
4493         }
4494 
4495         ir_assert(block >= 1) << "Invalid block size.";
4496         block_loop.block = block;
4497 
4498         int a_stride = 0;
4499         int b_stride = 0;
4500 
4501         // Ensure that A tile is dense.
4502         if (block_loop.a_idx != -1) {
4503             std::vector<dim_t> tile_dims(a_view_.nvdims(), 1);
4504             tile_dims[block_loop.a_idx] = block;
4505             auto layout = a_view_.create_pseudo_vlayout();
4506             auto tile = layout.map(tensor_t(tile_dims));
4507             if (!is_1d_strided(tile)) return false;
4508             a_stride = tile.blocks()[0].stride;
4509         }
4510 
4511         // Ensure that B tile is dense.
4512         if (block_loop.b_idx != -1) {
4513             std::vector<dim_t> tile_dims(b_view_.nvdims(), 1);
4514             tile_dims[block_loop.b_idx] = block;
4515             auto layout = b_view_.create_pseudo_vlayout();
4516             auto tile = layout.map(tensor_t(tile_dims));
4517             if (!is_1d_strided(tile)) return false;
4518             b_stride = tile.blocks()[0].stride;
4519         }
4520 
4521         build_mad(loops, block_loop, a_stride, b_stride);
4522         return true;
4523     }
4524 
is_1d_strided(const layout_t & layout)4525     static bool is_1d_strided(const layout_t &layout) {
4526         auto &blocks = layout.blocks();
4527         if (blocks.size() > 1) return false;
4528         return true;
4529     }
4530 
build_mad(const std::vector<loop_info_t> & loops,const loop_info_t & block_loop,int a_stride,int b_stride)4531     void build_mad(const std::vector<loop_info_t> &loops,
4532             const loop_info_t &block_loop, int a_stride, int b_stride) {
4533         ir_assert(utils::one_of(
4534                 block_loop.bmnk_kind, bmnk_kind_t::b, bmnk_kind_t::n))
4535                 << "Unsupported blocking (expected blocking by B or N).";
4536 
4537         auto &a_type = a_view_.type();
4538         auto &b_type = b_view_.type();
4539         auto c_type = multiply_desc_t::get_c_type(a_type, b_type,
4540                 /*force_c_upconvert=*/false);
4541 
4542         int block = block_loop.block;
4543         auto _mad = mad_t::make(
4544                 c_type, block, a_type, a_stride, b_type, b_stride);
4545         auto &mad = _mad.as<mad_t>();
4546 
4547         c_layout_ = compute_mad_c_layout(c_type, loops, block_loop);
4548 
4549         int nloops = int(loops.size());
4550         std::vector<int> bounds(loops.size());
4551         for (int i = 0; i < nloops; i++) {
4552             bounds[i] = loops[i].dim / loops[i].block;
4553         }
4554         std::vector<int> a_idx(a_view_.nvdims());
4555         std::vector<int> b_idx(b_view_.nvdims());
4556         std::vector<int> c_idx(c_layout_.ndims());
4557         ir_utils::for_each(bounds, [&](const std::vector<int> &idx) {
4558             for (int i = 0; i < nloops; i++) {
4559                 int full_idx = idx[i] * loops[i].block;
4560                 auto &loop = loops[i];
4561                 if (loop.a_idx != -1) a_idx[loop.a_idx] = full_idx;
4562                 if (loop.b_idx != -1) b_idx[loop.b_idx] = full_idx;
4563                 if (loop.c_idx != -1) c_idx[loop.c_idx] = full_idx;
4564             }
4565             int a_off = a_view_(a_idx) * a_type.size();
4566             int b_off = b_view_(b_idx) * b_type.size();
4567             int c_off = c_layout_(c_idx) * c_type.size();
4568             stmt_ = stmt_.append(mad(c_buf_[c_off], c_buf_[c_off],
4569                     a_buf_[a_off], b_buf_[b_off]));
4570         });
4571     }
4572 
compute_mad_c_layout(const type_t & c_type,const std::vector<loop_info_t> & loops,const loop_info_t & block_loop) const4573     layout_t compute_mad_c_layout(const type_t &c_type,
4574             const std::vector<loop_info_t> &loops,
4575             const loop_info_t &block_loop) const {
4576         layout_t c_layout(c_type, bmnk_mapper_.ndims(abc_kind_t::c), 0, {});
4577 
4578         int c_dim_idx = bmnk_mapper_.dim_idx(abc_kind_t::c, block_loop.var);
4579         c_layout = c_layout.add_outer_block(c_dim_idx, block_loop.block);
4580 
4581         for (size_t i = 0; i < loops.size(); i++) {
4582             if (loops[i].bmnk_kind == bmnk_kind_t::k) continue;
4583             int dim_idx = bmnk_mapper_.dim_idx(abc_kind_t::c, loops[i].var);
4584             int bound = loops[i].dim / loops[i].block;
4585             c_layout = c_layout.add_outer_block(dim_idx, bound);
4586         }
4587         return c_layout;
4588     }
4589 
4590     ngen::HW hw_;
4591     int simd_size_;
4592     bmnk_mapper_t bmnk_mapper_;
4593 
4594     bool do_transpose_ = false;
4595 
4596     view_t a_view_;
4597     view_t b_view_;
4598     layout_t c_layout_;
4599 
4600     expr_t a_buf_;
4601     expr_t b_buf_;
4602     expr_t c_buf_;
4603 
4604     stmt_t stmt_;
4605 };
4606 
get_fma_friendly_layout(abc_kind_t abc_kind,int simd_size,const layout_t & bmnk_layout,const type_t & a_type,const type_t & b_type)4607 layout_t get_fma_friendly_layout(abc_kind_t abc_kind, int simd_size,
4608         const layout_t &bmnk_layout, const type_t &a_type,
4609         const type_t &b_type) {
4610     bool is_a = (abc_kind == abc_kind_t::a);
4611     int mn_idx = (is_a ? 0 : 1);
4612     int k_idx = (is_a ? 1 : 0);
4613 
4614     dim_t mn_blk = bmnk_layout.dim(mn_idx);
4615     dim_t k_blk = bmnk_layout.dim(k_idx);
4616 
4617     // Cannot calculate correct r_count when !is_a, but rcount is effectively
4618     // ignored in that case as rcount mainly effects b_layout.
4619     int rcount = is_a && mn_blk < 8 ? utils::rnd_up_pow2(mn_blk) : 8;
4620     auto _dpas = dpas_t::make(/*is_dpasw=*/false, simd_size, /*sdepth=*/8,
4621             rcount, type_t::undef(), b_type, a_type);
4622     auto &dpas = _dpas.as<dpas_t>();
4623 
4624     auto dpas_layout = (is_a ? dpas.b_layout() : dpas.a_layout());
4625     dpas_layout = dpas_layout.transpose();
4626 
4627     ir_assert(dpas_layout.dim(k_idx) == k_blk);
4628     MAYBE_UNUSED(k_blk);
4629 
4630     dim_t dpas_mn_blk = dpas_layout.dim(mn_idx);
4631     dpas_layout = dpas_layout.add_outer_block(mn_idx, mn_blk / dpas_mn_blk);
4632 
4633     return dpas_layout;
4634 }
4635 
convert_to_fma_friendly_type(const conv_config_t & cfg,abc_kind_t abc_kind,const layout_t & layout,const type_t & a_type,const type_t & b_type,bool * changed=nullptr)4636 layout_t convert_to_fma_friendly_type(const conv_config_t &cfg,
4637         abc_kind_t abc_kind, const layout_t &layout, const type_t &a_type,
4638         const type_t &b_type, bool *changed = nullptr) {
4639     if (changed) *changed = false;
4640     if (cfg.fma_kind != fma_kind_t::mad) return layout;
4641 
4642     if (a_type.is_x8() && b_type.is_x8()) {
4643         if (changed) *changed = true;
4644         return layout.retype(type_t::s16()).make_strided(2);
4645     }
4646     // f16/bf16 mixed mode mad requires src2 to be f32
4647     if (a_type.is_bf16() || b_type.is_bf16()
4648             || (a_type.is_f32() && b_type.is_f16())
4649             || (a_type.is_f16() && b_type.is_f32())) {
4650         if (changed) *changed = true;
4651         return layout.retype(type_t::f32());
4652     }
4653     return layout;
4654 }
4655 
convert_to_fma_friendly_layout(const conv_config_t & cfg,abc_kind_t abc_kind,const bmnk_mapper_t & bmnk_mapper,const layout_t & layout,const type_t & a_type,const type_t & b_type,bool * changed=nullptr)4656 layout_t convert_to_fma_friendly_layout(const conv_config_t &cfg,
4657         abc_kind_t abc_kind, const bmnk_mapper_t &bmnk_mapper,
4658         const layout_t &layout, const type_t &a_type, const type_t &b_type,
4659         bool *changed = nullptr) {
4660     if (changed) *changed = false;
4661     if (!cfg.allow_grf_reorder) return layout;
4662 
4663     // GRF reorder is only supported with dpas/dpasw.
4664     if (!utils::one_of(cfg.fma_kind, fma_kind_t::dpas, fma_kind_t::dpasw)) {
4665         // mad may require type conversion.
4666         return convert_to_fma_friendly_type(
4667                 cfg, abc_kind, layout, a_type, b_type, changed);
4668     }
4669 
4670     std::vector<bmnk_kind_t> bmnk_kinds;
4671     if (abc_kind == abc_kind_t::a) {
4672         bmnk_kinds.push_back(bmnk_kind_t::m);
4673         bmnk_kinds.push_back(bmnk_kind_t::k);
4674     } else {
4675         bmnk_kinds.push_back(bmnk_kind_t::k);
4676         bmnk_kinds.push_back(bmnk_kind_t::n);
4677     }
4678 
4679     auto bmnk_layout = bmnk_mapper.map_to_bmnk(abc_kind, bmnk_kinds, layout);
4680 
4681     auto dpas_layout = get_fma_friendly_layout(
4682             abc_kind, cfg.simd_size, bmnk_layout, a_type, b_type);
4683     if (dpas_layout == bmnk_layout) return layout;
4684 
4685     if (changed) *changed = true;
4686 
4687     bmnk_block_mapper_t from_bmnk_mapper(bmnk_mapper);
4688     from_bmnk_mapper.push_blocks(abc_kind, layout.blocks());
4689 
4690     auto fma_layout
4691             = from_bmnk_mapper.map_from_bmnk(abc_kind, bmnk_kinds, dpas_layout);
4692     fma_layout = fma_layout.make_dense();
4693     return fma_layout;
4694 }
4695 
4696 class b_reduce_context_t {
4697 public:
b_reduce_context_t(const conv_config_t & cfg)4698     b_reduce_context_t(const conv_config_t &cfg)
4699         : cfg_(cfg), reduce_condition_(true) {
4700         if (cfg.do_b_reduction) b_reduced_reg_buf_ = make_buffer("b_reduced");
4701     }
4702 
4703     // Setters for B reduced memory buffer/view.
set_b_reduced_mem_buf(const expr_t & buf)4704     void set_b_reduced_mem_buf(const expr_t &buf) { b_reduced_mem_buf_ = buf; }
set_b_reduced_view(const view_t & v)4705     void set_b_reduced_view(const view_t &v) { b_reduced_view_ = v; }
4706 
4707     // Sets the condition to update B reduced output. Reduction is done across
4708     // K for B (KxN tensor) so M dimension should be checked before the update.
set_reduce_condition(const expr_t & cond)4709     void set_reduce_condition(const expr_t &cond) { reduce_condition_ = cond; }
4710 
4711     // Global memory buffer.
b_reduced_mem_buf() const4712     const expr_t &b_reduced_mem_buf() const { return b_reduced_mem_buf_; }
4713 
4714     // Register buffer.
b_reduced_reg_buf() const4715     const expr_t &b_reduced_reg_buf() const { return b_reduced_reg_buf_; }
b_reduced_size() const4716     int b_reduced_size() const { return b_reduced_size_; }
4717 
4718     // Memory view.
b_reduced_thr_view() const4719     const view_t &b_reduced_thr_view() const { return b_reduced_thr_view_; }
4720 
4721     // Register layout.
b_reduced_reg_layout() const4722     const layout_t &b_reduced_reg_layout() const {
4723         return b_reduced_reg_layout_;
4724     }
4725 
init_reduced_thr_view(const tensor_t & b_thr_tile,const expr_t & cond=expr_t ())4726     void init_reduced_thr_view(
4727             const tensor_t &b_thr_tile, const expr_t &cond = expr_t()) {
4728         ir_assert(b_reduced_thr_view_.is_empty()) << "Can't initialize twice.";
4729 
4730         auto b_reduced_thr_tile = b_to_b_reduced_tile(b_thr_tile);
4731         b_reduced_thr_view_
4732                 = b_reduced_view_.create_sub_view(b_reduced_thr_tile);
4733         b_reduced_reg_layout_ = b_reduced_thr_view_.create_dense_vlayout();
4734         b_reduced_size_ = b_reduced_reg_layout_.size();
4735         b_reduced_size_ = utils::rnd_up(b_reduced_size_, cfg_.grf_size());
4736 
4737         if (!cond.is_empty()) reduce_condition_ &= cond;
4738     }
4739 
create_reduce_stmt(const layout_t & b_layout,const expr_t & b_buf,const tensor_t & sub_tile=tensor_t ())4740     stmt_t create_reduce_stmt(const layout_t &b_layout, const expr_t &b_buf,
4741             const tensor_t &sub_tile = tensor_t()) {
4742         auto reduction_stmt
4743                 = jit::create_reduce_stmt(b_layout, b_reduced_reg_layout_,
4744                         b_buf, b_reduced_reg_buf_, sub_tile, (1 << 1));
4745         return reduction_stmt;
4746     }
4747 
create_store_stmt(ir_context_t & ir_ctx,const constraint_set_t & cset) const4748     stmt_t create_store_stmt(
4749             ir_context_t &ir_ctx, const constraint_set_t &cset) const {
4750         write_builder_t r2g(cfg_.hw, ir_ctx, cset, b_reduced_thr_view_,
4751                 b_reduced_mem_buf_, b_reduced_reg_buf_, /*is_slm=*/false,
4752                 ngen_proxy::AtomicOp::fadd);
4753         // TODO: Check that layouts match.
4754         auto ret = r2g.stmt();
4755         if (!reduce_condition_.is_empty()) {
4756             ret = if_t::make(reduce_condition_, ret);
4757         }
4758         return ret;
4759     }
4760 
4761 private:
b_to_b_reduced_tile(const tensor_t & b_tile) const4762     tensor_t b_to_b_reduced_tile(const tensor_t &b_tile) const {
4763         std::vector<dim_t> dims;
4764         std::vector<expr_t> start;
4765         for (int i = 0; i < b_tile.ndims(); i++) {
4766             if ((reduction_mask_ & (1 << i)) != 0) {
4767                 dims.push_back(b_tile(i));
4768                 start.push_back(b_tile.start(i));
4769             }
4770         }
4771         return tensor_t(dims, start);
4772     }
4773 
4774     const conv_config_t &cfg_;
4775 
4776     expr_t reduce_condition_;
4777 
4778     expr_t b_reduced_mem_buf_;
4779     expr_t b_reduced_reg_buf_;
4780 
4781     view_t b_reduced_view_;
4782     view_t b_reduced_thr_view_;
4783 
4784     layout_t b_reduced_reg_layout_;
4785     int b_reduced_size_ = 0;
4786 
4787     uint32_t reduction_mask_ = (1 << 1);
4788 };
4789 
4790 class load_multiply_builder_t {
4791 public:
load_multiply_builder_t(const conv_config_t & cfg,ir_context_t & ir_ctx,const constraint_set_t & cset,const gemm_schedule_t & gemm_schedule,b_reduce_context_t & b_reduce_ctx,const expr_t & ap_buf,const expr_t & a_slm_buf,const expr_t & bp_buf,const expr_t & b_slm_buf,const view_t & ap_x_view,const view_t & bp_x_view)4792     load_multiply_builder_t(const conv_config_t &cfg, ir_context_t &ir_ctx,
4793             const constraint_set_t &cset, const gemm_schedule_t &gemm_schedule,
4794             b_reduce_context_t &b_reduce_ctx, const expr_t &ap_buf,
4795             const expr_t &a_slm_buf, const expr_t &bp_buf,
4796             const expr_t &b_slm_buf, const view_t &ap_x_view,
4797             const view_t &bp_x_view)
4798         : cfg_(cfg)
4799         , ir_ctx_(ir_ctx)
4800         , cset_(cset)
4801         , gemm_schedule_(gemm_schedule)
4802         , b_reduce_ctx_(b_reduce_ctx)
4803         , ap_buf_(ap_buf)
4804         , a_slm_buf_(a_slm_buf)
4805         , bp_buf_(bp_buf)
4806         , b_slm_buf_(b_slm_buf) {
4807         ir_assert(cfg_.a_sub_tiles == 1 || cfg_.b_sub_tiles == 1)
4808                 << "At most one tensor can be tiled.";
4809 
4810         ab_tmp_buf_ = make_buffer("ab_tmp");
4811         a_buf_ = make_buffer("a");
4812         b_buf_ = make_buffer("b");
4813         c_buf_ = make_buffer("c");
4814 
4815         // Views to multiply by a thread.
4816         a_thr_view_ = ap_x_view.create_sub_view(gemm_schedule_.a_thr_tile());
4817         b_thr_view_ = bp_x_view.create_sub_view(gemm_schedule_.b_thr_tile());
4818 
4819         // Initialize view for reduced B.
4820         if (cfg_.do_b_reduction && !cfg_.use_b_slm) {
4821             b_reduce_ctx_.init_reduced_thr_view(
4822                     gemm_schedule_.b_thr_tile(/*is_relative=*/false));
4823         }
4824 
4825         // TODO: Specify loops over sub-tiles in the schedule, use unrolling.
4826         // Sub-tile indices.
4827         a_idx_ = ir_ctx_.create_tmp_var(type_t::s32(), "a_idx");
4828         b_idx_ = ir_ctx_.create_tmp_var(type_t::s32(), "b_idx");
4829 
4830         // Sub-tile views.
4831         a_i_view_ = create_sub_tile_view(abc_kind_t::a, a_thr_view_,
4832                 cfg_.a_sub_tiles, a_idx_, bmnk_kind_t::m, &a_i_outer_blocks_,
4833                 a_i_tile_);
4834         b_j_view_ = create_sub_tile_view(abc_kind_t::b, b_thr_view_,
4835                 cfg_.b_sub_tiles, b_idx_, bmnk_kind_t::n, &b_j_outer_blocks_,
4836                 b_j_tile_);
4837 
4838         build();
4839     }
4840 
allocs() const4841     const std::vector<stmt_t> &allocs() const { return allocs_; }
4842 
load_mul_stmt() const4843     const stmt_t &load_mul_stmt() const { return load_mul_stmt_; }
4844 
c_buf() const4845     const expr_t &c_buf() const { return c_buf_; }
4846 
c_reg_layout() const4847     const layout_t &c_reg_layout() const { return c_reg_layout_; }
4848 
c_attr() const4849     const alloc_attr_t &c_attr() const { return c_attr_; }
4850 
4851 private:
4852     struct sub_tile_info_t {
4853         bool is_loaded = false;
4854         view_t reg_view;
4855         int reg_buf_size;
4856     };
4857 
create_sub_tile_view(abc_kind_t abc_kind,const view_t & thr_view,int sub_tiles,const expr_t & idx,bmnk_kind_t bmnk_kind,std::vector<block_t> * outer_blocks,tensor_t & sub_tile) const4858     view_t create_sub_tile_view(abc_kind_t abc_kind, const view_t &thr_view,
4859             int sub_tiles, const expr_t &idx, bmnk_kind_t bmnk_kind,
4860             std::vector<block_t> *outer_blocks, tensor_t &sub_tile) const {
4861         auto &bmnk_mapper = gemm_schedule_.bmnk_mapper();
4862         auto layout = thr_view.create_pseudo_vlayout();
4863         dim_t mn_dim = 1;
4864         for (auto &b : layout.blocks()) {
4865             auto b_bmnk_kind = bmnk_mapper.bmnk_kind(abc_kind, b.dim_idx);
4866             if (b_bmnk_kind == bmnk_kind) mn_dim *= b.block;
4867         }
4868 
4869         std::vector<dim_t> sub_tile_dims(thr_view.nvdims(), 1);
4870         dim_t mn_sub_tile_dim = ir_utils::safe_divide(mn_dim, dim_t(sub_tiles));
4871         for (auto &b : layout.blocks()) {
4872             auto b_bmnk_kind = bmnk_mapper.bmnk_kind(abc_kind, b.dim_idx);
4873             if (b_bmnk_kind == bmnk_kind) {
4874                 if (mn_sub_tile_dim == 1) continue;
4875                 dim_t next_block;
4876                 if (mn_sub_tile_dim % b.block == 0) {
4877                     next_block = b.block;
4878                 } else {
4879                     next_block
4880                             = ir_utils::safe_divide(b.block, mn_sub_tile_dim);
4881                 }
4882                 sub_tile_dims[b.dim_idx] *= next_block;
4883                 mn_sub_tile_dim /= next_block;
4884             } else {
4885                 sub_tile_dims[b.dim_idx] *= b.block;
4886             }
4887         }
4888         grid_info_t grid({sub_tiles}, {idx});
4889         sub_tile = layout.split(tensor_t(sub_tile_dims), grid, outer_blocks);
4890         return thr_view.create_sub_view(sub_tile);
4891     }
4892 
a_type() const4893     const type_t &a_type() const { return a_i_view_.type(); }
b_type() const4894     const type_t &b_type() const { return b_j_view_.type(); }
4895 
build()4896     void build() {
4897         a_sub_tiles_.resize(cfg_.a_sub_tiles);
4898         b_sub_tiles_.resize(cfg_.b_sub_tiles);
4899         for (int i = 0; i < cfg_.a_sub_tiles; i++) {
4900             for (int j = 0; j < cfg_.b_sub_tiles; j++) {
4901                 build_sub_tile(i, j);
4902             }
4903         }
4904 
4905         if (tmp_buf_size_ > 0) {
4906             register_buffer(ab_tmp_buf_, tmp_buf_size_, alloc_kind_t::grf);
4907         }
4908 
4909         // C layout in problem notation.
4910         auto c_layout = c_sub_tile_layout_;
4911 
4912         // Add outer blocks coming from A/B sub-tiles.
4913         auto &bmnk_mapper = gemm_schedule_.bmnk_mapper();
4914         for (auto &b : a_i_outer_blocks_) {
4915             auto &var = bmnk_mapper.var(abc_kind_t::a, b.dim_idx);
4916             int c_dim_idx = bmnk_mapper.dim_idx(abc_kind_t::c, var);
4917             c_layout = c_layout.add_outer_block(c_dim_idx, b.block);
4918         }
4919         for (auto &b : b_j_outer_blocks_) {
4920             auto &var = bmnk_mapper.var(abc_kind_t::b, b.dim_idx);
4921             int c_dim_idx = bmnk_mapper.dim_idx(abc_kind_t::c, var);
4922             c_layout = c_layout.add_outer_block(c_dim_idx, b.block);
4923         }
4924 
4925         c_reg_layout_ = c_layout;
4926     }
4927 
build_sub_tile(int i,int j)4928     void build_sub_tile(int i, int j) {
4929         bool is_first = (i == 0 && j == 0);
4930 
4931         stmt_t ab_s2r_load;
4932         stmt_t ab_g2r_load;
4933         load_sub_tile(abc_kind_t::a, i, ab_s2r_load, ab_g2r_load);
4934         load_sub_tile(abc_kind_t::b, j, ab_s2r_load, ab_g2r_load);
4935 
4936         load_mul_stmt_ = load_mul_stmt_.append(
4937                 stmt_group_t::make(stmt_label_t::g2r_load(i + j), ab_g2r_load));
4938         load_mul_stmt_ = load_mul_stmt_.append(
4939                 stmt_group_t::make(stmt_label_t::s2r_load(i + j), ab_s2r_load));
4940 
4941         auto &a_i_view = a_sub_tiles_[i].reg_view;
4942         auto &b_j_view = b_sub_tiles_[j].reg_view;
4943 
4944         // Multiply C_i_j += A_i x B_j in GEMM notation.
4945         multiply_builder_t mul_builder(cfg_, gemm_schedule_.bmnk_mapper(),
4946                 a_i_view, b_j_view, a_buf_, b_buf_, c_buf_[c_buf_off_]);
4947         c_sub_tile_layout_ = mul_builder.c_layout();
4948         c_buf_off_ += c_sub_tile_layout_.size();
4949         ir_trace() << "Multiply (" << i << ", " << j << "):\n"
4950                    << mul_builder.str() << std::endl;
4951 
4952         load_mul_stmt_ = load_mul_stmt_.append(stmt_group_t::make(
4953                 stmt_label_t::mul(i + j), mul_builder.stmt()));
4954 
4955         if (!is_first) {
4956             ir_assert(mul_builder.c_layout() == c_sub_tile_layout_)
4957                     << "Sub-tile layouts must be equal.";
4958             return;
4959         }
4960 
4961         c_attr_ = grf_alloc_attr_t::make(mul_builder.c_grf_bundle());
4962 
4963         auto a_attr = grf_alloc_attr_t::make(mul_builder.a_grf_bundle());
4964         register_buffer(a_buf_, a_sub_tiles_[i].reg_buf_size, alloc_kind_t::grf,
4965                 a_attr);
4966 
4967         auto b_attr = grf_alloc_attr_t::make(mul_builder.b_grf_bundle());
4968         register_buffer(b_buf_, b_sub_tiles_[j].reg_buf_size, alloc_kind_t::grf,
4969                 b_attr);
4970     }
4971 
4972     // Loads A_i or B_j sub-tile.
load_sub_tile(abc_kind_t abc_kind,int i,stmt_t & ab_s2r_load,stmt_t & ab_g2r_load)4973     void load_sub_tile(abc_kind_t abc_kind, int i, stmt_t &ab_s2r_load,
4974             stmt_t &ab_g2r_load) {
4975         bool is_a = (abc_kind == abc_kind_t::a);
4976         auto &info = (is_a ? a_sub_tiles_[i] : b_sub_tiles_[i]);
4977         if (info.is_loaded) return;
4978 
4979         auto &bmnk_mapper = gemm_schedule_.bmnk_mapper();
4980 
4981         auto &x_view = (is_a ? a_i_view_ : b_j_view_);
4982         auto &x_tile = (is_a ? a_i_tile_ : b_j_tile_);
4983         auto &x_idx = (is_a ? a_idx_ : b_idx_);
4984 
4985         auto view = x_view.substitute(x_idx, i);
4986         auto tile = x_tile.substitute(x_idx, i);
4987 
4988         bool use_x_slm = (is_a ? cfg_.use_a_slm : cfg_.use_b_slm);
4989         auto &x_slm_buf = (is_a ? a_slm_buf_ : b_slm_buf_);
4990         auto &x_gmem_buf = (is_a ? ap_buf_ : bp_buf_);
4991         auto &x_buf = (use_x_slm ? x_slm_buf : x_gmem_buf);
4992         auto &x_reg_buf = (is_a ? a_buf_ : b_buf_);
4993 
4994         layout_t load_layout;
4995         view_t reg_view;
4996         stmt_t stmt;
4997         load_sub_tile_impl(abc_kind, i, view, x_buf, x_reg_buf, use_x_slm,
4998                 load_layout, reg_view, stmt);
4999 
5000         auto reg_layout = load_layout;
5001 
5002         if (!is_a && cfg_.do_b_reduction && !cfg_.use_b_slm) {
5003             auto reduce_stmt = b_reduce_ctx_.create_reduce_stmt(
5004                     reg_layout, b_buf_, tile);
5005             stmt = stmt.append(reduce_stmt);
5006         }
5007 
5008         bool changed;
5009         auto fma_layout = convert_to_fma_friendly_layout(cfg_, abc_kind,
5010                 bmnk_mapper, reg_layout, a_type(), b_type(), &changed);
5011 
5012         if (changed) {
5013             if (fma_layout.type() != reg_layout.type()) {
5014                 reg_view = reg_view.retype(fma_layout.type());
5015             }
5016             reg_layout = fma_layout;
5017             reg_view.set_tlayout(reg_layout);
5018             stmt = substitute(stmt, x_reg_buf, ab_tmp_buf_);
5019             stmt = stmt.append(create_reorder_stmt(
5020                     load_layout, reg_layout, ab_tmp_buf_, x_reg_buf));
5021             tmp_buf_size_ = std::max(tmp_buf_size_, int(load_layout.size()));
5022         }
5023 
5024         if (use_x_slm) {
5025             ab_s2r_load = ab_s2r_load.append(stmt);
5026         } else {
5027             ab_g2r_load = ab_g2r_load.append(stmt);
5028         }
5029         info.is_loaded = true;
5030         info.reg_view = reg_view;
5031         info.reg_buf_size = reg_layout.size();
5032     }
5033 
load_sub_tile_impl(abc_kind_t abc_kind,int sub_tile_idx,const view_t & _mem_view,const expr_t & buf,const expr_t & reg_buf,bool is_slm,layout_t & reg_layout,view_t & reg_view,stmt_t & stmt)5034     void load_sub_tile_impl(abc_kind_t abc_kind, int sub_tile_idx,
5035             const view_t &_mem_view, const expr_t &buf, const expr_t &reg_buf,
5036             bool is_slm, layout_t &reg_layout, view_t &reg_view, stmt_t &stmt) {
5037         bool is_a = (abc_kind == abc_kind_t::a);
5038 
5039         view_t mem_view;
5040         bool load_buffered = false;
5041 
5042         // Using buffered view is enabled only when:
5043         // - Loading directly from global memory
5044         // - FMA kind is mad (dpas implementation is more strict and requires
5045         //   layouts, not views)
5046         // - Loading A tensor (A - activations for FWD/BWD_D where we may have
5047         //   overlapping when applying KW blocking )
5048         if (!is_slm && is_a && cfg_.fma_kind == fma_kind_t::mad) {
5049             load_buffered
5050                     = _mem_view.try_create_buffer_view(mem_view, reg_view);
5051         }
5052 
5053         if (!load_buffered) mem_view = _mem_view;
5054 
5055         read_builder_t read(
5056                 cfg_.hw, ir_ctx_, cset_, mem_view, buf, reg_buf, is_slm);
5057         ir_trace() << (is_a ? "A" : "B") << " GMEM/SLM to GRF load #"
5058                    << sub_tile_idx << ":\n"
5059                    << read.str() << std::endl;
5060 
5061         if (load_buffered) {
5062             reg_view.set_tlayout(read.reg_layout());
5063         } else {
5064             reg_view = view_t(read.reg_layout());
5065         }
5066 
5067         reg_layout = read.reg_layout();
5068         stmt = read.stmt();
5069     }
5070 
register_buffer(const stmt_t & alloc)5071     void register_buffer(const stmt_t &alloc) {
5072         ir_assert(alloc.is<alloc_t>());
5073         allocs_.push_back(alloc);
5074     }
5075 
register_buffer(const expr_t & buf,int size,alloc_kind_t kind,const alloc_attr_t & attr={})5076     void register_buffer(const expr_t &buf, int size, alloc_kind_t kind,
5077             const alloc_attr_t &attr = {}) {
5078         register_buffer(alloc_t::make(buf, size, kind, attr));
5079     }
5080 
5081     const conv_config_t &cfg_;
5082     ir_context_t ir_ctx_;
5083     const constraint_set_t &cset_;
5084     const gemm_schedule_t &gemm_schedule_;
5085     b_reduce_context_t &b_reduce_ctx_;
5086 
5087     expr_t ap_buf_;
5088     expr_t a_slm_buf_;
5089 
5090     expr_t bp_buf_;
5091     expr_t b_slm_buf_;
5092 
5093     layout_t c_reg_layout_;
5094 
5095     expr_t ab_tmp_buf_;
5096     expr_t a_buf_;
5097     expr_t b_buf_;
5098     expr_t c_buf_;
5099 
5100     int tmp_buf_size_ = 0;
5101 
5102     // Per-thread views to multiply.
5103     view_t a_thr_view_;
5104     view_t b_thr_view_;
5105 
5106     // Sub-tile indices.
5107     expr_t a_idx_;
5108     expr_t b_idx_;
5109 
5110     // Sub-tile views.
5111     view_t a_i_view_;
5112     view_t b_j_view_;
5113 
5114     tensor_t a_i_tile_;
5115     tensor_t b_j_tile_;
5116 
5117     std::vector<sub_tile_info_t> a_sub_tiles_;
5118     std::vector<sub_tile_info_t> b_sub_tiles_;
5119 
5120     std::vector<block_t> a_i_outer_blocks_;
5121     std::vector<block_t> b_j_outer_blocks_;
5122 
5123     std::vector<stmt_t> allocs_;
5124 
5125     stmt_t load_mul_stmt_;
5126 
5127     int c_buf_off_ = 0;
5128     layout_t c_sub_tile_layout_;
5129     alloc_attr_t c_attr_;
5130 };
5131 
5132 class slm_reduce_builder_t {
5133 public:
slm_reduce_builder_t(const conv_config_t & cfg,ir_context_t & ir_ctx,const constraint_set_t & cset,const grid_info_t & tg_grid,const expr_t & reg_buf,const layout_t & reg_layout,const tensor_t & thr_tile)5134     slm_reduce_builder_t(const conv_config_t &cfg, ir_context_t &ir_ctx,
5135             const constraint_set_t &cset, const grid_info_t &tg_grid,
5136             const expr_t &reg_buf, const layout_t &reg_layout,
5137             const tensor_t &thr_tile)
5138         : cfg_(cfg)
5139         , ir_ctx_(ir_ctx)
5140         , cset_(cset)
5141         , tg_grid_(tg_grid)
5142         , reg_buf_(reg_buf)
5143         , reg_layout_(reg_layout)
5144         , thr_tile_(thr_tile) {
5145         ir_assert(tg_grid_.dim(2) > 1);
5146 
5147         tmp_reg_buf_ = ir_ctx_.create_tmp_var(type_t::byte_ptr());
5148         slm_buf_ = make_buffer("reduce_slm");
5149 
5150         build();
5151     }
5152 
reg_layout() const5153     layout_t reg_layout() const { return reg_layout_; }
5154 
thr_tile() const5155     tensor_t thr_tile() const { return thr_tile_; }
5156 
stmt() const5157     stmt_t stmt() const { return stmt_; }
5158 
5159 private:
build()5160     void build() {
5161         int ndims = reg_layout_.ndims();
5162         int tg_ndims = tg_grid_.ndims();
5163 
5164         // Create SLM layout to store all intermediate buffers from the thread
5165         // group.
5166         layout_t slm_layout(reg_layout_.type(), ndims + tg_ndims,
5167                 reg_layout_.offset(), reg_layout_.blocks());
5168         for (int i = tg_ndims - 1; i >= 0; i--) {
5169             slm_layout = slm_layout.add_outer_block(ndims + i, tg_grid_.dim(i));
5170         }
5171 
5172         slm_buf_size_ = slm_layout.size();
5173 
5174         // Write thread tile to SLM.
5175         std::vector<dim_t> write_dims = reg_layout_.dims();
5176         std::vector<expr_t> write_start(ndims + tg_ndims, 0);
5177         write_dims.resize(ndims + tg_ndims, 1);
5178         for (int i = tg_ndims - 1; i >= 0; i--) {
5179             write_start[ndims + i] = tg_grid_.idx(i);
5180         }
5181         auto write_tile = tensor_t(write_dims, write_start);
5182         write_builder_t write(cfg_.hw, ir_ctx_, cset_,
5183                 view_t(slm_layout.map(write_tile)), slm_buf_, reg_buf_,
5184                 /*is_slm=*/true);
5185         stmt_ = stmt_.append(funcs::barrier());
5186         stmt_ = stmt_.append(write.stmt());
5187         stmt_ = stmt_.append(funcs::barrier());
5188 
5189         auto &write_layout = write.reg_layout();
5190         ir_assert(write_layout == reg_layout_) << "Incompatible layouts.";
5191 
5192         // Redistribute the layout to read/reduce all k-axis tiles from every
5193         // thread.
5194         auto local_thr_tile = reg_layout_.split(tg_grid_.sub_grid({2}));
5195         reg_layout_ = reg_layout_.map(tensor_t(local_thr_tile.dims()));
5196 
5197         stmt_t reduce_stmt;
5198         std::vector<dim_t> read_dims(ndims + tg_ndims, 1);
5199         std::vector<expr_t> read_start(ndims + tg_ndims);
5200         for (int i = 0; i < ndims; i++) {
5201             read_dims[i] = local_thr_tile(i);
5202             read_start[i] = local_thr_tile.start(i);
5203         }
5204         read_start[ndims + 0] = tg_grid_.idx(0);
5205         read_start[ndims + 1] = tg_grid_.idx(1);
5206         for (int i = 0; i < tg_grid_.dim(2); i++) {
5207             read_start[ndims + 2] = i;
5208             tensor_t read_tile(read_dims, read_start);
5209             read_builder_t read(cfg_.hw, ir_ctx_, cset_,
5210                     view_t(slm_layout.map(read_tile)), slm_buf_, tmp_reg_buf_,
5211                     /*is_slm=*/true);
5212             reduce_stmt = reduce_stmt.append(read.stmt());
5213 
5214             tmp_reg_buf_size_
5215                     = std::max(tmp_reg_buf_size_, read.reg_buf_size());
5216             auto read_layout = read.reg_layout();
5217             for (int j = 0; j < 2; j++)
5218                 ir_assert(read_layout.dim(ndims + j) == 1);
5219             read_layout = layout_t(read_layout.type(), ndims + 1,
5220                     read_layout.offset(), read_layout.blocks());
5221             reduce_stmt = reduce_stmt.append(
5222                     create_reduce_stmt(read_layout, reg_layout_, tmp_reg_buf_,
5223                             reg_buf_, tensor_t(), reduction_mask()));
5224         }
5225 
5226         stmt_ = stmt_.append(
5227                 create_zero_out_stmt(cfg_.hw, reg_buf_, reg_layout_.size()));
5228         stmt_ = stmt_.append(reduce_stmt);
5229 
5230         stmt_ = alloc_t::make(
5231                 slm_buf_, slm_buf_size_, alloc_kind_t::slm, {}, stmt_);
5232         stmt_ = alloc_t::make(
5233                 tmp_reg_buf_, tmp_reg_buf_size_, alloc_kind_t::grf, {}, stmt_);
5234 
5235         thr_tile_ = thr_tile_.create_sub_tensor(local_thr_tile);
5236     }
5237 
reduction_mask() const5238     uint32_t reduction_mask() const {
5239         int k_dim_idx = reg_layout_.ndims();
5240         uint32_t mask = 0xFFFFFFFF;
5241         mask &= ~(1 << k_dim_idx);
5242         return mask;
5243     }
5244 
5245     const conv_config_t &cfg_;
5246     ir_context_t &ir_ctx_;
5247     const constraint_set_t &cset_;
5248     grid_info_t tg_grid_;
5249 
5250     expr_t reg_buf_;
5251     layout_t reg_layout_;
5252     tensor_t thr_tile_;
5253 
5254     expr_t tmp_reg_buf_;
5255     int tmp_reg_buf_size_ = 0;
5256 
5257     expr_t slm_buf_;
5258     int slm_buf_size_ = 0;
5259 
5260     stmt_t stmt_;
5261 };
5262 
5263 class compute_builder_t {
5264 public:
compute_builder_t(const conv_config_t & cfg,ir_context_t & ir_ctx,constraint_set_t & cset)5265     compute_builder_t(const conv_config_t &cfg, ir_context_t &ir_ctx,
5266             constraint_set_t &cset)
5267         : cfg_(cfg)
5268         , ir_ctx_(ir_ctx)
5269         , cset_(cset)
5270         , b_reduce_ctx_(cfg)
5271         , g2s_ctx_(ir_ctx) {}
5272 
ab_slm_size() const5273     int ab_slm_size() const { return ab_slm_size_; }
5274 
c_zero_out_stmt() const5275     const stmt_t &c_zero_out_stmt() const { return c_zero_out_stmt_; }
b_reduced_zero_out_stmt() const5276     const stmt_t &b_reduced_zero_out_stmt() const {
5277         return b_reduced_zero_out_stmt_;
5278     }
5279 
zero_out_stmt() const5280     stmt_t zero_out_stmt() const {
5281         stmt_t ret;
5282         ret = ret.append(c_zero_out_stmt());
5283         ret = ret.append(b_reduced_zero_out_stmt());
5284         return ret;
5285     }
5286 
iter_stmt() const5287     stmt_t iter_stmt() const {
5288         stmt_t stmt;
5289         bool use_prefetch = !prefetch_stmt_.is_empty();
5290         bool use_slm = !g2s_load_stmt_.is_empty();
5291         if (use_prefetch) {
5292             stmt = stmt.append(stmt_group_t::make(
5293                     stmt_label_t::prefetch(), prefetch_stmt_));
5294         } else if (use_slm) {
5295             stmt = stmt.append(stmt_group_t::make(
5296                     stmt_label_t::g2s_load(), g2s_load_stmt_));
5297             stmt = stmt.append(funcs::barrier());
5298             stmt = stmt.append(stmt_group_t::make(
5299                     stmt_label_t::g2s_store(), g2s_store_stmt_));
5300             stmt = stmt.append(funcs::barrier());
5301         }
5302         stmt = stmt.append(load_mul_stmt_);
5303         return stmt;
5304     }
5305 
c_store_stmt() const5306     const stmt_t &c_store_stmt() const { return c_store_stmt_; }
b_reduced_store_stmt() const5307     const stmt_t &b_reduced_store_stmt() const { return b_reduced_store_stmt_; }
5308 
inject_compute_alloc_stmts(const stmt_t & stmt) const5309     stmt_t inject_compute_alloc_stmts(const stmt_t &stmt) const {
5310         return jit::inject_alloc_stmts(stmt, compute_allocs_);
5311     }
5312 
inject_out_alloc_stmts(const stmt_t & stmt) const5313     stmt_t inject_out_alloc_stmts(const stmt_t &stmt) const {
5314         return jit::inject_alloc_stmts(stmt, out_allocs_);
5315     }
5316 
inject_let_stmts(const stmt_t & stmt) const5317     stmt_t inject_let_stmts(const stmt_t &stmt) const {
5318         return jit::inject_let_stmts(stmt, g2s_ctx_.grid_idx_lets);
5319     }
5320 
set_gemm_schedule(const gemm_schedule_t & gemm_schedule)5321     void set_gemm_schedule(const gemm_schedule_t &gemm_schedule) {
5322         gemm_schedule_ = gemm_schedule;
5323     }
5324 
5325     // Setters for original AP/BP/CP buffers (P - problem notation).
set_ap_buf(const expr_t & buf)5326     void set_ap_buf(const expr_t &buf) { ap_buf_ = buf; }
set_bp_buf(const expr_t & buf)5327     void set_bp_buf(const expr_t &buf) { bp_buf_ = buf; }
set_cp_buf(const expr_t & buf)5328     void set_cp_buf(const expr_t &buf) { cp_buf_ = buf; }
set_b_reduced_mem_buf(const expr_t & buf)5329     void set_b_reduced_mem_buf(const expr_t &buf) {
5330         b_reduce_ctx_.set_b_reduced_mem_buf(buf);
5331     }
5332 
set_b_reduced_view(const view_t & v)5333     void set_b_reduced_view(const view_t &v) {
5334         b_reduce_ctx_.set_b_reduced_view(v);
5335     }
5336 
set_post_op_context(const post_op_context_t & post_op_ctx)5337     void set_post_op_context(const post_op_context_t &post_op_ctx) {
5338         post_op_ctx_ = post_op_ctx;
5339     }
5340 
set_reduce_condition(const expr_t & cond)5341     void set_reduce_condition(const expr_t &cond) {
5342         b_reduce_ctx_.set_reduce_condition(cond);
5343     }
5344 
build()5345     void build() {
5346         // Initialize SLM buffers.
5347         expr_t a_slm_buf = make_buffer("a_slm");
5348         expr_t b_slm_buf = make_buffer("b_slm");
5349 
5350         view_t ap_gmem_view = gemm_schedule_.a_tg_view();
5351         view_t bp_gmem_view = gemm_schedule_.b_tg_view();
5352 
5353         // Views to multiply by a thread group (either GMEM or SLM).
5354         view_t ap_x_view;
5355         view_t bp_x_view;
5356         prepare_gmem_to_slm("A", cfg_.use_a_slm, gemm_schedule_.a_tg_tile(),
5357                 ap_gmem_view, ap_buf_, a_slm_buf, ap_x_view, g2s_ctx_);
5358         prepare_gmem_to_slm("B", cfg_.use_b_slm, gemm_schedule_.b_tg_tile(),
5359                 bp_gmem_view, bp_buf_, b_slm_buf, bp_x_view, g2s_ctx_);
5360         prepare_prefetch("A", cfg_.use_prefetch, ap_gmem_view, ap_buf_);
5361         prepare_prefetch("B", cfg_.use_prefetch, bp_gmem_view, bp_buf_);
5362 
5363         if (ap_x_view.is_empty()) ap_x_view = ap_gmem_view;
5364         if (bp_x_view.is_empty()) bp_x_view = bp_gmem_view;
5365 
5366         for (auto &bi : g2s_ctx_.bufs) {
5367             register_compute_buffer(bi.buf, bi.size, alloc_kind_t::grf);
5368         }
5369 
5370         load_multiply_builder_t load_mul_builder(cfg_, ir_ctx_, cset_,
5371                 gemm_schedule_, b_reduce_ctx_, ap_buf_, a_slm_buf, bp_buf_,
5372                 b_slm_buf, ap_x_view, bp_x_view);
5373 
5374         load_mul_stmt_ = load_mul_builder.load_mul_stmt();
5375         compute_allocs_.insert(compute_allocs_.end(),
5376                 load_mul_builder.allocs().begin(),
5377                 load_mul_builder.allocs().end());
5378 
5379         auto c_buf = load_mul_builder.c_buf();
5380         auto c_attr = load_mul_builder.c_attr();
5381         int c_size = load_mul_builder.c_reg_layout().size();
5382         register_out_buffer(c_buf, c_size, alloc_kind_t::grf, c_attr);
5383 
5384         auto c_thr_reg_layout = load_mul_builder.c_reg_layout();
5385         auto thr_tile = gemm_schedule_.c_thr_tile(/*is_relative=*/false);
5386 
5387         if (gemm_schedule_.with_thread_group_k_slicing()) {
5388             slm_reduce_builder_t slm_reduce_builder(cfg_, ir_ctx_, cset_,
5389                     gemm_schedule_.tg_grid(), c_buf, c_thr_reg_layout,
5390                     thr_tile);
5391             c_store_stmt_ = c_store_stmt_.append(slm_reduce_builder.stmt());
5392             c_thr_reg_layout = slm_reduce_builder.reg_layout();
5393             thr_tile = slm_reduce_builder.thr_tile();
5394         }
5395 
5396         auto c_thr_mem_view = gemm_schedule_.c_view().create_sub_view(thr_tile);
5397         epilogue_builder_t c_m2g(cfg_, ir_ctx_, cset_, post_op_ctx_, thr_tile,
5398                 c_thr_mem_view, c_thr_reg_layout, cp_buf_, c_buf);
5399         ir_trace() << "C GRF to GMEM store:\n" << c_m2g.stmt() << std::endl;
5400 
5401         c_zero_out_stmt_ = stmt_group_t::make(stmt_label_t::c_zero_out(),
5402                 create_zero_out_stmt(cfg_.hw, c_buf, c_size));
5403         c_store_stmt_ = c_store_stmt_.append(c_m2g.stmt());
5404 
5405         if (cfg_.do_b_reduction) {
5406             auto &ctx = b_reduce_ctx_;
5407             b_reduced_zero_out_stmt_ = create_zero_out_stmt(
5408                     cfg_.hw, ctx.b_reduced_reg_buf(), ctx.b_reduced_size());
5409             b_reduced_store_stmt_ = ctx.create_store_stmt(ir_ctx_, cset_);
5410             register_out_buffer(ctx.b_reduced_reg_buf(), ctx.b_reduced_size(),
5411                     alloc_kind_t::grf);
5412         }
5413 
5414         // Replace DPAS by DPASW when applicable.
5415         if (cfg_.fma_kind == fma_kind_t::dpasw) {
5416             alloc_updater_t alloc_updater;
5417             inject_dpasw(cfg_.hw, load_mul_stmt_, c_buf, c_store_stmt_,
5418                     alloc_updater, gemm_schedule_.tg_grid().idx(0));
5419             for (auto &a : compute_allocs_) {
5420                 a = alloc_updater.update(a);
5421             }
5422             for (auto &a : out_allocs_) {
5423                 a = alloc_updater.update(a);
5424             }
5425         }
5426 
5427         // Assign {Atomic} for DPAS(W) when applicable.
5428         load_mul_stmt_ = inject_atomic(load_mul_stmt_);
5429     }
5430 
5431 private:
5432     struct buf_info_t {
buf_info_tdnnl::impl::gpu::jit::compute_builder_t::buf_info_t5433         buf_info_t(const std::string &tag, const expr_t &buf)
5434             : tag(tag), buf(buf) {}
5435 
5436         std::string tag;
5437         expr_t buf;
5438         int size = 0;
5439     };
5440 
5441     struct g2s_context_t {
g2s_context_tdnnl::impl::gpu::jit::compute_builder_t::g2s_context_t5442         g2s_context_t(ir_context_t &ir_ctx) : ir_ctx(ir_ctx) {}
5443 
create_bufdnnl::impl::gpu::jit::compute_builder_t::g2s_context_t5444         expr_t create_buf(const char *tag, bool force_reuse = false) {
5445             if (reuse_buffers || force_reuse) {
5446                 for (auto &bi : bufs) {
5447                     if (bi.tag == tag) return bi.buf;
5448                 }
5449             }
5450             auto buf = ir_ctx.create_tmp_var(type_t::byte_ptr(), tag);
5451             bufs.emplace_back(tag, buf);
5452             return buf;
5453         }
5454 
set_buf_sizednnl::impl::gpu::jit::compute_builder_t::g2s_context_t5455         void set_buf_size(const expr_t &buf, int size) {
5456             for (auto &bi : bufs) {
5457                 if (bi.buf.is_same(buf)) bi.size = std::max(bi.size, size);
5458             }
5459         }
5460 
create_tmp_grid_idxdnnl::impl::gpu::jit::compute_builder_t::g2s_context_t5461         expr_t create_tmp_grid_idx() {
5462             auto var = ir_ctx.create_tmp_var(type_t::s32(), "idx");
5463             tmp_grid_idxs.insert({var, expr_t()});
5464             return var;
5465         }
5466 
set_grid_idx_valuednnl::impl::gpu::jit::compute_builder_t::g2s_context_t5467         void set_grid_idx_value(const expr_t &idx, const expr_t &value) {
5468             auto &old = tmp_grid_idxs[idx];
5469             ir_assert(old.is_empty());
5470             old = substitute_grid_idx_value(value);
5471         }
5472 
substitute_grid_idx_valuednnl::impl::gpu::jit::compute_builder_t::g2s_context_t5473         expr_t substitute_grid_idx_value(const expr_t &_e) {
5474             auto e = _e;
5475             auto vars = find_unique_objects<var_t>(e);
5476             for (auto &v : vars) {
5477                 auto it = tmp_grid_idxs.find(v);
5478                 if (it == tmp_grid_idxs.end()) continue;
5479                 e = substitute(e, v, it->second);
5480             }
5481             return e;
5482         }
5483 
register_griddnnl::impl::gpu::jit::compute_builder_t::g2s_context_t5484         void register_grid(const grid_info_t &grid) {
5485             for (int i = 0; i < grid.ndims(); i++) {
5486                 auto &idx = grid.idx(i);
5487                 auto it = tmp_grid_idxs.find(idx);
5488                 if (it == tmp_grid_idxs.end()) continue;
5489                 grid_idx_lets.emplace_back(let_t::make(idx, it->second));
5490             }
5491         }
5492 
5493         ir_context_t &ir_ctx;
5494         grid_info_t prev_load_grid;
5495         bool reuse_buffers = false;
5496         std::vector<buf_info_t> bufs;
5497 
5498         object_map_t<expr_t, expr_t> tmp_grid_idxs;
5499         std::vector<stmt_t> grid_idx_lets;
5500     };
5501 
register_compute_buffer(const expr_t & buf,int size,alloc_kind_t kind,const alloc_attr_t & attr={})5502     void register_compute_buffer(const expr_t &buf, int size, alloc_kind_t kind,
5503             const alloc_attr_t &attr = {}) {
5504         compute_allocs_.push_back(alloc_t::make(buf, size, kind, attr));
5505     }
5506 
register_out_buffer(const expr_t & buf,int size,alloc_kind_t kind,const alloc_attr_t & attr={})5507     void register_out_buffer(const expr_t &buf, int size, alloc_kind_t kind,
5508             const alloc_attr_t &attr = {}) {
5509         out_allocs_.push_back(alloc_t::make(buf, size, kind, attr));
5510     }
5511 
5512     // Handles GMEM to SLM load for A and B. Done in two steps:
5513     // 1. Load: GMEM -> GRF (temporary)
5514     // 2. Store: GRF (temporary) -> SLM
prepare_gmem_to_slm(const char * tag,bool use_x_slm,const tensor_t & tg_tile,const view_t & x_gmem_view,const expr_t & xp_buf,const expr_t & x_slm_buf,view_t & x_slm_view,g2s_context_t & g2s_ctx)5515     void prepare_gmem_to_slm(const char *tag, bool use_x_slm,
5516             const tensor_t &tg_tile, const view_t &x_gmem_view,
5517             const expr_t &xp_buf, const expr_t &x_slm_buf, view_t &x_slm_view,
5518             g2s_context_t &g2s_ctx) {
5519         if (!use_x_slm) return;
5520 
5521         grid_info_t load_grid = gemm_schedule_.tg_grid();
5522         for (;;) {
5523             bool ok = prepare_gmem_to_slm_impl(tag, use_x_slm, tg_tile,
5524                     x_gmem_view, xp_buf, x_slm_buf, x_slm_view, load_grid,
5525                     g2s_ctx);
5526             if (ok) {
5527                 g2s_ctx.prev_load_grid = load_grid;
5528                 g2s_ctx.register_grid(load_grid);
5529                 return;
5530             }
5531 
5532             // Reduce grid and try again.
5533             auto grid_idx = g2s_ctx.create_tmp_grid_idx();
5534             int dim_idx;
5535             expr_t grid_idx_value;
5536             auto new_load_grid
5537                     = load_grid.halven(grid_idx, dim_idx, grid_idx_value);
5538             if (new_load_grid.is_empty()) break;
5539 
5540             if (new_load_grid == g2s_ctx.prev_load_grid) {
5541                 new_load_grid = load_grid.halven(
5542                         grid_idx, dim_idx, grid_idx_value, /*first=*/false);
5543                 g2s_ctx.reuse_buffers = true;
5544             }
5545             g2s_ctx.set_grid_idx_value(grid_idx, grid_idx_value);
5546 
5547             cset_.add_constraint(grid_idx >= 0);
5548             cset_.add_constraint(grid_idx < new_load_grid.dim(dim_idx));
5549 
5550             load_grid = new_load_grid;
5551         }
5552         ir_error_not_expected() << "Can't create GMEM -> SLM loads/stores.";
5553     }
5554 
prepare_gmem_to_slm_impl(const char * tag,bool use_x_slm,const tensor_t & tg_tile,const view_t & x_gmem_view,const expr_t & xp_buf,const expr_t & x_slm_buf,view_t & x_slm_view,const grid_info_t & load_grid,g2s_context_t & g2s_ctx)5555     bool prepare_gmem_to_slm_impl(const char *tag, bool use_x_slm,
5556             const tensor_t &tg_tile, const view_t &x_gmem_view,
5557             const expr_t &xp_buf, const expr_t &x_slm_buf, view_t &x_slm_view,
5558             const grid_info_t &load_grid, g2s_context_t &g2s_ctx) {
5559         bool is_a = (tag[0] == 'A');
5560 
5561         auto xp_slm_layout = create_slm_layout(
5562                 x_gmem_view, is_a ? abc_kind_t::a : abc_kind_t::b, load_grid);
5563 
5564         auto grid_cond = load_grid.slice_condition();
5565 
5566         tensor_t thr_tile;
5567         // Per-thread view to load from GMEM to SLM.
5568         auto x_g2s_view = x_gmem_view.split(load_grid, thr_tile);
5569         auto slm_thr_layout = xp_slm_layout.map(thr_tile);
5570 
5571         // Ensure that each thread writes a dense region to SLM. If the layout
5572         // is not dense, return and try with smaller grid.
5573         if (!slm_thr_layout.is_dense()) return false;
5574 
5575         register_compute_buffer(
5576                 x_slm_buf, xp_slm_layout.size(), alloc_kind_t::slm);
5577         ab_slm_size_ += xp_slm_layout.size();
5578 
5579         // Temporary GRF buffer.
5580         expr_t x_g2s_reg_buf = g2s_ctx.create_buf("g2s");
5581 
5582         // GMEM -> GRF load.
5583         read_builder_t x_read(cfg_.hw, ir_ctx_, cset_, x_g2s_view, xp_buf,
5584                 x_g2s_reg_buf, /*is_slm=*/false);
5585         ir_trace() << tag << " GMEM to GRF load:\n"
5586                    << x_read.str() << std::endl;
5587 
5588         g2s_ctx.set_buf_size(x_g2s_reg_buf, x_read.reg_buf_size());
5589 
5590         auto load_stmt = x_read.stmt();
5591         if (!grid_cond.is_empty()) load_stmt = if_t::make(grid_cond, load_stmt);
5592         g2s_load_stmt_ = g2s_load_stmt_.append(load_stmt);
5593 
5594         // GRF -> SLM store.
5595         write_builder_t x_write(cfg_.hw, ir_ctx_, cset_, view_t(slm_thr_layout),
5596                 x_slm_buf, x_g2s_reg_buf, /*is_slm=*/true);
5597         ir_trace() << tag << " GRF to SLM store:\n"
5598                    << x_write.str() << std::endl;
5599         auto store_stmt = x_write.stmt();
5600 
5601         auto &read_layout = x_read.reg_layout();
5602         auto &write_layout = x_write.reg_layout();
5603         if (read_layout != write_layout) {
5604             if (cfg_.allow_grf_reorder) {
5605                 // Temporary GRF buffer.
5606                 expr_t tmp_buf
5607                         = g2s_ctx.create_buf("g2s_tmp", /*force_reuse=*/true);
5608                 auto reorder_stmt = create_reorder_stmt(
5609                         read_layout, write_layout, x_g2s_reg_buf, tmp_buf);
5610                 g2s_ctx.set_buf_size(tmp_buf, x_write.reg_buf_size());
5611                 store_stmt = substitute(store_stmt, x_g2s_reg_buf, tmp_buf);
5612                 store_stmt = reorder_stmt.append(store_stmt);
5613             } else {
5614                 ir_error_not_expected() << "Requested register layouts for "
5615                                         << tag << " do not match: "
5616                                         << "read: " << read_layout
5617                                         << ", write: " << write_layout;
5618             }
5619         }
5620         // Generate reduction statement for B.
5621         if (!is_a && cfg_.do_b_reduction) {
5622             auto absolute_thr_tile = tg_tile.create_sub_tensor(thr_tile);
5623             b_reduce_ctx_.init_reduced_thr_view(absolute_thr_tile, grid_cond);
5624             auto reduce_stmt = b_reduce_ctx_.create_reduce_stmt(
5625                     read_layout, x_g2s_reg_buf);
5626             store_stmt = reduce_stmt.append(store_stmt);
5627         }
5628         if (!grid_cond.is_empty())
5629             store_stmt = if_t::make(grid_cond, store_stmt);
5630         g2s_store_stmt_ = g2s_store_stmt_.append(store_stmt);
5631 
5632         x_slm_view = view_t(xp_slm_layout);
5633 
5634         return true;
5635     }
5636 
prepare_prefetch(const char * tag,bool use_prefetch,const view_t & x_gmem_view,const expr_t & xp_buf)5637     void prepare_prefetch(const char *tag, bool use_prefetch,
5638             const view_t &x_gmem_view, const expr_t &xp_buf) {
5639         if (!use_prefetch) return;
5640 
5641         // Per-thread view to prefetch from GMEM.
5642         auto thr_view = x_gmem_view.split(gemm_schedule_.tg_grid());
5643 
5644         // GMEM prefetch.
5645         read_builder_t x_prefetch(cfg_.hw, ir_ctx_, cset_, thr_view, xp_buf,
5646                 expr_t(), /*is_slm=*/false, /*is_prefetch=*/true);
5647         ir_trace() << tag << " GMEM prefetch:\n"
5648                    << x_prefetch.str() << std::endl;
5649 
5650         prefetch_stmt_ = prefetch_stmt_.append(x_prefetch.stmt());
5651     }
5652 
create_slm_layout(const view_t & tg_view,abc_kind_t abc_kind,const grid_info_t & load_grid) const5653     layout_t create_slm_layout(const view_t &tg_view, abc_kind_t abc_kind,
5654             const grid_info_t &load_grid) const {
5655         auto layout = tg_view.create_dense_vlayout();
5656         auto &a_type = gemm_schedule_.a_view().type();
5657         auto &b_type = gemm_schedule_.b_view().type();
5658         auto ret = convert_to_fma_friendly_layout(cfg_, abc_kind,
5659                 gemm_schedule_.bmnk_mapper(), layout, a_type, b_type);
5660         if (cfg_.pad_slm) ret = pad_slm_layout(ret, load_grid);
5661         return ret;
5662     }
5663 
5664     // SLM has 65 dword-granularity banks (Xe_HP):
5665     //      banks:   [bank 0] [bank 1] [bank 2] ... [bank 0]
5666     // byte offsets: | 0      | 4      | 8      ... | 4 * 65
5667     // SLM reads don't have conflicts. During SLM writes each fused EU writes
5668     // 64 bytes (in total 128 bytes per clock). If there are repeating banks
5669     // between 128 bytes the write takes 2 clocks to complete.
5670     // Assume that every X-axis thread (across tg_dim[0]) writes the
5671     // corresponding outer block of the layout. The goal is to ensure that the
5672     // stride between outer blocks allows to avoid duplicated banks.
pad_slm_layout(const layout_t & layout,const grid_info_t & load_grid) const5673     layout_t pad_slm_layout(
5674             const layout_t &layout, const grid_info_t &load_grid) const {
5675         auto tg_dim0 = load_grid.dim(0);
5676         auto tg_dim1 = load_grid.dim(1);
5677         int type_size = layout.type().size();
5678 
5679         ir_assert(layout.elems() % tg_dim0 == 0) << layout;
5680         dim_t inner_block = layout.elems() / tg_dim0;
5681 
5682         ir_assert((inner_block * type_size) % tg_dim1 == 0) << layout;
5683         dim_t per_thr_bytes = (inner_block * type_size) / tg_dim1;
5684 
5685         std::vector<dim_t> multi_blocks = {inner_block, tg_dim0};
5686         auto l = layout.split_into_multi_blocks(multi_blocks);
5687 
5688         auto padded_blocks = l.blocks();
5689         dim_t stride = -1;
5690         dim_t remaining_elems = inner_block;
5691         bool past_inner_block = false;
5692         for (auto &b : padded_blocks) {
5693             if (past_inner_block) {
5694                 if (stride == -1) {
5695                     dim_t stride_bytes = find_min_stride_without_conflicts(
5696                             per_thr_bytes, dim_t(b.stride) * type_size);
5697                     ir_assert(stride_bytes % type_size == 0);
5698                     stride = stride_bytes / type_size;
5699                 }
5700                 b.stride = stride;
5701                 stride = b.stride * b.block;
5702                 continue;
5703             }
5704             ir_assert(remaining_elems % b.block == 0);
5705             remaining_elems /= b.block;
5706             if (remaining_elems == 1) past_inner_block = true;
5707         }
5708         return layout_t(
5709                 layout.type(), layout.ndims(), layout.offset(), padded_blocks);
5710     }
5711 
find_min_stride_without_conflicts(dim_t inner_bytes,dim_t dense_stride_bytes) const5712     dim_t find_min_stride_without_conflicts(
5713             dim_t inner_bytes, dim_t dense_stride_bytes) const {
5714         int write_step = 64;
5715         int stride_step = 16;
5716         dim_t stride_beg = dense_stride_bytes;
5717         dim_t stride_end = 2 * dense_stride_bytes;
5718         const int slm_banks = 65;
5719         for (dim_t s = stride_beg; s < stride_end; s += stride_step) {
5720             bool ok = true;
5721             for (dim_t off0 = 0; off0 < inner_bytes; off0 += write_step) {
5722                 // Check banks for a single SLM write.
5723                 bool found[slm_banks] = {false};
5724                 for (dim_t off = off0; off < off0 + write_step;
5725                         off += sizeof(uint32_t)) {
5726                     int bank0 = (off / sizeof(uint32_t)) % slm_banks;
5727                     int bank1 = ((off + s) / sizeof(uint32_t)) % slm_banks;
5728                     if (found[bank0]) {
5729                         ok = false;
5730                         break;
5731                     }
5732                     found[bank0] = true;
5733                     if (found[bank1]) {
5734                         ok = false;
5735                         break;
5736                     }
5737                     found[bank1] = true;
5738                 }
5739                 if (ok) return s;
5740             }
5741         }
5742 
5743         ir_warning()
5744                 << "Couldn't find stride without conflicts for SLM padding."
5745                 << std::endl;
5746 
5747         return dense_stride_bytes;
5748     }
5749 
5750     const conv_config_t &cfg_;
5751     ir_context_t &ir_ctx_;
5752     constraint_set_t &cset_;
5753     post_op_context_t post_op_ctx_;
5754     b_reduce_context_t b_reduce_ctx_;
5755 
5756     g2s_context_t g2s_ctx_;
5757 
5758     gemm_schedule_t gemm_schedule_;
5759 
5760     expr_t ap_buf_;
5761     expr_t bp_buf_;
5762     expr_t cp_buf_;
5763 
5764     std::vector<stmt_t> compute_allocs_;
5765     std::vector<stmt_t> out_allocs_;
5766     int ab_slm_size_ = 0;
5767 
5768     stmt_t g2s_load_stmt_;
5769     stmt_t g2s_store_stmt_;
5770     stmt_t prefetch_stmt_;
5771     stmt_t load_mul_stmt_;
5772 
5773     stmt_t c_zero_out_stmt_;
5774     stmt_t c_store_stmt_;
5775 
5776     stmt_t b_reduced_zero_out_stmt_;
5777     stmt_t b_reduced_store_stmt_;
5778 };
5779 
build()5780 void kernel_builder_t::build() {
5781     ir_context_t ir_ctx;
5782     constraint_set_t init_cset;
5783 
5784     int grid_ndims = 3;
5785     kernel_grid_ = grid_info_t(grid_ndims);
5786     tg_grid_ = grid_info_t(grid_ndims);
5787     for (int i = 0; i < grid_ndims; i++) {
5788         local_id_[i]
5789                 = var_t::make(type_t::u16(), "local_id" + std::to_string(i));
5790         kernel_grid_.dim(i) = cfg_.kernel_grid_dim[i];
5791         kernel_grid_.idx(i)
5792                 = var_t::make(type_t::s32(), "grid_idx" + std::to_string(i));
5793         tg_grid_.dim(i) = cfg_.tg_grid_dim[i];
5794         tg_grid_.idx(i)
5795                 = var_t::make(type_t::s32(), "tg_idx" + std::to_string(i));
5796 
5797         init_cset.add_constraint(kernel_grid_.idx(i) >= 0);
5798         init_cset.add_constraint(kernel_grid_.idx(i) < cfg_.kernel_grid_dim[i]);
5799         init_cset.add_constraint(tg_grid_.idx(i) >= 0);
5800         init_cset.add_constraint(tg_grid_.idx(i) < cfg_.tg_grid_dim[i]);
5801     }
5802 
5803     gemm_schedule_t gemm_schedule(init_cset, kernel_grid_, tg_grid_);
5804 
5805     std::vector<stmt_t> init_stmts;
5806     for (int i = 0; i < grid_ndims; i++) {
5807         auto value = local_id_[i];
5808         if (i == 0) value /= cfg_.simd_size;
5809         init_stmts.push_back(let_t::make(tg_grid_.idx(i), value));
5810     }
5811 
5812     // Initialize memory buffers.
5813     std::vector<stmt_t> inner_lets;
5814 
5815     view_t a_view;
5816     view_t b_view;
5817     view_t c_view;
5818     view_t bp_reduced_view;
5819 
5820     expr_t ap_buf;
5821     expr_t bp_buf;
5822     expr_t cp_buf;
5823     expr_t b_reduced_mem_buf;
5824     expr_t b_reduction_condition;
5825 
5826     if (cfg_.is_fwd) {
5827         init_fwd(gemm_schedule, a_view, b_view, c_view, ap_buf, bp_buf, cp_buf);
5828     } else if (cfg_.is_bwd_d) {
5829         init_bwd_d(
5830                 gemm_schedule, a_view, b_view, c_view, ap_buf, bp_buf, cp_buf);
5831     } else if (cfg_.is_bwd_w) {
5832         init_bwd_w(gemm_schedule, a_view, b_view, c_view, bp_reduced_view,
5833                 ap_buf, bp_buf, cp_buf, b_reduced_mem_buf,
5834                 b_reduction_condition);
5835     } else {
5836         ir_error_not_expected();
5837     }
5838 
5839     gemm_schedule.finalize();
5840 
5841     post_op_context_t post_op_ctx(
5842             pd_, cfg_, gemm_schedule.c_view(), kernel_arg_info_);
5843     compute_builder_t cb(cfg_, ir_ctx, init_cset);
5844 
5845     cb.set_gemm_schedule(gemm_schedule);
5846     cb.set_ap_buf(ap_buf);
5847     cb.set_bp_buf(bp_buf);
5848     cb.set_cp_buf(cp_buf);
5849     cb.set_b_reduced_mem_buf(b_reduced_mem_buf);
5850     cb.set_b_reduced_view(bp_reduced_view);
5851     cb.set_post_op_context(post_op_ctx);
5852     cb.set_reduce_condition(b_reduction_condition);
5853 
5854     cb.build();
5855 
5856     std::vector<stmt_t> allocs;
5857     for (int i = 0; i < kernel_arg_info_.nargs(); i++) {
5858         auto &var = kernel_arg_info_.arg_var(i);
5859         if (!var.type().is_ptr()) continue;
5860         allocs.push_back(alloc_t::make(var, 0, alloc_kind_t::global));
5861     }
5862 
5863     // Create IR statements.
5864     stmt_t loop_stmt = cb.iter_stmt();
5865     loop_stmt = gemm_schedule.create_loop_nest(loop_stmt);
5866     loop_stmt = stmt_group_t::make(stmt_label_t::compute_loop(), loop_stmt);
5867     loop_stmt = cb.inject_compute_alloc_stmts(loop_stmt);
5868 
5869     auto c_store_stmt
5870             = stmt_group_t::make(stmt_label_t::c_store(), cb.c_store_stmt());
5871     stmt_ = loop_stmt;
5872     stmt_ = stmt_seq_t::make(cb.zero_out_stmt(), stmt_);
5873     stmt_ = stmt_seq_t::make(stmt_, cb.b_reduced_store_stmt());
5874     stmt_ = stmt_seq_t::make(stmt_, c_store_stmt);
5875 
5876     stmt_ = cb.inject_out_alloc_stmts(stmt_);
5877     stmt_ = cb.inject_let_stmts(stmt_);
5878 
5879     stmt_ = gemm_schedule.create_bind_stmt(stmt_);
5880     stmt_ = inject_let_stmts(stmt_, init_stmts);
5881     stmt_ = inject_alloc_stmts(stmt_, allocs);
5882 
5883     stmt_ = inject_external_var_let(stmt_);
5884     stmt_ = merge_slm_buffers(stmt_);
5885     if (!cfg_.do_loop_unroll && (cfg_.use_a_slm || cfg_.use_b_slm)) {
5886         stmt_ = inject_simple_slm_buffering(
5887                 cfg_.hw, stmt_, cfg_, ir_ctx, cb.ab_slm_size());
5888     }
5889     stmt_ = lift_buffer_offsets_in_send(stmt_);
5890     stmt_ = simplify_pass(stmt_, init_cset);
5891     stmt_ = inject_send(stmt_, ir_ctx, init_cset);
5892     stmt_ = split_wide_stores(cfg_.hw, stmt_);
5893     stmt_ = lift_alloc(stmt_, cfg_);
5894     stmt_ = eliminate_common_subexprs(stmt_, ir_ctx);
5895     stmt_ = hoist_exprs(stmt_, ir_ctx);
5896     if (cfg_.do_loop_unroll) stmt_ = loop_strength_reduce(stmt_);
5897     stmt_ = optimize_alloc_let(stmt_);
5898     if (cfg_.do_loop_unroll) {
5899         stmt_ = update_loops_for_unrolling(stmt_, cfg_);
5900         stmt_ = inject_unrolling(stmt_, cfg_, ir_ctx, cb.ab_slm_size());
5901     }
5902     stmt_ = fixup_if_conditions(stmt_, cfg_);
5903     stmt_ = unroll_loops(stmt_, ir_ctx);
5904     stmt_ = simplify_pass(stmt_, init_cset);
5905     stmt_ = optimize_alloc_let(stmt_);
5906     stmt_ = optimize_peephole(stmt_);
5907     stmt_ = stmt_group_t::make(stmt_label_t::kernel(), stmt_);
5908 
5909     ir_trace() << "Kernel body:\n" << stmt_ << std::endl;
5910 }
5911 
5912 namespace {
need_src_or_dst_check(bool is_fwd,int o,int i,int k,int p,int s,int d)5913 bool need_src_or_dst_check(
5914         bool is_fwd, int o, int i, int k, int p, int s, int d) {
5915     if (is_fwd) {
5916         int i_min = -p;
5917         int i_max = (o - 1) * s - p + (k - 1) * (1 + d);
5918         return (i_min < 0) || (i_max >= i);
5919     }
5920     // Backward.
5921     int os_min = p - (k - 1) * (1 + d);
5922     int os_max = (o - 1) + p;
5923     return (os_min < 0) || (os_max >= i * s);
5924 }
5925 
5926 } // namespace
5927 
init_fwd(gemm_schedule_t & gemm_schedule,view_t & src_view,view_t & wei_view,view_t & dst_view,expr_t & src_buf,expr_t & wei_buf,expr_t & dst_buf)5928 void kernel_builder_t::init_fwd(gemm_schedule_t &gemm_schedule,
5929         view_t &src_view, view_t &wei_view, view_t &dst_view, expr_t &src_buf,
5930         expr_t &wei_buf, expr_t &dst_buf) {
5931     // Unify layouts.
5932     auto src_layout = cfg_.src_layout;
5933     auto wei_layout = cfg_.wei_layout;
5934     auto dst_layout = cfg_.dst_layout;
5935     normalize_conv_layouts(src_layout, wei_layout, dst_layout, cfg_.with_groups,
5936             cfg_.g, cfg_.is_dw, cfg_.reduced_to_1d, /*add_groups=*/true);
5937 
5938     // Initialize views.
5939     auto mb = var_t::make(type_t::s32(), "mb");
5940     auto ic = var_t::make(type_t::s32(), "ic");
5941     auto oc = var_t::make(type_t::s32(), "oc");
5942     auto od = var_t::make(type_t::s32(), "od");
5943     auto oh = var_t::make(type_t::s32(), "oh");
5944     auto ow = var_t::make(type_t::s32(), "ow");
5945     auto kd = var_t::make(type_t::s32(), "kd");
5946     auto kh = var_t::make(type_t::s32(), "kh");
5947     auto kw = var_t::make(type_t::s32(), "kw");
5948     auto g = var_t::make(type_t::s32(), "g");
5949 
5950     // Initialize masks.
5951     expr_t id_mask, ih_mask, iw_mask;
5952     expr_t od_mask, oh_mask, ow_mask;
5953     expr_t src_mb_mask, dst_mb_mask;
5954     expr_t wei_oc_mask, dst_oc_mask;
5955     expr_t src_g_mask, wei_g_mask, dst_g_mask;
5956     expr_t kw_mask;
5957 
5958     bool check_ow = (cfg_.ow % cfg_.ow_tg_blk != 0);
5959     bool check_iw = check_ow
5960             || need_src_or_dst_check(cfg_.is_fwd, cfg_.ow, cfg_.iw, cfg_.kw,
5961                     cfg_.pw, cfg_.sw, cfg_.dw);
5962     bool check_ih = need_src_or_dst_check(
5963             cfg_.is_fwd, cfg_.oh, cfg_.ih, cfg_.kh, cfg_.ph, cfg_.sh, cfg_.dh);
5964     bool check_id = need_src_or_dst_check(
5965             cfg_.is_fwd, cfg_.od, cfg_.id, cfg_.kd, cfg_.pd, cfg_.sd, cfg_.dd);
5966     bool check_kw = (cfg_.kw % cfg_.kw_blk != 0);
5967 
5968     int src_g = int(src_layout.dim(1));
5969     int src_g_inner_blk = ir_utils::max_pow2_divisor(src_g);
5970     src_g_inner_blk = std::min(src_g_inner_blk, cfg_.g_thr_blk);
5971 
5972     int wei_g = int(wei_layout.dim(0));
5973     int wei_g_inner_blk = ir_utils::max_pow2_divisor(wei_g);
5974     wei_g_inner_blk = std::min(wei_g_inner_blk, cfg_.g_thr_blk);
5975 
5976     int wei_oc = int(wei_layout.dim(1));
5977     int wei_oc_inner_blk = ir_utils::max_pow2_divisor(wei_oc);
5978     wei_oc_inner_blk = std::min(wei_oc_inner_blk, cfg_.oc_thr_blk);
5979 
5980     int dst_g = int(dst_layout.dim(1));
5981     int dst_g_inner_blk = ir_utils::max_pow2_divisor(dst_g);
5982     dst_g_inner_blk = std::min(dst_g_inner_blk, cfg_.g_thr_blk);
5983 
5984     int dst_oc = int(dst_layout.dim(2));
5985     int dst_oc_inner_blk = ir_utils::max_pow2_divisor(dst_oc);
5986     dst_oc_inner_blk = std::min(dst_oc_inner_blk, cfg_.oc_thr_blk);
5987 
5988     bool check_src_g = (src_g % cfg_.g_tg_blk != 0);
5989     bool check_wei_g = (wei_g % cfg_.g_tg_blk != 0);
5990     bool check_wei_oc = (wei_oc % cfg_.oc_tg_blk != 0);
5991     bool check_dst_g = (dst_g % cfg_.g_tg_blk != 0);
5992     bool check_dst_oc = (dst_oc % cfg_.oc_tg_blk != 0);
5993 
5994     int src_mb = int(src_layout.dim(0));
5995     int dst_mb = int(dst_layout.dim(0));
5996 
5997     bool check_src_mb = (src_mb % cfg_.mb_tg_blk != 0);
5998     bool check_dst_mb = (dst_mb % cfg_.mb_tg_blk != 0);
5999 
6000     auto &x = view_t::placeholder_var();
6001     if (check_id) id_mask = (x >= 0) & (x < cfg_.id);
6002     if (check_ih) ih_mask = (x >= 0) & (x < cfg_.ih);
6003     if (check_iw) iw_mask = (x >= 0) & (x < cfg_.iw);
6004     if (check_ow) ow_mask = (x >= 0) & (x < cfg_.ow);
6005     if (check_src_g)
6006         src_g_mask = (x / src_g_inner_blk < src_g / src_g_inner_blk);
6007     if (check_wei_g)
6008         wei_g_mask = (x / wei_g_inner_blk < wei_g / wei_g_inner_blk);
6009     if (check_wei_oc)
6010         wei_oc_mask = (x / wei_oc_inner_blk < wei_oc / wei_oc_inner_blk);
6011     if (check_dst_g)
6012         dst_g_mask = (x / dst_g_inner_blk < dst_g / dst_g_inner_blk);
6013     if (check_dst_oc)
6014         dst_oc_mask = (x / dst_oc_inner_blk < dst_oc / dst_oc_inner_blk);
6015     if (check_kw) kw_mask = (x < cfg_.kw);
6016     if (check_src_mb) src_mb_mask = (x < src_mb);
6017     if (check_dst_mb) dst_mb_mask = (x < dst_mb);
6018 
6019     // Source.
6020     src_view = view_t({mb, g, ic, od, oh, ow, kd, kh, kw}, 6);
6021     src_view.set_vdim(mb, cfg_.mb);
6022     src_view.set_vdim(g, cfg_.g);
6023     src_view.set_vdim(ic, cfg_.ic);
6024     src_view.set_vdim(od, cfg_.od);
6025     src_view.set_vdim(oh, cfg_.oh);
6026     src_view.set_vdim(ow, cfg_.ow);
6027     src_view.set_vdim(kd, cfg_.kd);
6028     src_view.set_vdim(kh, cfg_.kh);
6029     src_view.set_vdim(kw, cfg_.kw);
6030     src_view.set_tdim(0, mb, src_mb_mask);
6031     src_view.set_tdim(1, g, src_g_mask);
6032     src_view.set_tdim(2, ic);
6033     src_view.set_tdim(3, od * cfg_.sd - cfg_.pd + kd * (1 + cfg_.dd), id_mask);
6034     src_view.set_tdim(4, oh * cfg_.sh - cfg_.ph + kh * (1 + cfg_.dh), ih_mask);
6035     src_view.set_tdim(5, ow * cfg_.sw - cfg_.pw + kw * (1 + cfg_.dw), iw_mask);
6036     src_view.set_tlayout(src_layout);
6037 
6038     // Weights.
6039     wei_view = view_t({g, oc, ic, kd, kh, kw}, 6);
6040     wei_view.set_vdim(g, cfg_.g);
6041     wei_view.set_vdim(oc, cfg_.oc);
6042     wei_view.set_vdim(ic, cfg_.ic);
6043     wei_view.set_vdim(kd, cfg_.kd);
6044     wei_view.set_vdim(kh, cfg_.kh);
6045     wei_view.set_vdim(kw, cfg_.kw);
6046     wei_view.set_tdim(0, g, wei_g_mask);
6047     wei_view.set_tdim(1, oc, wei_oc_mask);
6048     wei_view.set_tdim(2, ic);
6049     wei_view.set_tdim(3, kd);
6050     wei_view.set_tdim(4, kh);
6051     wei_view.set_tdim(5, kw, kw_mask);
6052     wei_view.set_tlayout(wei_layout);
6053 
6054     // Destination.
6055     dst_view = view_t({mb, g, oc, od, oh, ow}, 6);
6056     dst_view.set_vdim(mb, cfg_.mb);
6057     dst_view.set_vdim(g, cfg_.g);
6058     dst_view.set_vdim(oc, cfg_.oc);
6059     dst_view.set_vdim(od, cfg_.od);
6060     dst_view.set_vdim(oh, cfg_.oh);
6061     dst_view.set_vdim(ow, cfg_.ow);
6062     dst_view.set_tdim(0, mb, dst_mb_mask);
6063     dst_view.set_tdim(1, g, dst_g_mask);
6064     dst_view.set_tdim(2, oc, dst_oc_mask);
6065     dst_view.set_tdim(3, od, od_mask);
6066     dst_view.set_tdim(4, oh, oh_mask);
6067     dst_view.set_tdim(5, ow, ow_mask);
6068     dst_view.set_tlayout(dst_layout);
6069 
6070     // Initialize GEMM schedule.
6071     gemm_schedule.set_a_view(src_view);
6072     gemm_schedule.set_b_view(wei_view);
6073     gemm_schedule.set_c_view(dst_view);
6074     gemm_schedule.set_b_vars({g});
6075     gemm_schedule.set_m_vars({mb, od, oh, ow});
6076     gemm_schedule.set_n_vars({oc});
6077     gemm_schedule.set_k_vars({ic, kd, kh, kw});
6078 
6079     expr_t g_tg_blk_idx, g_inner;
6080     expr_t oc_tg_blk_idx, oc_thr_blk_idx, oc_inner;
6081     expr_t mb_tg_blk_idx, mb_thr_blk_idx, mb_inner;
6082     expr_t ow_tg_blk_idx, ow_thr_blk_idx, ow_inner;
6083     expr_t kw_outer, kw_inner;
6084     expr_t ic_thr_blk_idx, ic_outer, ic_inner;
6085 
6086     gemm_schedule.split(g, cfg_.g_tg_blk, g_tg_blk_idx, g_inner);
6087     gemm_schedule.split(oc, cfg_.oc_tg_blk, cfg_.oc_thr_blk, oc_tg_blk_idx,
6088             oc_thr_blk_idx, oc_inner);
6089     gemm_schedule.split(mb, cfg_.mb_tg_blk, cfg_.mb_thr_blk, mb_tg_blk_idx,
6090             mb_thr_blk_idx, mb_inner);
6091     gemm_schedule.split(ow, cfg_.ow_tg_blk, cfg_.ow_thr_blk, ow_tg_blk_idx,
6092             ow_thr_blk_idx, ow_inner);
6093     gemm_schedule.split(ic, cfg_.ic_blk * cfg_.ic_thr_dim, cfg_.ic_blk,
6094             ic_outer, ic_thr_blk_idx, ic_inner);
6095     gemm_schedule.split(kw, cfg_.kw_blk, kw_outer, kw_inner);
6096 
6097     auto g_odhw_idx = gemm_schedule.fuse({g_tg_blk_idx, od, oh, ow_tg_blk_idx});
6098     auto mb_ow_thr_blk_idx = gemm_schedule.fuse(mb_thr_blk_idx, ow_thr_blk_idx);
6099 
6100     gemm_schedule.bind(oc_tg_blk_idx, kernel_grid_.idx(0));
6101     gemm_schedule.bind(g_odhw_idx, kernel_grid_.idx(1));
6102     gemm_schedule.bind(mb_tg_blk_idx, kernel_grid_.idx(2));
6103     gemm_schedule.bind(oc_thr_blk_idx, tg_grid_.idx(0));
6104     gemm_schedule.bind(mb_ow_thr_blk_idx, tg_grid_.idx(1));
6105     gemm_schedule.bind(ic_thr_blk_idx, tg_grid_.idx(2));
6106 
6107     gemm_schedule.tensorize(g_inner);
6108     gemm_schedule.tensorize(oc_inner);
6109     gemm_schedule.tensorize(mb_inner);
6110     gemm_schedule.tensorize(ow_inner);
6111     gemm_schedule.tensorize(kw_inner);
6112     gemm_schedule.tensorize(ic_inner);
6113 
6114     gemm_schedule.reorder({ic_outer, kd, kh, kw_outer, oc_thr_blk_idx,
6115             mb_ow_thr_blk_idx, ic_thr_blk_idx});
6116 
6117     src_buf = kernel_arg_info_.find_arg("src");
6118     wei_buf = kernel_arg_info_.find_arg("wei");
6119     dst_buf = kernel_arg_info_.find_arg("dst");
6120 }
6121 
init_bwd_d(gemm_schedule_t & gemm_schedule,view_t & dst_view,view_t & wei_view,view_t & src_view,expr_t & dst_buf,expr_t & wei_buf,expr_t & src_buf)6122 void kernel_builder_t::init_bwd_d(gemm_schedule_t &gemm_schedule,
6123         view_t &dst_view, view_t &wei_view, view_t &src_view, expr_t &dst_buf,
6124         expr_t &wei_buf, expr_t &src_buf) {
6125     // Unify layouts.
6126     auto src_layout = cfg_.src_layout;
6127     auto wei_layout = cfg_.wei_layout;
6128     auto dst_layout = cfg_.dst_layout;
6129     normalize_conv_layouts(src_layout, wei_layout, dst_layout, cfg_.with_groups,
6130             cfg_.g, cfg_.is_dw, cfg_.reduced_to_1d, /*add_groups=*/false);
6131 
6132     // Initialize views.
6133     auto mb = var_t::make(type_t::s32(), "mb");
6134     auto ic = var_t::make(type_t::s32(), "ic");
6135     auto oc = var_t::make(type_t::s32(), "oc");
6136     auto id = var_t::make(type_t::s32(), "id");
6137     auto ih = var_t::make(type_t::s32(), "ih");
6138     auto iw = var_t::make(type_t::s32(), "iw");
6139     auto kd = var_t::make(type_t::s32(), "kd");
6140     auto kh = var_t::make(type_t::s32(), "kh");
6141     auto kw = var_t::make(type_t::s32(), "kw");
6142 
6143     // Initialize masks.
6144     expr_t id_mask, ih_mask, iw_mask;
6145     expr_t od_mask(true), oh_mask(true), ow_mask(true);
6146     expr_t src_mb_mask, dst_mb_mask;
6147     expr_t wei_oc_mask, dst_oc_mask;
6148     expr_t wei_ic_mask, src_ic_mask;
6149 
6150     bool check_iw = (cfg_.iw % cfg_.iw_tg_blk != 0);
6151     bool check_ow = check_iw
6152             || need_src_or_dst_check(cfg_.is_fwd, cfg_.ow, cfg_.iw, cfg_.kw,
6153                     cfg_.pw, cfg_.sw, cfg_.dw);
6154     bool check_oh = need_src_or_dst_check(
6155             cfg_.is_fwd, cfg_.oh, cfg_.ih, cfg_.kh, cfg_.ph, cfg_.sh, cfg_.dh);
6156     bool check_od = need_src_or_dst_check(
6157             cfg_.is_fwd, cfg_.od, cfg_.id, cfg_.kd, cfg_.pd, cfg_.sd, cfg_.dd);
6158 
6159     int wei_ic = int(cfg_.wei_layout.dim(cfg_.with_groups ? 2 : 1));
6160     int src_ic = int(cfg_.src_layout.dim(1));
6161 
6162     int wei_ic_inner_blk = ir_utils::max_pow2_divisor(wei_ic);
6163     int src_ic_inner_blk = ir_utils::max_pow2_divisor(src_ic);
6164     wei_ic_inner_blk = std::min(wei_ic_inner_blk, cfg_.ic_thr_blk);
6165     src_ic_inner_blk = std::min(src_ic_inner_blk, cfg_.ic_thr_blk);
6166 
6167     bool check_wei_ic = (wei_ic % cfg_.ic_tg_blk != 0);
6168     bool check_src_ic = (src_ic % cfg_.ic_tg_blk != 0);
6169 
6170     int src_mb = int(cfg_.src_layout.dim(0));
6171     int dst_mb = int(cfg_.src_layout.dim(0));
6172 
6173     bool check_src_mb = (src_mb % cfg_.mb_tg_blk != 0);
6174     bool check_dst_mb = (dst_mb % cfg_.mb_tg_blk != 0);
6175 
6176     auto &x = view_t::placeholder_var();
6177     if (check_od) od_mask = (x >= 0) & (x < cfg_.od);
6178     if (check_oh) oh_mask = (x >= 0) & (x < cfg_.oh);
6179     if (check_ow) ow_mask = (x >= 0) & (x < cfg_.ow);
6180     if (check_iw) iw_mask = (x >= 0) & (x < cfg_.iw);
6181     if (check_wei_ic)
6182         wei_ic_mask = (x / wei_ic_inner_blk < wei_ic / wei_ic_inner_blk);
6183     if (check_src_ic)
6184         src_ic_mask = (x / src_ic_inner_blk < src_ic / src_ic_inner_blk);
6185     if (check_src_mb) src_mb_mask = (x < src_mb);
6186     if (check_dst_mb) dst_mb_mask = (x < dst_mb);
6187 
6188     // Destination.
6189     dst_view = view_t({mb, oc, id, ih, iw, kd, kh, kw}, 5);
6190     dst_view.set_vdim(mb, cfg_.mb);
6191     dst_view.set_vdim(oc, cfg_.oc);
6192     dst_view.set_vdim(id, cfg_.id);
6193     dst_view.set_vdim(ih, cfg_.ih);
6194     dst_view.set_vdim(iw, cfg_.iw);
6195     dst_view.set_vdim(kd, cfg_.kd);
6196     dst_view.set_vdim(kh, cfg_.kh);
6197     dst_view.set_vdim(kw, cfg_.kw);
6198     dst_view.set_tdim(0, mb, src_mb_mask);
6199     dst_view.set_tdim(1, oc);
6200 
6201     auto od = id - kd * (1 + cfg_.dd) + cfg_.pd;
6202     auto oh = ih - kh * (1 + cfg_.dh) + cfg_.ph;
6203     auto ow = iw - kw * (1 + cfg_.dw) + cfg_.pw;
6204     dst_view.set_tdim(2, od / cfg_.sd, od_mask & (od % cfg_.sd == 0));
6205     dst_view.set_tdim(3, oh / cfg_.sh, oh_mask & (oh % cfg_.sh == 0));
6206     dst_view.set_tdim(4, ow / cfg_.sw, ow_mask & (ow % cfg_.sw == 0));
6207 
6208     dst_view.set_tlayout(dst_layout);
6209 
6210     // Weights.
6211     wei_view = view_t({oc, ic, kd, kh, kw}, 5);
6212     wei_view.set_vdim(ic, cfg_.ic);
6213     wei_view.set_vdim(oc, cfg_.oc);
6214     wei_view.set_vdim(kd, cfg_.kd);
6215     wei_view.set_vdim(kh, cfg_.kh);
6216     wei_view.set_vdim(kw, cfg_.kw);
6217     wei_view.set_tdim(0, oc);
6218     wei_view.set_tdim(1, ic, wei_ic_mask);
6219     wei_view.set_tdim(2, kd);
6220     wei_view.set_tdim(3, kh);
6221     wei_view.set_tdim(4, kw);
6222     wei_view.set_tlayout(wei_layout);
6223 
6224     // Source.
6225     src_view = view_t({mb, ic, id, ih, iw}, 5);
6226     src_view.set_vdim(mb, cfg_.mb);
6227     src_view.set_vdim(ic, cfg_.ic);
6228     src_view.set_vdim(id, cfg_.id);
6229     src_view.set_vdim(ih, cfg_.ih);
6230     src_view.set_vdim(iw, cfg_.iw);
6231     src_view.set_tdim(0, mb, dst_mb_mask);
6232     src_view.set_tdim(1, ic, src_ic_mask);
6233     src_view.set_tdim(2, id, id_mask);
6234     src_view.set_tdim(3, ih, ih_mask);
6235     src_view.set_tdim(4, iw, iw_mask);
6236     src_view.set_tlayout(src_layout);
6237 
6238     // Initialize GEMM schedule.
6239     gemm_schedule.set_a_view(dst_view);
6240     gemm_schedule.set_b_view(wei_view);
6241     gemm_schedule.set_c_view(src_view);
6242     gemm_schedule.set_m_vars({mb, id, ih, iw});
6243     gemm_schedule.set_n_vars({ic});
6244     gemm_schedule.set_k_vars({oc, kd, kh, kw});
6245 
6246     expr_t ic_tg_blk_idx, ic_thr_blk_idx, ic_inner;
6247     expr_t mb_tg_blk_idx, mb_inner;
6248     expr_t iw_tg_blk_idx, iw_thr_blk_idx, iw_inner;
6249     expr_t oc_blk_idx, oc_inner;
6250 
6251     gemm_schedule.split(ic, cfg_.ic_tg_blk, cfg_.ic_thr_blk, ic_tg_blk_idx,
6252             ic_thr_blk_idx, ic_inner);
6253     gemm_schedule.split(mb, cfg_.mb_tg_blk, mb_tg_blk_idx, mb_inner);
6254     gemm_schedule.split(iw, cfg_.iw_tg_blk, cfg_.iw_thr_blk, iw_tg_blk_idx,
6255             iw_thr_blk_idx, iw_inner);
6256     gemm_schedule.split(oc, cfg_.oc_blk, oc_blk_idx, oc_inner);
6257 
6258     auto idhw_idx = gemm_schedule.fuse(id, ih, iw_tg_blk_idx);
6259     gemm_schedule.bind(ic_tg_blk_idx, kernel_grid_.idx(0));
6260     gemm_schedule.bind(idhw_idx, kernel_grid_.idx(1));
6261     gemm_schedule.bind(mb_tg_blk_idx, kernel_grid_.idx(2));
6262     gemm_schedule.bind(ic_thr_blk_idx, tg_grid_.idx(0));
6263     gemm_schedule.bind(iw_thr_blk_idx, tg_grid_.idx(1));
6264 
6265     gemm_schedule.tensorize(ic_inner);
6266     gemm_schedule.tensorize(mb_inner);
6267     gemm_schedule.tensorize(iw_inner);
6268     gemm_schedule.tensorize(oc_inner);
6269 
6270     gemm_schedule.reorder({oc_blk_idx, kd, kh, kw});
6271 
6272     src_buf = kernel_arg_info_.find_arg("src");
6273     wei_buf = kernel_arg_info_.find_arg("wei");
6274     dst_buf = kernel_arg_info_.find_arg("dst");
6275 }
6276 
init_bwd_w(gemm_schedule_t & gemm_schedule,view_t & src_view,view_t & dst_view,view_t & wei_view,view_t & bia_view,expr_t & src_buf,expr_t & dst_buf,expr_t & wei_buf,expr_t & bia_buf,expr_t & bia_reduction_condition)6277 void kernel_builder_t::init_bwd_w(gemm_schedule_t &gemm_schedule,
6278         view_t &src_view, view_t &dst_view, view_t &wei_view, view_t &bia_view,
6279         expr_t &src_buf, expr_t &dst_buf, expr_t &wei_buf, expr_t &bia_buf,
6280         expr_t &bia_reduction_condition) {
6281     // Unify layouts.
6282     auto src_layout = cfg_.src_layout;
6283     auto wei_layout = cfg_.wei_layout;
6284     auto dst_layout = cfg_.dst_layout;
6285     normalize_conv_layouts(src_layout, wei_layout, dst_layout, cfg_.with_groups,
6286             cfg_.g, cfg_.is_dw, cfg_.reduced_to_1d, /*add_groups=*/false);
6287 
6288     // Initialize thread group views.
6289     auto mb = var_t::make(type_t::s32(), "mb");
6290     auto ic = var_t::make(type_t::s32(), "ic");
6291     auto oc = var_t::make(type_t::s32(), "oc");
6292     auto od = var_t::make(type_t::s32(), "od");
6293     auto oh = var_t::make(type_t::s32(), "oh");
6294     auto ow = var_t::make(type_t::s32(), "ow");
6295     auto kd = var_t::make(type_t::s32(), "kd");
6296     auto kh = var_t::make(type_t::s32(), "kh");
6297     auto kw = var_t::make(type_t::s32(), "kw");
6298 
6299     // Initialize masks.
6300     expr_t id_mask(true), ih_mask(true), iw_mask(true);
6301     expr_t od_mask, oh_mask, ow_mask;
6302     expr_t src_mb_mask, src_ic_mask;
6303     expr_t dst_mb_mask, dst_oc_mask;
6304     expr_t wei_oc_mask, wei_ic_mask;
6305     expr_t kw_mask;
6306 
6307     bool check_ow = (cfg_.ow % cfg_.ow_tg_blk != 0);
6308     bool check_oh = (cfg_.oh % cfg_.oh_tg_blk != 0);
6309     bool check_od = (cfg_.od % cfg_.od_tg_blk != 0);
6310     bool check_kw = (cfg_.kw % cfg_.kw_blk != 0);
6311     bool check_iw = check_kw
6312             || need_src_or_dst_check(/*is_fwd=*/true, cfg_.ow, cfg_.iw, cfg_.kw,
6313                     cfg_.pw, cfg_.sw, cfg_.dw);
6314     bool check_ih = need_src_or_dst_check(/*is_fwd=*/true, cfg_.oh, cfg_.ih,
6315             cfg_.kh, cfg_.ph, cfg_.sh, cfg_.dh);
6316     bool check_id = need_src_or_dst_check(/*is_fwd=*/true, cfg_.od, cfg_.id,
6317             cfg_.kd, cfg_.pd, cfg_.sd, cfg_.dd);
6318     bool check_iw_min = check_iw;
6319     bool check_ih_min = check_ih;
6320     bool check_id_min = check_id;
6321     bool check_iw_max = (check_iw || check_ow);
6322     bool check_ih_max = (check_ih || check_oh);
6323     bool check_id_max = (check_id || check_od);
6324 
6325     int src_ic = int(cfg_.src_layout.dim(1));
6326     int dst_oc = int(cfg_.dst_layout.dim(1));
6327     int wei_oc = int(cfg_.wei_layout.dim(cfg_.with_groups ? 1 : 0));
6328     int wei_ic = int(cfg_.wei_layout.dim(cfg_.with_groups ? 2 : 1));
6329 
6330     int src_ic_inner_blk = ir_utils::max_pow2_divisor(src_ic);
6331     int dst_oc_inner_blk = ir_utils::max_pow2_divisor(dst_oc);
6332     int wei_oc_inner_blk = ir_utils::max_pow2_divisor(wei_oc);
6333     int wei_ic_inner_blk = ir_utils::max_pow2_divisor(wei_ic);
6334     src_ic_inner_blk = std::min(src_ic_inner_blk, cfg_.ic_thr_blk);
6335     dst_oc_inner_blk = std::min(dst_oc_inner_blk, cfg_.oc_thr_blk);
6336     wei_oc_inner_blk = std::min(wei_oc_inner_blk, cfg_.oc_thr_blk);
6337     wei_ic_inner_blk = std::min(wei_ic_inner_blk, cfg_.ic_thr_blk);
6338 
6339     bool check_src_ic = (src_ic % cfg_.ic_tg_blk != 0);
6340     bool check_dst_oc = (dst_oc % cfg_.oc_tg_blk != 0);
6341     bool check_wei_oc = (wei_oc % cfg_.oc_tg_blk != 0);
6342     bool check_wei_ic = (wei_ic % cfg_.ic_tg_blk != 0);
6343 
6344     auto &x = view_t::placeholder_var();
6345     if (check_id_min) id_mask &= (x >= 0);
6346     if (check_ih_min) ih_mask &= (x >= 0);
6347     if (check_iw_min) iw_mask &= (x >= 0);
6348     if (check_id_max) id_mask &= (x < cfg_.id);
6349     if (check_ih_max) ih_mask &= (x < cfg_.ih);
6350     if (check_iw_max) iw_mask &= (x < cfg_.iw);
6351     if (check_od) od_mask = (x < cfg_.od);
6352     if (check_oh) oh_mask = (x < cfg_.oh);
6353     if (check_ow) ow_mask = (x < cfg_.ow);
6354     if (check_src_ic)
6355         src_ic_mask = (x / src_ic_inner_blk < src_ic / src_ic_inner_blk);
6356     if (check_dst_oc)
6357         dst_oc_mask = (x / dst_oc_inner_blk < dst_oc / dst_oc_inner_blk);
6358     if (check_wei_oc)
6359         wei_oc_mask = (x / wei_oc_inner_blk < wei_oc / wei_oc_inner_blk);
6360     if (check_wei_ic)
6361         wei_ic_mask = (x / wei_ic_inner_blk < wei_ic / wei_ic_inner_blk);
6362     if (check_kw) kw_mask = (x < cfg_.kw);
6363 
6364     // Source.
6365     src_view = view_t({mb, ic, od, oh, ow, kw}, 5);
6366     src_view.set_vdim(mb, cfg_.mb);
6367     src_view.set_vdim(ic, cfg_.ic);
6368     src_view.set_vdim(od, cfg_.od);
6369     src_view.set_vdim(oh, cfg_.oh);
6370     src_view.set_vdim(ow, cfg_.ow);
6371     src_view.set_vdim(kw, cfg_.kw);
6372     src_view.set_tdim(0, mb, src_mb_mask);
6373     src_view.set_tdim(1, ic, src_ic_mask);
6374     src_view.set_tdim(2, od * cfg_.sd - cfg_.pd + kd * (1 + cfg_.dd), id_mask);
6375     src_view.set_tdim(3, oh * cfg_.sh - cfg_.ph + kh * (1 + cfg_.dh), ih_mask);
6376     src_view.set_tdim(4, ow * cfg_.sw - cfg_.pw + kw * (1 + cfg_.dw), iw_mask);
6377     src_view.set_tlayout(src_layout);
6378 
6379     // Weights.
6380     wei_view = view_t({oc, ic, kd, kh, kw}, 5);
6381     wei_view.set_vdim(oc, cfg_.oc);
6382     wei_view.set_vdim(ic, cfg_.ic);
6383     wei_view.set_vdim(kd, cfg_.kd);
6384     wei_view.set_vdim(kh, cfg_.kh);
6385     wei_view.set_vdim(kw, cfg_.kw);
6386     wei_view.set_tdim(0, oc, wei_oc_mask);
6387     wei_view.set_tdim(1, ic, wei_ic_mask);
6388     wei_view.set_tdim(2, kd);
6389     wei_view.set_tdim(3, kh);
6390     wei_view.set_tdim(4, kw, kw_mask);
6391     wei_view.set_tlayout(wei_layout);
6392 
6393     // Destination.
6394     dst_view = view_t({mb, oc, od, oh, ow}, 5);
6395     dst_view.set_vdim(mb, cfg_.mb);
6396     dst_view.set_vdim(oc, cfg_.oc);
6397     dst_view.set_vdim(od, cfg_.od);
6398     dst_view.set_vdim(oh, cfg_.oh);
6399     dst_view.set_vdim(ow, cfg_.ow);
6400     dst_view.set_tdim(0, mb, dst_mb_mask);
6401     dst_view.set_tdim(1, oc, dst_oc_mask);
6402     dst_view.set_tdim(2, od, od_mask);
6403     dst_view.set_tdim(3, oh, oh_mask);
6404     dst_view.set_tdim(4, ow, ow_mask);
6405     dst_view.set_tlayout(dst_layout);
6406 
6407     // Bias.
6408     if (cfg_.with_bias) {
6409         expr_t bia_oc_mask;
6410         if (cfg_.oc % cfg_.oc_tg_blk != 0) bia_oc_mask = (x < cfg_.oc);
6411         bia_view = view_t({oc}, 1);
6412         bia_view.set_vdim(oc, cfg_.oc, 0);
6413         bia_view.set_tdim(0, oc, bia_oc_mask);
6414         bia_view.set_tlayout(cfg_.bia_layout);
6415     }
6416 
6417     // Initialize GEMM schedule.
6418     gemm_schedule.set_a_view(src_view);
6419     gemm_schedule.set_b_view(dst_view);
6420     gemm_schedule.set_c_view(wei_view);
6421     gemm_schedule.set_m_vars({ic, kw});
6422     gemm_schedule.set_n_vars({oc});
6423     gemm_schedule.set_k_vars({mb, od, oh, ow});
6424 
6425     expr_t mb_tg_blk_idx, mb_thr_blk_idx, mb_inner;
6426     expr_t oc_tg_blk_idx, oc_thr_blk_idx, oc_inner;
6427     expr_t ic_tg_blk_idx, ic_thr_blk_idx, ic_inner;
6428     expr_t od_tg_blk_idx, od_inner;
6429     expr_t oh_tg_blk_idx, oh_inner;
6430     expr_t ow_tg_blk_idx, ow_thr_blk_idx, ow_inner;
6431     expr_t kw_tg_blk_idx, kw_inner;
6432 
6433     gemm_schedule.split(mb, cfg_.mb_tg_blk, cfg_.mb_blk, mb_tg_blk_idx,
6434             mb_thr_blk_idx, mb_inner);
6435     gemm_schedule.split(ic, cfg_.ic_tg_blk, cfg_.ic_thr_blk, ic_tg_blk_idx,
6436             ic_thr_blk_idx, ic_inner);
6437     gemm_schedule.split(oc, cfg_.oc_tg_blk, cfg_.oc_thr_blk, oc_tg_blk_idx,
6438             oc_thr_blk_idx, oc_inner);
6439     gemm_schedule.split(od, cfg_.od_tg_blk, od_tg_blk_idx, od_inner);
6440     gemm_schedule.split(oh, cfg_.oh_tg_blk, oh_tg_blk_idx, oh_inner);
6441     gemm_schedule.split(ow, cfg_.ow_tg_blk, cfg_.ow_thr_blk, ow_tg_blk_idx,
6442             ow_thr_blk_idx, ow_inner);
6443     gemm_schedule.split(kw, cfg_.kw_tg_blk, kw_tg_blk_idx, kw_inner);
6444 
6445     auto odhw_tg_blk_kdhw_ic_tg_blk_idx
6446             = gemm_schedule.fuse({od_tg_blk_idx, oh_tg_blk_idx, ow_tg_blk_idx,
6447                     kd, kh, kw_tg_blk_idx, ic_tg_blk_idx});
6448 
6449     gemm_schedule.bind(oc_tg_blk_idx, kernel_grid_.idx(0));
6450     gemm_schedule.bind(odhw_tg_blk_kdhw_ic_tg_blk_idx, kernel_grid_.idx(1));
6451     gemm_schedule.bind(mb_tg_blk_idx, kernel_grid_.idx(2));
6452 
6453     gemm_schedule.bind(oc_thr_blk_idx, tg_grid_.idx(0));
6454     gemm_schedule.bind(ic_thr_blk_idx, tg_grid_.idx(1));
6455 
6456     gemm_schedule.reorder({od_inner, oh_inner, ow_inner, mb_thr_blk_idx});
6457 
6458     gemm_schedule.unroll(mb_thr_blk_idx, cfg_.mb_unroll);
6459     gemm_schedule.unroll(ow_thr_blk_idx, cfg_.ow_unroll);
6460     gemm_schedule.tensorize(oc_inner);
6461     gemm_schedule.tensorize(ic_inner);
6462     gemm_schedule.tensorize(mb_inner);
6463     gemm_schedule.tensorize(ow_inner);
6464     gemm_schedule.tensorize(kw_inner);
6465 
6466     src_buf = kernel_arg_info_.find_arg("src");
6467     wei_buf = kernel_arg_info_.find_arg("wei");
6468     dst_buf = kernel_arg_info_.find_arg("dst");
6469 
6470     if (cfg_.with_bias) {
6471         bia_buf = kernel_arg_info_.find_arg("bia");
6472         bia_reduction_condition = expr_t(true);
6473         if (cfg_.kd > 1) bia_reduction_condition &= (kd == 0);
6474         if (cfg_.kh > 1) bia_reduction_condition &= (kh == 0);
6475         if (cfg_.kw > 1) bia_reduction_condition &= (kw_tg_blk_idx == 0);
6476         if (cfg_.ic_tg_dim > 1) bia_reduction_condition &= (ic_tg_blk_idx == 0);
6477         if (!cfg_.use_b_slm && tg_grid_.dim(1) > 1) {
6478             bia_reduction_condition &= (tg_grid_.idx(1) == 0);
6479         }
6480     }
6481 }
6482 
6483 } // namespace jit
6484 } // namespace gpu
6485 } // namespace impl
6486 } // namespace dnnl
6487