1 /*******************************************************************************
2 * Copyright 2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef GPU_JIT_CONV_GEMM_SCHEDULE_HPP
18 #define GPU_JIT_CONV_GEMM_SCHEDULE_HPP
19
20 #include <functional>
21 #include <limits>
22 #include <sstream>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 #include <initializer_list>
27
28 #include "gpu/jit/conv/ir.hpp"
29 #include "gpu/jit/conv/tensor.hpp"
30 #include "gpu/jit/conv/utils.hpp"
31
32 namespace dnnl {
33 namespace impl {
34 namespace gpu {
35 namespace jit {
36
37 // Used to describe semantics of a dimension in the GEMM context.
38 // GEMM operation is defined as C = A x B
39 // GEMM dimension kinds:
40 // - B: shared by all tensors A, B, C (batch dimension)
41 // - M: shared only by A and C
42 // - N: shared only by B and C
43 // - K: shared only by A and B (reduction dimension)
44 enum class bmnk_kind_t { undef = -1, b = 0, m = 1, n = 2, k = 3 };
45
46 enum class abc_kind_t { undef, a, b, c };
47
48 class bmnk_mapper_t {
49 public:
50 bmnk_mapper_t() = default;
51
bmnk_mapper_t(const object_map_t<expr_t,bmnk_kind_t> & bmnk_kinds)52 bmnk_mapper_t(const object_map_t<expr_t, bmnk_kind_t> &bmnk_kinds)
53 : bmnk_kinds_(bmnk_kinds) {}
54
bmnk_kind(const expr_t & var) const55 bmnk_kind_t bmnk_kind(const expr_t &var) const {
56 auto it = bmnk_kinds_.find(var);
57 if (it == bmnk_kinds_.end()) return bmnk_kind_t::undef;
58 return it->second;
59 }
60
bmnk_kind(abc_kind_t abc_kind,int dim_idx) const61 bmnk_kind_t bmnk_kind(abc_kind_t abc_kind, int dim_idx) const {
62 return bmnk_kind(var(abc_kind, dim_idx));
63 }
64
ndims(abc_kind_t abc_kind) const65 int ndims(abc_kind_t abc_kind) const {
66 return int(get_vars(abc_kind).size());
67 }
68
set_a_vars(const std::vector<expr_t> & vars)69 void set_a_vars(const std::vector<expr_t> &vars) { a_vars_ = vars; }
set_b_vars(const std::vector<expr_t> & vars)70 void set_b_vars(const std::vector<expr_t> &vars) { b_vars_ = vars; }
set_c_vars(const std::vector<expr_t> & vars)71 void set_c_vars(const std::vector<expr_t> &vars) { c_vars_ = vars; }
72
set_bmnk_kind(const expr_t & var,bmnk_kind_t bmnk_kind)73 void set_bmnk_kind(const expr_t &var, bmnk_kind_t bmnk_kind) {
74 auto ret = bmnk_kinds_.insert({var, bmnk_kind});
75 ir_assert(ret.second) << "Can't set variable twice: " << var;
76 }
77
var(abc_kind_t abc_kind,int dim_idx) const78 const expr_t &var(abc_kind_t abc_kind, int dim_idx) const {
79 return get_vars(abc_kind)[dim_idx];
80 }
81
dim_idx(abc_kind_t abc_kind,const expr_t & var) const82 int dim_idx(abc_kind_t abc_kind, const expr_t &var) const {
83 auto &vars = get_vars(abc_kind);
84 for (int i = 0; i < int(vars.size()); i++) {
85 if (vars[i].is_same(var)) return i;
86 }
87 return -1;
88 }
89
90 layout_t map_to_bmnk(abc_kind_t abc_kind,
91 const std::vector<bmnk_kind_t> &bmnk_kinds,
92 const view_t &view) const;
93
94 layout_t map_to_bmnk(abc_kind_t abc_kind,
95 const std::vector<bmnk_kind_t> &bmnk_kinds,
96 const layout_t &layout) const;
97
98 private:
get_vars(abc_kind_t abc_kind) const99 const std::vector<expr_t> &get_vars(abc_kind_t abc_kind) const {
100 switch (abc_kind) {
101 case abc_kind_t::a: return a_vars_;
102 case abc_kind_t::b: return b_vars_;
103 case abc_kind_t::c: return c_vars_;
104 default: ir_error_not_expected() << "Unknown ABC kind.";
105 }
106 return a_vars_;
107 }
108
get_vars(abc_kind_t abc_kind)109 std::vector<expr_t> &get_vars(abc_kind_t abc_kind) {
110 auto &vars
111 = const_cast<const bmnk_mapper_t *>(this)->get_vars(abc_kind);
112 return const_cast<std::vector<expr_t> &>(vars);
113 }
114
115 std::vector<expr_t> a_vars_;
116 std::vector<expr_t> b_vars_;
117 std::vector<expr_t> c_vars_;
118 object_map_t<expr_t, bmnk_kind_t> bmnk_kinds_;
119 };
120
121 class bmnk_block_mapper_t {
122 public:
bmnk_block_mapper_t(const bmnk_mapper_t & bmnk_mapper)123 bmnk_block_mapper_t(const bmnk_mapper_t &bmnk_mapper)
124 : bmnk_mapper_(bmnk_mapper) {}
125
push_blocks(abc_kind_t abc_kind,const std::vector<block_t> & blocks)126 void push_blocks(abc_kind_t abc_kind, const std::vector<block_t> &blocks) {
127 for (auto &b : blocks)
128 push_block(abc_kind, b);
129 }
130
131 void push_block(abc_kind_t abc_kind, const block_t &b);
132
133 layout_t map_from_bmnk(abc_kind_t abc_kind,
134 const std::vector<bmnk_kind_t> &bmnk_kinds,
135 const layout_t &bmnk_layout) const;
136
137 private:
pop_size_1_blocks(std::vector<block_t> & blocks)138 static void pop_size_1_blocks(std::vector<block_t> &blocks) {
139 while (!blocks.empty() && blocks.front().block == 1) {
140 blocks.erase(blocks.begin());
141 }
142 }
143
create_prb_blocks(abc_kind_t abc_kind,const std::vector<std::pair<abc_kind_t,block_t>> & mn_blocks) const144 std::vector<block_t> create_prb_blocks(abc_kind_t abc_kind,
145 const std::vector<std::pair<abc_kind_t, block_t>> &mn_blocks)
146 const {
147 std::vector<block_t> ret;
148 ret.reserve(mn_blocks.size());
149 for (auto &p : mn_blocks) {
150 auto b = p.second;
151 const auto &var = bmnk_mapper_.var(p.first, b.dim_idx);
152 b.dim_idx = bmnk_mapper_.dim_idx(abc_kind, var);
153 ret.push_back(b);
154 }
155 return ret;
156 }
157
158 bool pop_block(std::vector<block_t> &bmnk_blocks,
159 std::vector<block_t> &prb_blocks, const block_t &bmnk_block) const;
160
161 bmnk_mapper_t bmnk_mapper_;
162
163 // Ordered from innermost to outermost.
164 std::vector<std::pair<abc_kind_t, block_t>> m_blocks_;
165 std::vector<std::pair<abc_kind_t, block_t>> n_blocks_;
166 std::vector<std::pair<abc_kind_t, block_t>> k_blocks_;
167 };
168
169 enum class loop_kind_t : int {
170 undef,
171 kernel_grid, // Loop is bound to the kernel grid.
172 serial, // Loop is inside a thread (may be unrolled or just a regular loop).
173 tg_grid, // Loop is bound to the thread group grid.
174 tensorized, // Such loops are fully unrolled/vectorized and converted to blocked multiplication.
175 };
176
to_string(loop_kind_t kind)177 static std::string to_string(loop_kind_t kind) {
178 switch (kind) {
179 case loop_kind_t::undef: return "undef";
180 case loop_kind_t::kernel_grid: return "kernel_grid";
181 case loop_kind_t::serial: return "serial";
182 case loop_kind_t::tg_grid: return "tg_grid";
183 case loop_kind_t::tensorized: return "tensorized";
184 default: ir_error_not_expected();
185 }
186 return "unknown";
187 }
188
operator <<(std::ostream & out,loop_kind_t kind)189 inline std::ostream &operator<<(std::ostream &out, loop_kind_t kind) {
190 out << to_string(kind);
191 return out;
192 }
193
194 enum class tile_level_t { thread_group, thread };
195
196 class loop_t {
197 public:
loop_t()198 loop_t() : kind_(loop_kind_t::undef) {}
199
loop_t(const expr_t & var,const expr_t & bound,bool is_root)200 loop_t(const expr_t &var, const expr_t &bound, bool is_root)
201 : var_(var)
202 , kind_(loop_kind_t::serial)
203 , bound_(bound)
204 , is_root_(is_root) {}
205
var() const206 const expr_t &var() const { return var_; }
207
kind() const208 loop_kind_t kind() const { return kind_; }
209
set_kind(loop_kind_t kind)210 void set_kind(loop_kind_t kind) { kind_ = kind; }
211
unroll_factor() const212 int unroll_factor() const { return unroll_factor_; }
213
set_unroll_factor(int factor)214 void set_unroll_factor(int factor) { unroll_factor_ = factor; }
215
is_kernel_grid() const216 bool is_kernel_grid() const { return kind() == loop_kind_t::kernel_grid; }
217
is_serial() const218 bool is_serial() const { return kind() == loop_kind_t::serial; }
219
is_tg_grid() const220 bool is_tg_grid() const { return kind() == loop_kind_t::tg_grid; }
221
is_tensorized() const222 bool is_tensorized() const { return kind() == loop_kind_t::tensorized; }
223
bound() const224 const expr_t &bound() const { return bound_; }
225
set_bound(const expr_t & bound)226 void set_bound(const expr_t &bound) { bound_ = bound; }
227
is_bound() const228 bool is_bound() const { return !bound_var().is_empty(); }
229
bound_var() const230 const expr_t &bound_var() const { return bound_var_; }
231
set_bound_var(const expr_t & v)232 void set_bound_var(const expr_t &v) { bound_var_ = v; }
233
is_root() const234 bool is_root() const { return is_root_; }
235
236 // Returns true for loops that were neither split, nor fused with other loops.
is_leaf() const237 bool is_leaf() const { return is_leaf_; }
238
239 // Returns true if this loop was split into outer/inner loops.
is_split_parent() const240 bool is_split_parent() const { return is_split_parent_; }
241
242 // Returns true if this loop was the result of a split.
is_split_child() const243 bool is_split_child() const { return is_split_child_; }
244
245 // Returns true if this loop was fused with other loops.
is_fused_parent() const246 bool is_fused_parent() const { return is_fused_parent_; }
247
248 // Returns true if this loop was the result of a fusion.
is_fused_child() const249 bool is_fused_child() const { return is_fused_child_; }
250
parent_vars() const251 const std::vector<expr_t> &parent_vars() const { return parent_vars_; }
child_vars() const252 const std::vector<expr_t> &child_vars() const { return child_vars_; }
253
set_split(loop_t & outer_loop,loop_t & inner_loop)254 void set_split(loop_t &outer_loop, loop_t &inner_loop) {
255 outer_loop.parent_vars_.push_back(var());
256 child_vars_.push_back(outer_loop.var());
257 outer_loop.is_split_child_ = true;
258
259 inner_loop.parent_vars_.push_back(var());
260 child_vars_.push_back(inner_loop.var());
261 inner_loop.is_split_child_ = true;
262
263 is_split_parent_ = true;
264 is_leaf_ = false;
265 }
266
set_fuse(std::vector<std::reference_wrapper<loop_t>> & loops)267 void set_fuse(std::vector<std::reference_wrapper<loop_t>> &loops) {
268 for (auto &l_ref : loops) {
269 auto &l = l_ref.get();
270 parent_vars_.push_back(l.var());
271 l.child_vars_.push_back(var());
272 l.is_fused_parent_ = true;
273 l.is_leaf_ = false;
274 }
275 is_fused_child_ = true;
276 }
277
278 // Returns a loop variable expressed in the variables of the leaf loops.
expand_var(const object_map_t<expr_t,loop_t> & all_loops,bool skip_fused=false) const279 expr_t expand_var(const object_map_t<expr_t, loop_t> &all_loops,
280 bool skip_fused = false) const {
281 if (is_leaf()) return var();
282 if (is_split_parent()) {
283 ir_assert(child_vars_.size() == 2);
284 auto &outer_loop = all_loops.at(child_vars_[0]);
285 auto &inner_loop = all_loops.at(child_vars_[1]);
286 auto outer_var = outer_loop.expand_var(all_loops, skip_fused);
287 auto inner_var = inner_loop.expand_var(all_loops, skip_fused);
288 return outer_var * inner_loop.bound() + inner_var;
289 }
290 if (is_fused_parent()) {
291 if (skip_fused) return var();
292 // Example of "unpacking":
293 // fused_var = (a * b * c * d)
294 // b = (fused_var / (D * C)) % B
295 ir_assert(child_vars_.size() == 1);
296 auto &fused_loop = all_loops.at(child_vars_[0]);
297 int nvars = int(fused_loop.parent_vars_.size());
298 expr_t denom = 1;
299 for (int i = nvars - 1; i >= 0; i--) {
300 auto &v = fused_loop.parent_vars_[i];
301 auto &child_loop = all_loops.at(v);
302 auto &bound = child_loop.bound();
303 if (v.is_same(var())) {
304 auto e = fused_loop.expand_var(all_loops, skip_fused)
305 / denom;
306 return (i == 0 ? e : e % bound);
307 }
308 denom *= bound;
309 }
310 }
311
312 ir_error_not_expected();
313 return expr_t();
314 }
315
str() const316 std::string str() const {
317 using namespace ir_utils;
318
319 std::ostringstream oss;
320 oss << "var: " << var_;
321 oss << " bound: " << bound_;
322 oss << " kind: " << kind_;
323 if (unroll_factor_ != 1) oss << " unroll: " << unroll_factor_;
324 std::vector<std::string> props;
325 if (is_root()) props.push_back("root");
326 if (is_fused_child()) props.push_back("fused");
327 if (is_split_parent()) props.push_back("split");
328 oss << "(" << make_seq_print_helper(props, ", ") << ")";
329 return oss.str();
330 }
331
332 IR_DEFINE_DUMP()
333
334 private:
335 expr_t var_; // Loop index variable.
336 loop_kind_t kind_; // Loop kind.
337 expr_t bound_; // Loop bound (exclusive).
338
339 expr_t bound_var_; // External variable this loop bound to.
340
341 int unroll_factor_ = 1;
342
343 bool is_root_ = false;
344 bool is_leaf_ = true;
345
346 bool is_split_parent_ = false;
347 bool is_split_child_ = false;
348
349 bool is_fused_parent_ = false;
350 bool is_fused_child_ = false;
351
352 // For variables there were split or fused.
353 // Fusion: i x j -> k
354 // i.child_vars _= [k]
355 // j.child_vars _= [k]
356 // k.parent_vars_ = [i, j]
357 // Split: i -> j x k
358 // i.child_vars_ = [j, k]
359 // j.parent_vars_ = [i]
360 // k.parent_vars_ = [i]
361 std::vector<expr_t> parent_vars_;
362 std::vector<expr_t> child_vars_;
363 };
364
365 // Defines GEMM computation including:
366 // - Blocking scheme (order of loops, tiles per thread group/thread)
367 // - Mapping of problem dimensions to GEMM dimensions (BMNK)
368 class gemm_schedule_t {
369 public:
370 gemm_schedule_t() = default;
371
gemm_schedule_t(constraint_set_t & cset,const grid_info_t & kernel_grid,const grid_info_t & tg_grid)372 gemm_schedule_t(constraint_set_t &cset, const grid_info_t &kernel_grid,
373 const grid_info_t &tg_grid)
374 : cset_(&cset), kernel_grid_(kernel_grid), tg_grid_(tg_grid) {}
375
kernel_grid() const376 const grid_info_t &kernel_grid() const { return kernel_grid_; }
tg_grid() const377 const grid_info_t &tg_grid() const { return tg_grid_; }
378
bmnk_kind(const expr_t & var) const379 bmnk_kind_t bmnk_kind(const expr_t &var) const {
380 return bmnk_kind(std::vector<expr_t>({var}));
381 }
382
bmnk_mapper() const383 const bmnk_mapper_t &bmnk_mapper() const { return bmnk_mapper_; }
384
set_b_vars(const std::vector<expr_t> & vars)385 void set_b_vars(const std::vector<expr_t> &vars) {
386 for (auto &v : vars)
387 set_bmnk_kind(v, bmnk_kind_t::b);
388 }
389
set_m_vars(const std::vector<expr_t> & vars)390 void set_m_vars(const std::vector<expr_t> &vars) {
391 for (auto &v : vars)
392 set_bmnk_kind(v, bmnk_kind_t::m);
393 }
394
set_n_vars(const std::vector<expr_t> & vars)395 void set_n_vars(const std::vector<expr_t> &vars) {
396 for (auto &v : vars)
397 set_bmnk_kind(v, bmnk_kind_t::n);
398 }
399
set_k_vars(const std::vector<expr_t> & vars)400 void set_k_vars(const std::vector<expr_t> &vars) {
401 for (auto &v : vars)
402 set_bmnk_kind(v, bmnk_kind_t::k);
403 }
404
405 // A/B/C views in the problem notation.
a_view() const406 const view_t &a_view() const { return a_view_; }
b_view() const407 const view_t &b_view() const { return b_view_; }
c_view() const408 const view_t &c_view() const { return c_view_; }
409
set_a_view(const view_t & v)410 void set_a_view(const view_t &v) {
411 set_view(v, a_view_);
412 bmnk_mapper_.set_a_vars(a_view_.vvars());
413 }
414
set_b_view(const view_t & v)415 void set_b_view(const view_t &v) {
416 set_view(v, b_view_);
417 bmnk_mapper_.set_b_vars(b_view_.vvars());
418 }
419
set_c_view(const view_t & v)420 void set_c_view(const view_t &v) {
421 set_view(v, c_view_);
422 bmnk_mapper_.set_c_vars(c_view_.vvars());
423 }
424
a_tg_view() const425 view_t a_tg_view() const {
426 ir_assert(is_finalized_);
427 return a_view_.create_sub_view(a_tg_tile_);
428 }
429
b_tg_view() const430 view_t b_tg_view() const {
431 ir_assert(is_finalized_);
432 return b_view_.create_sub_view(b_tg_tile_);
433 }
434
c_tg_view() const435 view_t c_tg_view() const {
436 ir_assert(is_finalized_);
437 return c_view_.create_sub_view(c_tg_tile_);
438 }
439
440 // Thread group tiles for A, B, C.
a_tg_tile() const441 const tensor_t &a_tg_tile() const { return a_tg_tile_; }
b_tg_tile() const442 const tensor_t &b_tg_tile() const { return b_tg_tile_; }
c_tg_tile() const443 const tensor_t &c_tg_tile() const { return c_tg_tile_; }
444
445 // Thread tiles for A, B, C.
a_thr_tile(bool is_relative=true) const446 tensor_t a_thr_tile(bool is_relative = true) const {
447 if (is_relative) return a_thr_tile_;
448 return a_tg_tile_.create_sub_tensor(a_thr_tile_);
449 }
450
b_thr_tile(bool is_relative=true) const451 tensor_t b_thr_tile(bool is_relative = true) const {
452 if (is_relative) return b_thr_tile_;
453 return b_tg_tile_.create_sub_tensor(b_thr_tile_);
454 }
455
c_thr_tile(bool is_relative=true) const456 tensor_t c_thr_tile(bool is_relative = true) const {
457 if (is_relative) return c_thr_tile_;
458 return c_tg_tile_.create_sub_tensor(c_thr_tile_);
459 }
460
461 // Splits loop defined by `var` into two new loops based on `factor`.
462 // Before:
463 // for (int var = 0; var < I; var++) { ... }
464 // After:
465 // for (int outer_var = 0; outer_var < I / factor; outer_var++) {
466 // for (int inner_var = 0; inner_var < factor; inner_var++) {
467 // ...
468 // }
469 // }
split(const expr_t & var,int factor,expr_t & outer_var,expr_t & inner_var)470 void split(const expr_t &var, int factor, expr_t &outer_var,
471 expr_t &inner_var) {
472 auto &loop = find_loop(var);
473 ir_assert(loop.is_leaf()) << "Can't split, non-leaf loop.";
474
475 int bound = to_cpp<int>(loop.bound());
476 if (loop.is_root() && (bound % factor != 0)) {
477 // Auto round-up bounds for the root loops.
478 bound = utils::rnd_up(bound, factor);
479 loop.set_bound(bound);
480 }
481
482 ir_assert(bound % factor == 0) << "Can't split.";
483
484 outer_var = create_var({var}, "outer");
485 inner_var = create_var({var}, "inner");
486 auto &outer_loop = create_loop(outer_var, bound / factor);
487 auto &inner_loop = create_loop(inner_var, factor);
488 loop.set_split(outer_loop, inner_loop);
489 set_bmnk_kind(outer_var, bmnk_kind(var));
490 set_bmnk_kind(inner_var, bmnk_kind(var));
491 }
492
493 // Double split.
split(const expr_t & var,int factor0,int factor1,expr_t & outer_var0,expr_t & outer_var1,expr_t & inner_var)494 void split(const expr_t &var, int factor0, int factor1, expr_t &outer_var0,
495 expr_t &outer_var1, expr_t &inner_var) {
496 expr_t dummy_inner_var;
497 split(var, factor0, outer_var0, dummy_inner_var);
498 split(dummy_inner_var, factor1, outer_var1, inner_var);
499 }
500
501 // Fuses loops defined by `v0` and `v1` variables, v0 - outer variable, v1
502 // - inner variable.
503 // Before:
504 // for (int v0 = 0; v0 < V0; v0++) {
505 // for (int v1 = 0; v1 < V1; v1++) { ... }
506 // }
507 // After:
508 // for (int v = 0; v < V0 * V1; v++) {
509 // int v0 = v / V1;
510 // int v1 = v % V1;
511 // ...
512 // }
fuse(const expr_t & v0,const expr_t & v1)513 expr_t fuse(const expr_t &v0, const expr_t &v1) { return fuse({v0, v1}); }
514
515 // Double fuse, v0 - outermost variable, v2 - innermost variable.
fuse(const expr_t & v0,const expr_t & v1,const expr_t & v2)516 expr_t fuse(const expr_t &v0, const expr_t &v1, const expr_t &v2) {
517 return fuse({v0, v1, v2});
518 }
519
520 // Fusion of multiple loops.
fuse(const std::vector<expr_t> & vars)521 expr_t fuse(const std::vector<expr_t> &vars) {
522 auto fused_var = create_var(vars, "fused");
523 expr_t fused_bound = find_loop(vars[0]).bound();
524 for (int i = 1; i < int(vars.size()); i++) {
525 auto &loop = find_loop(vars[i]);
526 fused_bound *= loop.bound();
527 }
528 auto &fused_loop = create_loop(fused_var, fused_bound);
529 std::vector<std::reference_wrapper<loop_t>> loop_refs;
530 for (auto &v : vars) {
531 loop_refs.push_back(find_loop(v));
532 }
533 fused_loop.set_fuse(loop_refs);
534 set_bmnk_kind(fused_var, bmnk_kind(vars));
535 return fused_var;
536 }
537
538 // Sets unrolling factor for the given loop.
unroll(const expr_t & v,int factor)539 void unroll(const expr_t &v, int factor) {
540 auto &loop = find_loop(v);
541 loop.set_unroll_factor(factor);
542 }
543
544 // Marks the loop defined by `v` as tensorized.
tensorize(const expr_t & v)545 void tensorize(const expr_t &v) {
546 auto &loop = find_loop(v);
547 loop.set_kind(loop_kind_t::tensorized);
548 }
549
550 // Binds the loop defined by `v` to an external variable.
bind(const expr_t & v,const expr_t & bound_var)551 void bind(const expr_t &v, const expr_t &bound_var) {
552 auto &loop = find_loop(v);
553 ir_assert(loop.is_leaf()) << "Can't bind non-leaf loop: " << v;
554 loop.set_bound_var(bound_var);
555 loop.set_kind(bound_var_to_loop_kind(bound_var));
556
557 int var_dim = bound_var_to_dim(bound_var);
558 ir_assert(to_cpp<int>(loop.bound()) == var_dim)
559 << "Dimension size doesn't match.";
560 }
561
562 // Reorders loops defiend by given variables.
reorder(const std::vector<expr_t> & ordered_vars)563 void reorder(const std::vector<expr_t> &ordered_vars) {
564 for (auto &v : ordered_vars) {
565 auto &loop = find_loop(v);
566 ir_assert(loop.is_leaf()) << "Can't reorder non-leaf loop: " << v;
567 }
568 std::vector<bool> found(vars_.size());
569 for (size_t i = 0; i < vars_.size(); i++) {
570 for (size_t j = 0; j < ordered_vars.size(); j++) {
571 if (ordered_vars[j].is_same(vars_[i])) {
572 found[i] = true;
573 break;
574 }
575 }
576 }
577
578 for (size_t i = 0, j = 0; i < vars_.size(); i++) {
579 if (!found[i]) continue;
580 vars_[i] = ordered_vars[j++];
581 }
582 }
583
with_thread_group_k_slicing() const584 bool with_thread_group_k_slicing() const {
585 ir_assert(is_finalized_);
586 dim_t k_thr = 1;
587 dim_t k_tg = 1;
588 for (int i = 0; i < bmnk_mapper_.ndims(abc_kind_t::a); i++) {
589 if (bmnk_mapper_.bmnk_kind(abc_kind_t::a, i) != bmnk_kind_t::k)
590 continue;
591 k_thr *= a_thr_tile_(i);
592 k_tg *= a_tg_tile_(i);
593 }
594 ir_assert(k_tg % k_thr == 0);
595 return k_thr < k_tg;
596 }
597
finalize()598 void finalize() {
599 sort_vars();
600 init_problem_tiles();
601 init_constraint_set();
602 is_finalized_ = true;
603 }
604
605 // Returns a statement describing the loop nest of the schedule.
create_loop_nest(const stmt_t & _body=stmt_t ()) const606 stmt_t create_loop_nest(const stmt_t &_body = stmt_t()) const {
607 stmt_t body = _body;
608 for (auto it = vars_.rbegin(); it != vars_.rend(); it++) {
609 auto &var = *it;
610 auto &loop = find_loop(var);
611 if (!loop.is_leaf() || loop.is_tensorized() || loop.is_bound())
612 continue;
613 body = maybe_inject_let_for_fused_vars(body, loop);
614 body = for_t::make(
615 var, 0, loop.bound(), body, loop.unroll_factor());
616 }
617 return body;
618 }
619
create_bind_stmt(const stmt_t & _body=stmt_t ()) const620 stmt_t create_bind_stmt(const stmt_t &_body = stmt_t()) const {
621 stmt_t body = _body;
622 for (auto it = vars_.rbegin(); it != vars_.rend(); it++) {
623 auto &var = *it;
624 auto &loop = find_loop(var);
625 if (!loop.is_leaf() || !loop.is_bound()) continue;
626 body = maybe_inject_let_for_fused_vars(body, loop);
627 body = let_t::make(var, loop.bound_var(), body);
628 }
629 return body;
630 }
631
632 private:
633 // Describes split of a root loop into sub-loops.
634 class split_info_t {
635 public:
split_info_t(const loop_t * root_loop)636 split_info_t(const loop_t *root_loop) : root_loop_(root_loop) {}
637
nloops() const638 int nloops() const { return int(loops_.size()); }
639
add_sub_loop(const loop_t * loop,loop_kind_t loop_kind,int loop_level)640 void add_sub_loop(
641 const loop_t *loop, loop_kind_t loop_kind, int loop_level) {
642 loops_.push_back(loop);
643 loop_kinds_.push_back(loop_kind);
644 loop_levels_.push_back(loop_level);
645 }
646
647 // Verifies that sub-loops are ordered from outermost to innermost
648 // according to the schedule conventions. There are three set of loops:
649 // 1) Loops bound to kernel grid
650 // 2) Loops bound to thread group grid and serial loops
651 // 3) Tensorized loops
652 // Sets of loops must be ordered from outermost to innermost going from
653 // 1 to 3. Inside a set loops can be ordered arbitrarily.
is_valid() const654 bool is_valid() const {
655 auto get_loop_key = [&](int loop_idx) {
656 switch (loop_kinds_[loop_idx]) {
657 case loop_kind_t::kernel_grid: return -1;
658 case loop_kind_t::tg_grid:
659 case loop_kind_t::serial: return loop_levels_[loop_idx];
660 case loop_kind_t::tensorized:
661 return std::numeric_limits<int>::max();
662 default: ir_error_not_expected();
663 }
664 return -1;
665 };
666 int prev_key = -1;
667 for (int i = 0; i < nloops(); i++) {
668 int key = get_loop_key(i);
669 if (key < prev_key) return false;
670 prev_key = key;
671 }
672 return true;
673 }
674
675 // Returns total extent of all loops at a given tile level.
dim(tile_level_t tile_level) const676 dim_t dim(tile_level_t tile_level) const {
677 dim_t ret = 1;
678 for (int i = 0; i < nloops(); i++) {
679 switch (loop_kinds_[i]) {
680 case loop_kind_t::kernel_grid:
681 case loop_kind_t::serial: continue;
682 case loop_kind_t::tg_grid:
683 if (tile_level == tile_level_t::thread) continue;
684 break;
685 case loop_kind_t::tensorized: break;
686 default: ir_error_not_expected();
687 }
688 ret *= to_cpp<dim_t>(loops_[i]->bound());
689 }
690 return ret;
691 }
692
693 // Returns initial offset expressed in the outer variables at a given
694 // tile level.
start(const object_map_t<expr_t,loop_t> & all_loops,tile_level_t tile_level) const695 expr_t start(const object_map_t<expr_t, loop_t> &all_loops,
696 tile_level_t tile_level) const {
697 auto ret = root_loop_->expand_var(all_loops, /*skip_fused=*/true);
698 for (int i = 0; i < nloops(); i++) {
699 switch (loop_kinds_[i]) {
700 case loop_kind_t::kernel_grid:
701 case loop_kind_t::serial:
702 if (tile_level == tile_level_t::thread) break;
703 continue;
704 case loop_kind_t::tg_grid:
705 if (tile_level == tile_level_t::thread) continue;
706 break;
707 case loop_kind_t::tensorized: break;
708 default: ir_error_not_expected();
709 }
710 ret = substitute(ret, loops_[i]->var(), expr_t(0));
711 }
712 return simplify(ret);
713 }
714
715 private:
716 const loop_t *root_loop_;
717 std::vector<const loop_t *> loops_;
718 std::vector<loop_kind_t> loop_kinds_;
719 std::vector<int> loop_levels_;
720 };
721
bmnk_kind(const std::vector<expr_t> & vars) const722 bmnk_kind_t bmnk_kind(const std::vector<expr_t> &vars) const {
723 if (vars.empty()) return bmnk_kind_t::undef;
724 if (vars.size() == 1) return bmnk_mapper_.bmnk_kind(vars[0]);
725 bmnk_kind_t ret = bmnk_kind(vars[0]);
726 for (size_t i = 1; i < vars.size(); i++) {
727 if (bmnk_kind(vars[i]) != ret) return bmnk_kind_t::undef;
728 }
729 return ret;
730 }
731
set_bmnk_kind(const expr_t & var,bmnk_kind_t kind)732 void set_bmnk_kind(const expr_t &var, bmnk_kind_t kind) {
733 bmnk_mapper_.set_bmnk_kind(var, kind);
734 }
735
set_view(const view_t & view,view_t & this_view)736 void set_view(const view_t &view, view_t &this_view) {
737 this_view = view;
738 // Create missing loops.
739 for (int i = 0; i < view.nvdims(); i++) {
740 auto &v = view.vvars()[i];
741 dim_t bound = view.vdims()[i];
742 if (has_loop(v)) {
743 auto &loop = find_loop(v);
744 ir_assert(bound == to_cpp<dim_t>(loop.bound()))
745 << "Inconsistent sizes.";
746 continue;
747 }
748 create_loop(v, bound, /*is_root=*/true);
749 }
750 }
751
bound_var_to_loop_kind(const expr_t & v) const752 loop_kind_t bound_var_to_loop_kind(const expr_t &v) const {
753 for (int i = 0; i < kernel_grid_.ndims(); i++) {
754 if (kernel_grid_.idx(i).is_same(v)) return loop_kind_t::kernel_grid;
755 }
756 for (int i = 0; i < tg_grid_.ndims(); i++) {
757 if (tg_grid_.idx(i).is_same(v)) return loop_kind_t::tg_grid;
758 }
759 ir_error_not_expected() << "Unknown external variable: " << v;
760 return loop_kind_t::undef;
761 }
762
bound_var_to_dim(const expr_t & v) const763 int bound_var_to_dim(const expr_t &v) const {
764 for (int i = 0; i < kernel_grid_.ndims(); i++) {
765 if (kernel_grid_.idx(i).is_same(v)) return kernel_grid_.dim(i);
766 }
767 for (int i = 0; i < tg_grid_.ndims(); i++) {
768 if (tg_grid_.idx(i).is_same(v)) return tg_grid_.dim(i);
769 }
770 ir_error_not_expected() << "Unknown external variable: " << v;
771 return -1;
772 }
773
has_loop(const expr_t & var) const774 bool has_loop(const expr_t &var) const {
775 auto it = loops_.find(var);
776 return it != loops_.end();
777 }
778
find_loop(const expr_t & var) const779 const loop_t &find_loop(const expr_t &var) const {
780 ir_assert(has_loop(var)) << "Var not found: " << var;
781 return loops_.at(var);
782 }
783
find_loop(const expr_t & var)784 loop_t &find_loop(const expr_t &var) {
785 ir_assert(has_loop(var)) << "Var not found: " << var;
786 return loops_[var];
787 }
788
loop_level(const expr_t & var) const789 int loop_level(const expr_t &var) const {
790 for (int i = 0; i < int(vars_.size()); i++) {
791 if (vars_[i].is_same(var)) return i;
792 }
793 return -1;
794 }
795
create_loop(const expr_t & var,const expr_t & bound,bool is_root=false)796 loop_t &create_loop(
797 const expr_t &var, const expr_t &bound, bool is_root = false) {
798 loop_t loop(var, bound, is_root);
799 auto ret = loops_.insert({var, loop});
800 ir_assert(ret.second) << "Variable already exists: " << var;
801 vars_.push_back(var);
802 return ret.first->second;
803 }
804
strip_suffix(const std::string & s,const std::string & suffix)805 static std::string strip_suffix(
806 const std::string &s, const std::string &suffix) {
807 auto pos = s.find(suffix);
808 if (pos == std::string::npos) return s;
809 if (pos + suffix.length() != s.length()) return s;
810 return s.substr(0, pos);
811 }
812
create_var(const std::vector<expr_t> & vars,const std::string & suffix)813 static expr_t create_var(
814 const std::vector<expr_t> &vars, const std::string &suffix) {
815 std::string var_name;
816 for (auto &v : vars) {
817 auto name = strip_suffix(v.as<var_t>().name, "_idx");
818 var_name += name + "_";
819 }
820 var_name += suffix;
821 return var_t::make(type_t::s32(), var_name);
822 }
823
get_var_key(const expr_t & v) const824 int get_var_key(const expr_t &v) const {
825 int key_max = std::numeric_limits<int>::max();
826 auto &loop = find_loop(v);
827 if (!loop.is_leaf()) return key_max;
828 // Loops bound to the kernel grid.
829 if (loop.is_kernel_grid()) {
830 return kernel_grid_.ndims()
831 - kernel_grid_.dim_idx(loop.bound_var());
832 }
833 // Loops bound to the thread group grid or serial loop.
834 if (loop.is_tg_grid() || loop.is_serial()) return 10;
835
836 // Tensorized loops are the innermost.
837 if (loop.is_tensorized()) return key_max - 1;
838 ir_error_not_expected() << "Unknown loop";
839 return -1;
840 }
841
sort_vars()842 void sort_vars() {
843 std::stable_sort(vars_.end(), vars_.end(),
844 [&](const expr_t &a_var, const expr_t &b_var) {
845 int a_key = get_var_key(a_var);
846 int b_key = get_var_key(b_var);
847 return a_key < b_key;
848 });
849 }
850
init_problem_tiles()851 void init_problem_tiles() {
852 object_map_t<expr_t, split_info_t> split_infos;
853 for (auto *view : {&a_view_, &b_view_, &c_view_}) {
854 for (auto &v : view->vvars()) {
855 if (split_infos.count(v) > 0) continue;
856 split_infos.insert({v, get_split_info(v)});
857 }
858 }
859 a_tg_tile_ = compute_problem_tile(
860 a_view_.vvars(), split_infos, tile_level_t::thread_group);
861 b_tg_tile_ = compute_problem_tile(
862 b_view_.vvars(), split_infos, tile_level_t::thread_group);
863 c_tg_tile_ = compute_problem_tile(
864 c_view_.vvars(), split_infos, tile_level_t::thread_group);
865 a_thr_tile_ = compute_problem_tile(
866 a_view_.vvars(), split_infos, tile_level_t::thread);
867 b_thr_tile_ = compute_problem_tile(
868 b_view_.vvars(), split_infos, tile_level_t::thread);
869 c_thr_tile_ = compute_problem_tile(
870 c_view_.vvars(), split_infos, tile_level_t::thread);
871 }
872
init_constraint_set()873 void init_constraint_set() {
874 for (auto &v : vars_) {
875 auto &loop = find_loop(v);
876 if (loop.is_fused_parent()) {
877 cset_->add_constraint(v >= 0);
878 cset_->add_constraint(v < loop.bound());
879 continue;
880 }
881 if (!loop.is_leaf()) continue;
882
883 // Fused variables are used only to initialize fused parents.
884 if (loop.is_fused_child()) continue;
885
886 if (loop.is_bound()) {
887 cset_->add_constraint(v == loop.bound_var());
888 continue;
889 }
890
891 cset_->add_constraint(v >= 0);
892 cset_->add_constraint(v < loop.bound());
893 }
894 }
895
get_split_info(const expr_t & root_var) const896 split_info_t get_split_info(const expr_t &root_var) const {
897 split_info_t ret(&find_loop(root_var));
898 std::function<void(const expr_t &)> walk_down;
899 walk_down = [&](const expr_t &v) {
900 auto &loop = find_loop(v);
901 if (loop.is_leaf() || loop.is_fused_parent()) {
902 // Treat a fused var as leaf as it can't be split into other
903 // vars.
904 loop_kind_t kind = loop.kind();
905 int level;
906 if (loop.is_fused_parent()) {
907 auto &child_var = loop.child_vars()[0];
908 ir_assert(find_loop(child_var).is_leaf());
909 kind = find_loop(child_var).kind();
910 level = loop_level(child_var);
911 } else {
912 level = loop_level(v);
913 }
914 ret.add_sub_loop(&loop, kind, level);
915 } else if (loop.is_split_parent()) {
916 walk_down(loop.child_vars()[0]);
917 walk_down(loop.child_vars()[1]);
918 } else {
919 ir_error_not_expected();
920 }
921 };
922 walk_down(root_var);
923 ir_assert(ret.is_valid()) << "Invalid loop nest.";
924 return ret;
925 }
926
compute_problem_tile(const std::vector<expr_t> & vars,const object_map_t<expr_t,split_info_t> & split_infos,tile_level_t tile_level)927 tensor_t compute_problem_tile(const std::vector<expr_t> &vars,
928 const object_map_t<expr_t, split_info_t> &split_infos,
929 tile_level_t tile_level) {
930 std::vector<dim_t> tile_dims;
931 std::vector<expr_t> tile_start;
932 for (auto &v : vars) {
933 auto &split_info = split_infos.at(v);
934 tile_dims.push_back(split_info.dim(tile_level));
935 tile_start.push_back(split_info.start(loops_, tile_level));
936 }
937 return tensor_t(tile_dims, tile_start);
938 }
939
maybe_inject_let_for_fused_vars(const stmt_t & _body,const loop_t & loop) const940 stmt_t maybe_inject_let_for_fused_vars(
941 const stmt_t &_body, const loop_t &loop) const {
942 auto body = _body;
943 if (!loop.is_leaf() || !loop.is_fused_child()) return body;
944 auto &pvars = loop.parent_vars();
945 for (auto it = pvars.rbegin(); it != pvars.rend(); it++) {
946 auto &ploop = find_loop(*it);
947 body = let_t::make(*it, ploop.expand_var(loops_), body);
948 }
949 return body;
950 }
951
952 bool is_finalized_ = false;
953
954 constraint_set_t *cset_;
955 grid_info_t kernel_grid_;
956 grid_info_t tg_grid_;
957
958 // Loop indices, ordered from outermost to innermost.
959 std::vector<expr_t> vars_;
960
961 object_map_t<expr_t, loop_t> loops_;
962
963 bmnk_mapper_t bmnk_mapper_;
964
965 // Full views for A, B, C.
966 view_t a_view_;
967 view_t b_view_;
968 view_t c_view_;
969
970 // Thread group tiles for A, B, C.
971 tensor_t a_tg_tile_;
972 tensor_t b_tg_tile_;
973 tensor_t c_tg_tile_;
974
975 // Thread tiles for A, B, C (relative to thread group tiles).
976 tensor_t a_thr_tile_;
977 tensor_t b_thr_tile_;
978 tensor_t c_thr_tile_;
979 };
980
981 } // namespace jit
982 } // namespace gpu
983 } // namespace impl
984 } // namespace dnnl
985
986 #endif
987