1 /*******************************************************************************
2 * Copyright 2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef GPU_JIT_CONV_GEMM_SCHEDULE_HPP
18 #define GPU_JIT_CONV_GEMM_SCHEDULE_HPP
19 
20 #include <functional>
21 #include <limits>
22 #include <sstream>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 #include <initializer_list>
27 
28 #include "gpu/jit/conv/ir.hpp"
29 #include "gpu/jit/conv/tensor.hpp"
30 #include "gpu/jit/conv/utils.hpp"
31 
32 namespace dnnl {
33 namespace impl {
34 namespace gpu {
35 namespace jit {
36 
37 // Used to describe semantics of a dimension in the GEMM context.
38 // GEMM operation is defined as C = A x B
39 // GEMM dimension kinds:
40 // - B:  shared by all tensors A, B, C (batch dimension)
41 // - M:  shared only by A and C
42 // - N:  shared only by B and C
43 // - K:  shared only by A and B (reduction dimension)
44 enum class bmnk_kind_t { undef = -1, b = 0, m = 1, n = 2, k = 3 };
45 
46 enum class abc_kind_t { undef, a, b, c };
47 
48 class bmnk_mapper_t {
49 public:
50     bmnk_mapper_t() = default;
51 
bmnk_mapper_t(const object_map_t<expr_t,bmnk_kind_t> & bmnk_kinds)52     bmnk_mapper_t(const object_map_t<expr_t, bmnk_kind_t> &bmnk_kinds)
53         : bmnk_kinds_(bmnk_kinds) {}
54 
bmnk_kind(const expr_t & var) const55     bmnk_kind_t bmnk_kind(const expr_t &var) const {
56         auto it = bmnk_kinds_.find(var);
57         if (it == bmnk_kinds_.end()) return bmnk_kind_t::undef;
58         return it->second;
59     }
60 
bmnk_kind(abc_kind_t abc_kind,int dim_idx) const61     bmnk_kind_t bmnk_kind(abc_kind_t abc_kind, int dim_idx) const {
62         return bmnk_kind(var(abc_kind, dim_idx));
63     }
64 
ndims(abc_kind_t abc_kind) const65     int ndims(abc_kind_t abc_kind) const {
66         return int(get_vars(abc_kind).size());
67     }
68 
set_a_vars(const std::vector<expr_t> & vars)69     void set_a_vars(const std::vector<expr_t> &vars) { a_vars_ = vars; }
set_b_vars(const std::vector<expr_t> & vars)70     void set_b_vars(const std::vector<expr_t> &vars) { b_vars_ = vars; }
set_c_vars(const std::vector<expr_t> & vars)71     void set_c_vars(const std::vector<expr_t> &vars) { c_vars_ = vars; }
72 
set_bmnk_kind(const expr_t & var,bmnk_kind_t bmnk_kind)73     void set_bmnk_kind(const expr_t &var, bmnk_kind_t bmnk_kind) {
74         auto ret = bmnk_kinds_.insert({var, bmnk_kind});
75         ir_assert(ret.second) << "Can't set variable twice: " << var;
76     }
77 
var(abc_kind_t abc_kind,int dim_idx) const78     const expr_t &var(abc_kind_t abc_kind, int dim_idx) const {
79         return get_vars(abc_kind)[dim_idx];
80     }
81 
dim_idx(abc_kind_t abc_kind,const expr_t & var) const82     int dim_idx(abc_kind_t abc_kind, const expr_t &var) const {
83         auto &vars = get_vars(abc_kind);
84         for (int i = 0; i < int(vars.size()); i++) {
85             if (vars[i].is_same(var)) return i;
86         }
87         return -1;
88     }
89 
90     layout_t map_to_bmnk(abc_kind_t abc_kind,
91             const std::vector<bmnk_kind_t> &bmnk_kinds,
92             const view_t &view) const;
93 
94     layout_t map_to_bmnk(abc_kind_t abc_kind,
95             const std::vector<bmnk_kind_t> &bmnk_kinds,
96             const layout_t &layout) const;
97 
98 private:
get_vars(abc_kind_t abc_kind) const99     const std::vector<expr_t> &get_vars(abc_kind_t abc_kind) const {
100         switch (abc_kind) {
101             case abc_kind_t::a: return a_vars_;
102             case abc_kind_t::b: return b_vars_;
103             case abc_kind_t::c: return c_vars_;
104             default: ir_error_not_expected() << "Unknown ABC kind.";
105         }
106         return a_vars_;
107     }
108 
get_vars(abc_kind_t abc_kind)109     std::vector<expr_t> &get_vars(abc_kind_t abc_kind) {
110         auto &vars
111                 = const_cast<const bmnk_mapper_t *>(this)->get_vars(abc_kind);
112         return const_cast<std::vector<expr_t> &>(vars);
113     }
114 
115     std::vector<expr_t> a_vars_;
116     std::vector<expr_t> b_vars_;
117     std::vector<expr_t> c_vars_;
118     object_map_t<expr_t, bmnk_kind_t> bmnk_kinds_;
119 };
120 
121 class bmnk_block_mapper_t {
122 public:
bmnk_block_mapper_t(const bmnk_mapper_t & bmnk_mapper)123     bmnk_block_mapper_t(const bmnk_mapper_t &bmnk_mapper)
124         : bmnk_mapper_(bmnk_mapper) {}
125 
push_blocks(abc_kind_t abc_kind,const std::vector<block_t> & blocks)126     void push_blocks(abc_kind_t abc_kind, const std::vector<block_t> &blocks) {
127         for (auto &b : blocks)
128             push_block(abc_kind, b);
129     }
130 
131     void push_block(abc_kind_t abc_kind, const block_t &b);
132 
133     layout_t map_from_bmnk(abc_kind_t abc_kind,
134             const std::vector<bmnk_kind_t> &bmnk_kinds,
135             const layout_t &bmnk_layout) const;
136 
137 private:
pop_size_1_blocks(std::vector<block_t> & blocks)138     static void pop_size_1_blocks(std::vector<block_t> &blocks) {
139         while (!blocks.empty() && blocks.front().block == 1) {
140             blocks.erase(blocks.begin());
141         }
142     }
143 
create_prb_blocks(abc_kind_t abc_kind,const std::vector<std::pair<abc_kind_t,block_t>> & mn_blocks) const144     std::vector<block_t> create_prb_blocks(abc_kind_t abc_kind,
145             const std::vector<std::pair<abc_kind_t, block_t>> &mn_blocks)
146             const {
147         std::vector<block_t> ret;
148         ret.reserve(mn_blocks.size());
149         for (auto &p : mn_blocks) {
150             auto b = p.second;
151             const auto &var = bmnk_mapper_.var(p.first, b.dim_idx);
152             b.dim_idx = bmnk_mapper_.dim_idx(abc_kind, var);
153             ret.push_back(b);
154         }
155         return ret;
156     }
157 
158     bool pop_block(std::vector<block_t> &bmnk_blocks,
159             std::vector<block_t> &prb_blocks, const block_t &bmnk_block) const;
160 
161     bmnk_mapper_t bmnk_mapper_;
162 
163     // Ordered from innermost to outermost.
164     std::vector<std::pair<abc_kind_t, block_t>> m_blocks_;
165     std::vector<std::pair<abc_kind_t, block_t>> n_blocks_;
166     std::vector<std::pair<abc_kind_t, block_t>> k_blocks_;
167 };
168 
169 enum class loop_kind_t : int {
170     undef,
171     kernel_grid, // Loop is bound to the kernel grid.
172     serial, // Loop is inside a thread (may be unrolled or just a regular loop).
173     tg_grid, // Loop is bound to the thread group grid.
174     tensorized, // Such loops are fully unrolled/vectorized and converted to blocked multiplication.
175 };
176 
to_string(loop_kind_t kind)177 static std::string to_string(loop_kind_t kind) {
178     switch (kind) {
179         case loop_kind_t::undef: return "undef";
180         case loop_kind_t::kernel_grid: return "kernel_grid";
181         case loop_kind_t::serial: return "serial";
182         case loop_kind_t::tg_grid: return "tg_grid";
183         case loop_kind_t::tensorized: return "tensorized";
184         default: ir_error_not_expected();
185     }
186     return "unknown";
187 }
188 
operator <<(std::ostream & out,loop_kind_t kind)189 inline std::ostream &operator<<(std::ostream &out, loop_kind_t kind) {
190     out << to_string(kind);
191     return out;
192 }
193 
194 enum class tile_level_t { thread_group, thread };
195 
196 class loop_t {
197 public:
loop_t()198     loop_t() : kind_(loop_kind_t::undef) {}
199 
loop_t(const expr_t & var,const expr_t & bound,bool is_root)200     loop_t(const expr_t &var, const expr_t &bound, bool is_root)
201         : var_(var)
202         , kind_(loop_kind_t::serial)
203         , bound_(bound)
204         , is_root_(is_root) {}
205 
var() const206     const expr_t &var() const { return var_; }
207 
kind() const208     loop_kind_t kind() const { return kind_; }
209 
set_kind(loop_kind_t kind)210     void set_kind(loop_kind_t kind) { kind_ = kind; }
211 
unroll_factor() const212     int unroll_factor() const { return unroll_factor_; }
213 
set_unroll_factor(int factor)214     void set_unroll_factor(int factor) { unroll_factor_ = factor; }
215 
is_kernel_grid() const216     bool is_kernel_grid() const { return kind() == loop_kind_t::kernel_grid; }
217 
is_serial() const218     bool is_serial() const { return kind() == loop_kind_t::serial; }
219 
is_tg_grid() const220     bool is_tg_grid() const { return kind() == loop_kind_t::tg_grid; }
221 
is_tensorized() const222     bool is_tensorized() const { return kind() == loop_kind_t::tensorized; }
223 
bound() const224     const expr_t &bound() const { return bound_; }
225 
set_bound(const expr_t & bound)226     void set_bound(const expr_t &bound) { bound_ = bound; }
227 
is_bound() const228     bool is_bound() const { return !bound_var().is_empty(); }
229 
bound_var() const230     const expr_t &bound_var() const { return bound_var_; }
231 
set_bound_var(const expr_t & v)232     void set_bound_var(const expr_t &v) { bound_var_ = v; }
233 
is_root() const234     bool is_root() const { return is_root_; }
235 
236     // Returns true for loops that were neither split, nor fused with other loops.
is_leaf() const237     bool is_leaf() const { return is_leaf_; }
238 
239     // Returns true if this loop was split into outer/inner loops.
is_split_parent() const240     bool is_split_parent() const { return is_split_parent_; }
241 
242     // Returns true if this loop was the result of a split.
is_split_child() const243     bool is_split_child() const { return is_split_child_; }
244 
245     // Returns true if this loop was fused with other loops.
is_fused_parent() const246     bool is_fused_parent() const { return is_fused_parent_; }
247 
248     // Returns true if this loop was the result of a fusion.
is_fused_child() const249     bool is_fused_child() const { return is_fused_child_; }
250 
parent_vars() const251     const std::vector<expr_t> &parent_vars() const { return parent_vars_; }
child_vars() const252     const std::vector<expr_t> &child_vars() const { return child_vars_; }
253 
set_split(loop_t & outer_loop,loop_t & inner_loop)254     void set_split(loop_t &outer_loop, loop_t &inner_loop) {
255         outer_loop.parent_vars_.push_back(var());
256         child_vars_.push_back(outer_loop.var());
257         outer_loop.is_split_child_ = true;
258 
259         inner_loop.parent_vars_.push_back(var());
260         child_vars_.push_back(inner_loop.var());
261         inner_loop.is_split_child_ = true;
262 
263         is_split_parent_ = true;
264         is_leaf_ = false;
265     }
266 
set_fuse(std::vector<std::reference_wrapper<loop_t>> & loops)267     void set_fuse(std::vector<std::reference_wrapper<loop_t>> &loops) {
268         for (auto &l_ref : loops) {
269             auto &l = l_ref.get();
270             parent_vars_.push_back(l.var());
271             l.child_vars_.push_back(var());
272             l.is_fused_parent_ = true;
273             l.is_leaf_ = false;
274         }
275         is_fused_child_ = true;
276     }
277 
278     // Returns a loop variable expressed in the variables of the leaf loops.
expand_var(const object_map_t<expr_t,loop_t> & all_loops,bool skip_fused=false) const279     expr_t expand_var(const object_map_t<expr_t, loop_t> &all_loops,
280             bool skip_fused = false) const {
281         if (is_leaf()) return var();
282         if (is_split_parent()) {
283             ir_assert(child_vars_.size() == 2);
284             auto &outer_loop = all_loops.at(child_vars_[0]);
285             auto &inner_loop = all_loops.at(child_vars_[1]);
286             auto outer_var = outer_loop.expand_var(all_loops, skip_fused);
287             auto inner_var = inner_loop.expand_var(all_loops, skip_fused);
288             return outer_var * inner_loop.bound() + inner_var;
289         }
290         if (is_fused_parent()) {
291             if (skip_fused) return var();
292             // Example of "unpacking":
293             //     fused_var = (a * b * c * d)
294             //     b = (fused_var / (D * C)) % B
295             ir_assert(child_vars_.size() == 1);
296             auto &fused_loop = all_loops.at(child_vars_[0]);
297             int nvars = int(fused_loop.parent_vars_.size());
298             expr_t denom = 1;
299             for (int i = nvars - 1; i >= 0; i--) {
300                 auto &v = fused_loop.parent_vars_[i];
301                 auto &child_loop = all_loops.at(v);
302                 auto &bound = child_loop.bound();
303                 if (v.is_same(var())) {
304                     auto e = fused_loop.expand_var(all_loops, skip_fused)
305                             / denom;
306                     return (i == 0 ? e : e % bound);
307                 }
308                 denom *= bound;
309             }
310         }
311 
312         ir_error_not_expected();
313         return expr_t();
314     }
315 
str() const316     std::string str() const {
317         using namespace ir_utils;
318 
319         std::ostringstream oss;
320         oss << "var: " << var_;
321         oss << " bound: " << bound_;
322         oss << " kind: " << kind_;
323         if (unroll_factor_ != 1) oss << " unroll: " << unroll_factor_;
324         std::vector<std::string> props;
325         if (is_root()) props.push_back("root");
326         if (is_fused_child()) props.push_back("fused");
327         if (is_split_parent()) props.push_back("split");
328         oss << "(" << make_seq_print_helper(props, ", ") << ")";
329         return oss.str();
330     }
331 
332     IR_DEFINE_DUMP()
333 
334 private:
335     expr_t var_; // Loop index variable.
336     loop_kind_t kind_; // Loop kind.
337     expr_t bound_; // Loop bound (exclusive).
338 
339     expr_t bound_var_; // External variable this loop bound to.
340 
341     int unroll_factor_ = 1;
342 
343     bool is_root_ = false;
344     bool is_leaf_ = true;
345 
346     bool is_split_parent_ = false;
347     bool is_split_child_ = false;
348 
349     bool is_fused_parent_ = false;
350     bool is_fused_child_ = false;
351 
352     // For variables there were split or fused.
353     // Fusion: i x j -> k
354     //     i.child_vars _= [k]
355     //     j.child_vars _= [k]
356     //     k.parent_vars_ = [i, j]
357     // Split: i -> j x k
358     //     i.child_vars_ = [j, k]
359     //     j.parent_vars_ = [i]
360     //     k.parent_vars_ = [i]
361     std::vector<expr_t> parent_vars_;
362     std::vector<expr_t> child_vars_;
363 };
364 
365 // Defines GEMM computation including:
366 // - Blocking scheme (order of loops, tiles per thread group/thread)
367 // - Mapping of problem dimensions to GEMM dimensions (BMNK)
368 class gemm_schedule_t {
369 public:
370     gemm_schedule_t() = default;
371 
gemm_schedule_t(constraint_set_t & cset,const grid_info_t & kernel_grid,const grid_info_t & tg_grid)372     gemm_schedule_t(constraint_set_t &cset, const grid_info_t &kernel_grid,
373             const grid_info_t &tg_grid)
374         : cset_(&cset), kernel_grid_(kernel_grid), tg_grid_(tg_grid) {}
375 
kernel_grid() const376     const grid_info_t &kernel_grid() const { return kernel_grid_; }
tg_grid() const377     const grid_info_t &tg_grid() const { return tg_grid_; }
378 
bmnk_kind(const expr_t & var) const379     bmnk_kind_t bmnk_kind(const expr_t &var) const {
380         return bmnk_kind(std::vector<expr_t>({var}));
381     }
382 
bmnk_mapper() const383     const bmnk_mapper_t &bmnk_mapper() const { return bmnk_mapper_; }
384 
set_b_vars(const std::vector<expr_t> & vars)385     void set_b_vars(const std::vector<expr_t> &vars) {
386         for (auto &v : vars)
387             set_bmnk_kind(v, bmnk_kind_t::b);
388     }
389 
set_m_vars(const std::vector<expr_t> & vars)390     void set_m_vars(const std::vector<expr_t> &vars) {
391         for (auto &v : vars)
392             set_bmnk_kind(v, bmnk_kind_t::m);
393     }
394 
set_n_vars(const std::vector<expr_t> & vars)395     void set_n_vars(const std::vector<expr_t> &vars) {
396         for (auto &v : vars)
397             set_bmnk_kind(v, bmnk_kind_t::n);
398     }
399 
set_k_vars(const std::vector<expr_t> & vars)400     void set_k_vars(const std::vector<expr_t> &vars) {
401         for (auto &v : vars)
402             set_bmnk_kind(v, bmnk_kind_t::k);
403     }
404 
405     // A/B/C views in the problem notation.
a_view() const406     const view_t &a_view() const { return a_view_; }
b_view() const407     const view_t &b_view() const { return b_view_; }
c_view() const408     const view_t &c_view() const { return c_view_; }
409 
set_a_view(const view_t & v)410     void set_a_view(const view_t &v) {
411         set_view(v, a_view_);
412         bmnk_mapper_.set_a_vars(a_view_.vvars());
413     }
414 
set_b_view(const view_t & v)415     void set_b_view(const view_t &v) {
416         set_view(v, b_view_);
417         bmnk_mapper_.set_b_vars(b_view_.vvars());
418     }
419 
set_c_view(const view_t & v)420     void set_c_view(const view_t &v) {
421         set_view(v, c_view_);
422         bmnk_mapper_.set_c_vars(c_view_.vvars());
423     }
424 
a_tg_view() const425     view_t a_tg_view() const {
426         ir_assert(is_finalized_);
427         return a_view_.create_sub_view(a_tg_tile_);
428     }
429 
b_tg_view() const430     view_t b_tg_view() const {
431         ir_assert(is_finalized_);
432         return b_view_.create_sub_view(b_tg_tile_);
433     }
434 
c_tg_view() const435     view_t c_tg_view() const {
436         ir_assert(is_finalized_);
437         return c_view_.create_sub_view(c_tg_tile_);
438     }
439 
440     // Thread group tiles for A, B, C.
a_tg_tile() const441     const tensor_t &a_tg_tile() const { return a_tg_tile_; }
b_tg_tile() const442     const tensor_t &b_tg_tile() const { return b_tg_tile_; }
c_tg_tile() const443     const tensor_t &c_tg_tile() const { return c_tg_tile_; }
444 
445     // Thread tiles for A, B, C.
a_thr_tile(bool is_relative=true) const446     tensor_t a_thr_tile(bool is_relative = true) const {
447         if (is_relative) return a_thr_tile_;
448         return a_tg_tile_.create_sub_tensor(a_thr_tile_);
449     }
450 
b_thr_tile(bool is_relative=true) const451     tensor_t b_thr_tile(bool is_relative = true) const {
452         if (is_relative) return b_thr_tile_;
453         return b_tg_tile_.create_sub_tensor(b_thr_tile_);
454     }
455 
c_thr_tile(bool is_relative=true) const456     tensor_t c_thr_tile(bool is_relative = true) const {
457         if (is_relative) return c_thr_tile_;
458         return c_tg_tile_.create_sub_tensor(c_thr_tile_);
459     }
460 
461     // Splits loop defined by `var` into two new loops based on `factor`.
462     // Before:
463     //     for (int var = 0; var < I; var++) { ... }
464     // After:
465     //   for (int outer_var = 0; outer_var < I / factor; outer_var++) {
466     //     for (int inner_var = 0; inner_var < factor; inner_var++) {
467     //       ...
468     //     }
469     //   }
split(const expr_t & var,int factor,expr_t & outer_var,expr_t & inner_var)470     void split(const expr_t &var, int factor, expr_t &outer_var,
471             expr_t &inner_var) {
472         auto &loop = find_loop(var);
473         ir_assert(loop.is_leaf()) << "Can't split, non-leaf loop.";
474 
475         int bound = to_cpp<int>(loop.bound());
476         if (loop.is_root() && (bound % factor != 0)) {
477             // Auto round-up bounds for the root loops.
478             bound = utils::rnd_up(bound, factor);
479             loop.set_bound(bound);
480         }
481 
482         ir_assert(bound % factor == 0) << "Can't split.";
483 
484         outer_var = create_var({var}, "outer");
485         inner_var = create_var({var}, "inner");
486         auto &outer_loop = create_loop(outer_var, bound / factor);
487         auto &inner_loop = create_loop(inner_var, factor);
488         loop.set_split(outer_loop, inner_loop);
489         set_bmnk_kind(outer_var, bmnk_kind(var));
490         set_bmnk_kind(inner_var, bmnk_kind(var));
491     }
492 
493     // Double split.
split(const expr_t & var,int factor0,int factor1,expr_t & outer_var0,expr_t & outer_var1,expr_t & inner_var)494     void split(const expr_t &var, int factor0, int factor1, expr_t &outer_var0,
495             expr_t &outer_var1, expr_t &inner_var) {
496         expr_t dummy_inner_var;
497         split(var, factor0, outer_var0, dummy_inner_var);
498         split(dummy_inner_var, factor1, outer_var1, inner_var);
499     }
500 
501     // Fuses loops defined by `v0` and `v1` variables, v0 - outer variable, v1
502     // - inner variable.
503     // Before:
504     //   for (int v0 = 0; v0 < V0; v0++) {
505     //     for (int v1 = 0; v1 < V1; v1++) { ... }
506     //   }
507     // After:
508     //   for (int v = 0; v < V0 * V1; v++) {
509     //       int v0 = v / V1;
510     //       int v1 = v % V1;
511     //       ...
512     //   }
fuse(const expr_t & v0,const expr_t & v1)513     expr_t fuse(const expr_t &v0, const expr_t &v1) { return fuse({v0, v1}); }
514 
515     // Double fuse, v0 - outermost variable, v2 - innermost variable.
fuse(const expr_t & v0,const expr_t & v1,const expr_t & v2)516     expr_t fuse(const expr_t &v0, const expr_t &v1, const expr_t &v2) {
517         return fuse({v0, v1, v2});
518     }
519 
520     // Fusion of multiple loops.
fuse(const std::vector<expr_t> & vars)521     expr_t fuse(const std::vector<expr_t> &vars) {
522         auto fused_var = create_var(vars, "fused");
523         expr_t fused_bound = find_loop(vars[0]).bound();
524         for (int i = 1; i < int(vars.size()); i++) {
525             auto &loop = find_loop(vars[i]);
526             fused_bound *= loop.bound();
527         }
528         auto &fused_loop = create_loop(fused_var, fused_bound);
529         std::vector<std::reference_wrapper<loop_t>> loop_refs;
530         for (auto &v : vars) {
531             loop_refs.push_back(find_loop(v));
532         }
533         fused_loop.set_fuse(loop_refs);
534         set_bmnk_kind(fused_var, bmnk_kind(vars));
535         return fused_var;
536     }
537 
538     // Sets unrolling factor for the given loop.
unroll(const expr_t & v,int factor)539     void unroll(const expr_t &v, int factor) {
540         auto &loop = find_loop(v);
541         loop.set_unroll_factor(factor);
542     }
543 
544     // Marks the loop defined by `v` as tensorized.
tensorize(const expr_t & v)545     void tensorize(const expr_t &v) {
546         auto &loop = find_loop(v);
547         loop.set_kind(loop_kind_t::tensorized);
548     }
549 
550     // Binds the loop defined by `v` to an external variable.
bind(const expr_t & v,const expr_t & bound_var)551     void bind(const expr_t &v, const expr_t &bound_var) {
552         auto &loop = find_loop(v);
553         ir_assert(loop.is_leaf()) << "Can't bind non-leaf loop: " << v;
554         loop.set_bound_var(bound_var);
555         loop.set_kind(bound_var_to_loop_kind(bound_var));
556 
557         int var_dim = bound_var_to_dim(bound_var);
558         ir_assert(to_cpp<int>(loop.bound()) == var_dim)
559                 << "Dimension size doesn't match.";
560     }
561 
562     // Reorders loops defiend by given variables.
reorder(const std::vector<expr_t> & ordered_vars)563     void reorder(const std::vector<expr_t> &ordered_vars) {
564         for (auto &v : ordered_vars) {
565             auto &loop = find_loop(v);
566             ir_assert(loop.is_leaf()) << "Can't reorder non-leaf loop: " << v;
567         }
568         std::vector<bool> found(vars_.size());
569         for (size_t i = 0; i < vars_.size(); i++) {
570             for (size_t j = 0; j < ordered_vars.size(); j++) {
571                 if (ordered_vars[j].is_same(vars_[i])) {
572                     found[i] = true;
573                     break;
574                 }
575             }
576         }
577 
578         for (size_t i = 0, j = 0; i < vars_.size(); i++) {
579             if (!found[i]) continue;
580             vars_[i] = ordered_vars[j++];
581         }
582     }
583 
with_thread_group_k_slicing() const584     bool with_thread_group_k_slicing() const {
585         ir_assert(is_finalized_);
586         dim_t k_thr = 1;
587         dim_t k_tg = 1;
588         for (int i = 0; i < bmnk_mapper_.ndims(abc_kind_t::a); i++) {
589             if (bmnk_mapper_.bmnk_kind(abc_kind_t::a, i) != bmnk_kind_t::k)
590                 continue;
591             k_thr *= a_thr_tile_(i);
592             k_tg *= a_tg_tile_(i);
593         }
594         ir_assert(k_tg % k_thr == 0);
595         return k_thr < k_tg;
596     }
597 
finalize()598     void finalize() {
599         sort_vars();
600         init_problem_tiles();
601         init_constraint_set();
602         is_finalized_ = true;
603     }
604 
605     // Returns a statement describing the loop nest of the schedule.
create_loop_nest(const stmt_t & _body=stmt_t ()) const606     stmt_t create_loop_nest(const stmt_t &_body = stmt_t()) const {
607         stmt_t body = _body;
608         for (auto it = vars_.rbegin(); it != vars_.rend(); it++) {
609             auto &var = *it;
610             auto &loop = find_loop(var);
611             if (!loop.is_leaf() || loop.is_tensorized() || loop.is_bound())
612                 continue;
613             body = maybe_inject_let_for_fused_vars(body, loop);
614             body = for_t::make(
615                     var, 0, loop.bound(), body, loop.unroll_factor());
616         }
617         return body;
618     }
619 
create_bind_stmt(const stmt_t & _body=stmt_t ()) const620     stmt_t create_bind_stmt(const stmt_t &_body = stmt_t()) const {
621         stmt_t body = _body;
622         for (auto it = vars_.rbegin(); it != vars_.rend(); it++) {
623             auto &var = *it;
624             auto &loop = find_loop(var);
625             if (!loop.is_leaf() || !loop.is_bound()) continue;
626             body = maybe_inject_let_for_fused_vars(body, loop);
627             body = let_t::make(var, loop.bound_var(), body);
628         }
629         return body;
630     }
631 
632 private:
633     // Describes split of a root loop into sub-loops.
634     class split_info_t {
635     public:
split_info_t(const loop_t * root_loop)636         split_info_t(const loop_t *root_loop) : root_loop_(root_loop) {}
637 
nloops() const638         int nloops() const { return int(loops_.size()); }
639 
add_sub_loop(const loop_t * loop,loop_kind_t loop_kind,int loop_level)640         void add_sub_loop(
641                 const loop_t *loop, loop_kind_t loop_kind, int loop_level) {
642             loops_.push_back(loop);
643             loop_kinds_.push_back(loop_kind);
644             loop_levels_.push_back(loop_level);
645         }
646 
647         // Verifies that sub-loops are ordered from outermost to innermost
648         // according to the schedule conventions. There are three set of loops:
649         // 1) Loops bound to kernel grid
650         // 2) Loops bound to thread group grid and serial loops
651         // 3) Tensorized loops
652         // Sets of loops must be ordered from outermost to innermost going from
653         // 1 to 3. Inside a set loops can be ordered arbitrarily.
is_valid() const654         bool is_valid() const {
655             auto get_loop_key = [&](int loop_idx) {
656                 switch (loop_kinds_[loop_idx]) {
657                     case loop_kind_t::kernel_grid: return -1;
658                     case loop_kind_t::tg_grid:
659                     case loop_kind_t::serial: return loop_levels_[loop_idx];
660                     case loop_kind_t::tensorized:
661                         return std::numeric_limits<int>::max();
662                     default: ir_error_not_expected();
663                 }
664                 return -1;
665             };
666             int prev_key = -1;
667             for (int i = 0; i < nloops(); i++) {
668                 int key = get_loop_key(i);
669                 if (key < prev_key) return false;
670                 prev_key = key;
671             }
672             return true;
673         }
674 
675         // Returns total extent of all loops at a given tile level.
dim(tile_level_t tile_level) const676         dim_t dim(tile_level_t tile_level) const {
677             dim_t ret = 1;
678             for (int i = 0; i < nloops(); i++) {
679                 switch (loop_kinds_[i]) {
680                     case loop_kind_t::kernel_grid:
681                     case loop_kind_t::serial: continue;
682                     case loop_kind_t::tg_grid:
683                         if (tile_level == tile_level_t::thread) continue;
684                         break;
685                     case loop_kind_t::tensorized: break;
686                     default: ir_error_not_expected();
687                 }
688                 ret *= to_cpp<dim_t>(loops_[i]->bound());
689             }
690             return ret;
691         }
692 
693         // Returns initial offset expressed in the outer variables at a given
694         // tile level.
start(const object_map_t<expr_t,loop_t> & all_loops,tile_level_t tile_level) const695         expr_t start(const object_map_t<expr_t, loop_t> &all_loops,
696                 tile_level_t tile_level) const {
697             auto ret = root_loop_->expand_var(all_loops, /*skip_fused=*/true);
698             for (int i = 0; i < nloops(); i++) {
699                 switch (loop_kinds_[i]) {
700                     case loop_kind_t::kernel_grid:
701                     case loop_kind_t::serial:
702                         if (tile_level == tile_level_t::thread) break;
703                         continue;
704                     case loop_kind_t::tg_grid:
705                         if (tile_level == tile_level_t::thread) continue;
706                         break;
707                     case loop_kind_t::tensorized: break;
708                     default: ir_error_not_expected();
709                 }
710                 ret = substitute(ret, loops_[i]->var(), expr_t(0));
711             }
712             return simplify(ret);
713         }
714 
715     private:
716         const loop_t *root_loop_;
717         std::vector<const loop_t *> loops_;
718         std::vector<loop_kind_t> loop_kinds_;
719         std::vector<int> loop_levels_;
720     };
721 
bmnk_kind(const std::vector<expr_t> & vars) const722     bmnk_kind_t bmnk_kind(const std::vector<expr_t> &vars) const {
723         if (vars.empty()) return bmnk_kind_t::undef;
724         if (vars.size() == 1) return bmnk_mapper_.bmnk_kind(vars[0]);
725         bmnk_kind_t ret = bmnk_kind(vars[0]);
726         for (size_t i = 1; i < vars.size(); i++) {
727             if (bmnk_kind(vars[i]) != ret) return bmnk_kind_t::undef;
728         }
729         return ret;
730     }
731 
set_bmnk_kind(const expr_t & var,bmnk_kind_t kind)732     void set_bmnk_kind(const expr_t &var, bmnk_kind_t kind) {
733         bmnk_mapper_.set_bmnk_kind(var, kind);
734     }
735 
set_view(const view_t & view,view_t & this_view)736     void set_view(const view_t &view, view_t &this_view) {
737         this_view = view;
738         // Create missing loops.
739         for (int i = 0; i < view.nvdims(); i++) {
740             auto &v = view.vvars()[i];
741             dim_t bound = view.vdims()[i];
742             if (has_loop(v)) {
743                 auto &loop = find_loop(v);
744                 ir_assert(bound == to_cpp<dim_t>(loop.bound()))
745                         << "Inconsistent sizes.";
746                 continue;
747             }
748             create_loop(v, bound, /*is_root=*/true);
749         }
750     }
751 
bound_var_to_loop_kind(const expr_t & v) const752     loop_kind_t bound_var_to_loop_kind(const expr_t &v) const {
753         for (int i = 0; i < kernel_grid_.ndims(); i++) {
754             if (kernel_grid_.idx(i).is_same(v)) return loop_kind_t::kernel_grid;
755         }
756         for (int i = 0; i < tg_grid_.ndims(); i++) {
757             if (tg_grid_.idx(i).is_same(v)) return loop_kind_t::tg_grid;
758         }
759         ir_error_not_expected() << "Unknown external variable: " << v;
760         return loop_kind_t::undef;
761     }
762 
bound_var_to_dim(const expr_t & v) const763     int bound_var_to_dim(const expr_t &v) const {
764         for (int i = 0; i < kernel_grid_.ndims(); i++) {
765             if (kernel_grid_.idx(i).is_same(v)) return kernel_grid_.dim(i);
766         }
767         for (int i = 0; i < tg_grid_.ndims(); i++) {
768             if (tg_grid_.idx(i).is_same(v)) return tg_grid_.dim(i);
769         }
770         ir_error_not_expected() << "Unknown external variable: " << v;
771         return -1;
772     }
773 
has_loop(const expr_t & var) const774     bool has_loop(const expr_t &var) const {
775         auto it = loops_.find(var);
776         return it != loops_.end();
777     }
778 
find_loop(const expr_t & var) const779     const loop_t &find_loop(const expr_t &var) const {
780         ir_assert(has_loop(var)) << "Var not found: " << var;
781         return loops_.at(var);
782     }
783 
find_loop(const expr_t & var)784     loop_t &find_loop(const expr_t &var) {
785         ir_assert(has_loop(var)) << "Var not found: " << var;
786         return loops_[var];
787     }
788 
loop_level(const expr_t & var) const789     int loop_level(const expr_t &var) const {
790         for (int i = 0; i < int(vars_.size()); i++) {
791             if (vars_[i].is_same(var)) return i;
792         }
793         return -1;
794     }
795 
create_loop(const expr_t & var,const expr_t & bound,bool is_root=false)796     loop_t &create_loop(
797             const expr_t &var, const expr_t &bound, bool is_root = false) {
798         loop_t loop(var, bound, is_root);
799         auto ret = loops_.insert({var, loop});
800         ir_assert(ret.second) << "Variable already exists: " << var;
801         vars_.push_back(var);
802         return ret.first->second;
803     }
804 
strip_suffix(const std::string & s,const std::string & suffix)805     static std::string strip_suffix(
806             const std::string &s, const std::string &suffix) {
807         auto pos = s.find(suffix);
808         if (pos == std::string::npos) return s;
809         if (pos + suffix.length() != s.length()) return s;
810         return s.substr(0, pos);
811     }
812 
create_var(const std::vector<expr_t> & vars,const std::string & suffix)813     static expr_t create_var(
814             const std::vector<expr_t> &vars, const std::string &suffix) {
815         std::string var_name;
816         for (auto &v : vars) {
817             auto name = strip_suffix(v.as<var_t>().name, "_idx");
818             var_name += name + "_";
819         }
820         var_name += suffix;
821         return var_t::make(type_t::s32(), var_name);
822     }
823 
get_var_key(const expr_t & v) const824     int get_var_key(const expr_t &v) const {
825         int key_max = std::numeric_limits<int>::max();
826         auto &loop = find_loop(v);
827         if (!loop.is_leaf()) return key_max;
828         // Loops bound to the kernel grid.
829         if (loop.is_kernel_grid()) {
830             return kernel_grid_.ndims()
831                     - kernel_grid_.dim_idx(loop.bound_var());
832         }
833         // Loops bound to the thread group grid or serial loop.
834         if (loop.is_tg_grid() || loop.is_serial()) return 10;
835 
836         // Tensorized loops are the innermost.
837         if (loop.is_tensorized()) return key_max - 1;
838         ir_error_not_expected() << "Unknown loop";
839         return -1;
840     }
841 
sort_vars()842     void sort_vars() {
843         std::stable_sort(vars_.end(), vars_.end(),
844                 [&](const expr_t &a_var, const expr_t &b_var) {
845                     int a_key = get_var_key(a_var);
846                     int b_key = get_var_key(b_var);
847                     return a_key < b_key;
848                 });
849     }
850 
init_problem_tiles()851     void init_problem_tiles() {
852         object_map_t<expr_t, split_info_t> split_infos;
853         for (auto *view : {&a_view_, &b_view_, &c_view_}) {
854             for (auto &v : view->vvars()) {
855                 if (split_infos.count(v) > 0) continue;
856                 split_infos.insert({v, get_split_info(v)});
857             }
858         }
859         a_tg_tile_ = compute_problem_tile(
860                 a_view_.vvars(), split_infos, tile_level_t::thread_group);
861         b_tg_tile_ = compute_problem_tile(
862                 b_view_.vvars(), split_infos, tile_level_t::thread_group);
863         c_tg_tile_ = compute_problem_tile(
864                 c_view_.vvars(), split_infos, tile_level_t::thread_group);
865         a_thr_tile_ = compute_problem_tile(
866                 a_view_.vvars(), split_infos, tile_level_t::thread);
867         b_thr_tile_ = compute_problem_tile(
868                 b_view_.vvars(), split_infos, tile_level_t::thread);
869         c_thr_tile_ = compute_problem_tile(
870                 c_view_.vvars(), split_infos, tile_level_t::thread);
871     }
872 
init_constraint_set()873     void init_constraint_set() {
874         for (auto &v : vars_) {
875             auto &loop = find_loop(v);
876             if (loop.is_fused_parent()) {
877                 cset_->add_constraint(v >= 0);
878                 cset_->add_constraint(v < loop.bound());
879                 continue;
880             }
881             if (!loop.is_leaf()) continue;
882 
883             // Fused variables are used only to initialize fused parents.
884             if (loop.is_fused_child()) continue;
885 
886             if (loop.is_bound()) {
887                 cset_->add_constraint(v == loop.bound_var());
888                 continue;
889             }
890 
891             cset_->add_constraint(v >= 0);
892             cset_->add_constraint(v < loop.bound());
893         }
894     }
895 
get_split_info(const expr_t & root_var) const896     split_info_t get_split_info(const expr_t &root_var) const {
897         split_info_t ret(&find_loop(root_var));
898         std::function<void(const expr_t &)> walk_down;
899         walk_down = [&](const expr_t &v) {
900             auto &loop = find_loop(v);
901             if (loop.is_leaf() || loop.is_fused_parent()) {
902                 // Treat a fused var as leaf as it can't be split into other
903                 // vars.
904                 loop_kind_t kind = loop.kind();
905                 int level;
906                 if (loop.is_fused_parent()) {
907                     auto &child_var = loop.child_vars()[0];
908                     ir_assert(find_loop(child_var).is_leaf());
909                     kind = find_loop(child_var).kind();
910                     level = loop_level(child_var);
911                 } else {
912                     level = loop_level(v);
913                 }
914                 ret.add_sub_loop(&loop, kind, level);
915             } else if (loop.is_split_parent()) {
916                 walk_down(loop.child_vars()[0]);
917                 walk_down(loop.child_vars()[1]);
918             } else {
919                 ir_error_not_expected();
920             }
921         };
922         walk_down(root_var);
923         ir_assert(ret.is_valid()) << "Invalid loop nest.";
924         return ret;
925     }
926 
compute_problem_tile(const std::vector<expr_t> & vars,const object_map_t<expr_t,split_info_t> & split_infos,tile_level_t tile_level)927     tensor_t compute_problem_tile(const std::vector<expr_t> &vars,
928             const object_map_t<expr_t, split_info_t> &split_infos,
929             tile_level_t tile_level) {
930         std::vector<dim_t> tile_dims;
931         std::vector<expr_t> tile_start;
932         for (auto &v : vars) {
933             auto &split_info = split_infos.at(v);
934             tile_dims.push_back(split_info.dim(tile_level));
935             tile_start.push_back(split_info.start(loops_, tile_level));
936         }
937         return tensor_t(tile_dims, tile_start);
938     }
939 
maybe_inject_let_for_fused_vars(const stmt_t & _body,const loop_t & loop) const940     stmt_t maybe_inject_let_for_fused_vars(
941             const stmt_t &_body, const loop_t &loop) const {
942         auto body = _body;
943         if (!loop.is_leaf() || !loop.is_fused_child()) return body;
944         auto &pvars = loop.parent_vars();
945         for (auto it = pvars.rbegin(); it != pvars.rend(); it++) {
946             auto &ploop = find_loop(*it);
947             body = let_t::make(*it, ploop.expand_var(loops_), body);
948         }
949         return body;
950     }
951 
952     bool is_finalized_ = false;
953 
954     constraint_set_t *cset_;
955     grid_info_t kernel_grid_;
956     grid_info_t tg_grid_;
957 
958     // Loop indices, ordered from outermost to innermost.
959     std::vector<expr_t> vars_;
960 
961     object_map_t<expr_t, loop_t> loops_;
962 
963     bmnk_mapper_t bmnk_mapper_;
964 
965     // Full views for A, B, C.
966     view_t a_view_;
967     view_t b_view_;
968     view_t c_view_;
969 
970     // Thread group tiles for A, B, C.
971     tensor_t a_tg_tile_;
972     tensor_t b_tg_tile_;
973     tensor_t c_tg_tile_;
974 
975     // Thread tiles for A, B, C (relative to thread group tiles).
976     tensor_t a_thr_tile_;
977     tensor_t b_thr_tile_;
978     tensor_t c_thr_tile_;
979 };
980 
981 } // namespace jit
982 } // namespace gpu
983 } // namespace impl
984 } // namespace dnnl
985 
986 #endif
987