1 /*******************************************************************************
2 * Copyright 2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef GPU_JIT_CONV_IR_HPP
18 #define GPU_JIT_CONV_IR_HPP
19
20 #include <algorithm>
21 #include <mutex>
22 #include <thread>
23 #include <vector>
24
25 #include "gpu/jit/conv/ir_core.hpp"
26
27 namespace dnnl {
28 namespace impl {
29 namespace gpu {
30 namespace jit {
31
32 // Helper class to walk through IR tree.
33 class ir_visitor_t {
34 public:
35 using dispatch_func_type = void (*)(ir_visitor_t *, const object_impl_t &);
36
37 virtual ~ir_visitor_t() = default;
38
visit(const object_t & obj)39 virtual void visit(const object_t &obj) { dispatch(obj.impl()); }
40
41 template <typename T>
visit(const std::vector<T> & v)42 void visit(const std::vector<T> &v) {
43 for (auto &e : v)
44 visit(e);
45 }
46
pre_visit(const object_impl_t & obj)47 virtual void pre_visit(const object_impl_t &obj) {}
post_visit(const object_impl_t & obj)48 virtual void post_visit(const object_impl_t &obj) {}
49
50 // To catch missing _visit() handlers in ir_visitor_t.
_visit(const object_impl_t & obj)51 virtual void _visit(const object_impl_t &obj) {
52 ir_error_not_expected() << "Can't handle type: " << object_t(obj);
53 }
54
55 #define DECL_VISIT_LEAF(name) \
56 virtual void _visit(const name &obj) {}
57
58 DECL_VISIT_LEAF(bool_imm_t)
DECL_VISIT_LEAF(float_imm_t)59 DECL_VISIT_LEAF(float_imm_t)
60 DECL_VISIT_LEAF(func_impl_t)
61 DECL_VISIT_LEAF(int_imm_t)
62 DECL_VISIT_LEAF(var_t)
63
64 #undef DECL_VISIT_LEAF
65
66 virtual void _visit(const alloc_t &obj) {
67 visit(obj.buf);
68 visit(obj.body);
69 }
70
_visit(const binary_op_t & obj)71 virtual void _visit(const binary_op_t &obj) {
72 visit(obj.a);
73 visit(obj.b);
74 }
75
_visit(const cast_t & obj)76 virtual void _visit(const cast_t &obj) { visit(obj.expr); }
77
_visit(const for_t & obj)78 virtual void _visit(const for_t &obj) {
79 visit(obj.var);
80 visit(obj.init);
81 visit(obj.bound);
82 visit(obj.body);
83 }
84
_visit(const func_call_t & obj)85 virtual void _visit(const func_call_t &obj) {
86 visit(obj.func);
87 visit(obj.args);
88 }
89
_visit(const if_t & obj)90 virtual void _visit(const if_t &obj) {
91 visit(obj.cond);
92 visit(obj.body);
93 visit(obj.else_body);
94 }
95
_visit(const iif_t & obj)96 virtual void _visit(const iif_t &obj) {
97 visit(obj.cond);
98 visit(obj.true_expr);
99 visit(obj.false_expr);
100 }
101
_visit(const let_t & obj)102 virtual void _visit(const let_t &obj) {
103 visit(obj.var);
104 visit(obj.value);
105 visit(obj.body);
106 }
107
_visit(const load_t & obj)108 virtual void _visit(const load_t &obj) {
109 visit(obj.buf);
110 visit(obj.off);
111 }
112
_visit(const ptr_t & obj)113 virtual void _visit(const ptr_t &obj) {
114 visit(obj.base);
115 visit(obj.off);
116 }
117
_visit(const shuffle_t & obj)118 virtual void _visit(const shuffle_t &obj) { visit(obj.vec); }
119
_visit(const stmt_group_t & obj)120 virtual void _visit(const stmt_group_t &obj) { visit(obj.body); }
121
_visit(const stmt_seq_t & obj)122 virtual void _visit(const stmt_seq_t &obj) {
123 visit(obj.head);
124 visit(obj.tail);
125 }
126
_visit(const store_t & obj)127 virtual void _visit(const store_t &obj) {
128 visit(obj.buf);
129 visit(obj.off);
130 visit(obj.value);
131 visit(obj.mask);
132 }
133
_visit(const ternary_op_t & obj)134 virtual void _visit(const ternary_op_t &obj) {
135 visit(obj.a);
136 visit(obj.b);
137 visit(obj.c);
138 }
139
_visit(const unary_op_t & obj)140 virtual void _visit(const unary_op_t &obj) { visit(obj.a); }
141
is_supported(const object_t & obj) const142 bool is_supported(const object_t &obj) const {
143 if (obj.is_empty()) return true;
144
145 auto *impl = obj.impl();
146 auto ti = impl->dispatch_type_id();
147 return ti < num_dispatch_funcs;
148 }
149
150 protected:
find_dispatch_func(int64_t ti) const151 virtual dispatch_func_type find_dispatch_func(int64_t ti) const {
152 return ti < num_dispatch_funcs ? dispatch_funcs()[ti] : nullptr;
153 }
154
155 private:
156 static const int64_t num_dispatch_funcs
157 = ir_type_id_t::end_visitable_ir_objects;
158 static std::array<dispatch_func_type, num_dispatch_funcs> &
dispatch_funcs()159 dispatch_funcs() {
160 static std::array<dispatch_func_type, num_dispatch_funcs>
161 _dispatch_funcs;
162 static std::once_flag initialized;
163 std::call_once(initialized, [&]() {
164 #define HANDLE_IR_OBJECT(type) \
165 _dispatch_funcs[type::_dispatch_type_id()] = &call<type>;
166 HANDLE_ALL_IR_OBJECTS()
167
168 #undef HANDLE_IR_OBJECT
169 });
170 return _dispatch_funcs;
171 }
172
173 template <typename T>
call(ir_visitor_t * visitor,const object_impl_t & obj)174 static void call(ir_visitor_t *visitor, const object_impl_t &obj) {
175 visitor->pre_visit(obj);
176 visitor->_visit((const T &)obj);
177 visitor->post_visit(obj);
178 }
179
dispatch(const object_impl_t * obj)180 void dispatch(const object_impl_t *obj) {
181 if (!obj) return;
182
183 auto ti = obj->dispatch_type_id();
184 auto f = find_dispatch_func(ti);
185 if (!f) {
186 ir_error_not_expected() << "Can't handle type: " << object_t(obj);
187 }
188 f(this, *obj);
189 }
190 };
191
192 class ir_context_t {
193 public:
create_tmp_var(const type_t & type,const std::string & prefix="tmp")194 expr_t create_tmp_var(
195 const type_t &type, const std::string &prefix = "tmp") {
196 int &id = prefix_ids_[prefix];
197 auto name = prefix + "_" + std::to_string(id);
198 id++;
199 return var_t::make(type, name);
200 }
201
202 private:
203 std::unordered_map<std::string, int> prefix_ids_;
204 };
205
206 class alloc_updater_t : public ir_mutator_t {
207 public:
resize(const expr_t & buf,int new_size)208 void resize(const expr_t &buf, int new_size) {
209 auto ret = resizes_.insert({buf, new_size});
210 ir_assert(ret.second) << buf;
211 MAYBE_UNUSED(ret);
212 }
213
remove(const expr_t & buf)214 void remove(const expr_t &buf) {
215 auto ret = removes_.insert(buf);
216 ir_assert(ret.second) << buf;
217 MAYBE_UNUSED(ret);
218 }
219
update(const stmt_t & stmt)220 stmt_t update(const stmt_t &stmt) { return mutate(stmt); }
221
_mutate(const alloc_t & obj)222 object_t _mutate(const alloc_t &obj) override {
223 auto new_obj = ir_mutator_t::_mutate(obj);
224
225 if (try_remove(new_obj)) return new_obj;
226 if (try_resize(new_obj)) return new_obj;
227
228 return new_obj;
229 }
230
231 private:
try_remove(object_t & obj)232 bool try_remove(object_t &obj) {
233 auto &alloc = obj.as<alloc_t>();
234 auto it = removes_.find(alloc.buf);
235 if (it == removes_.end()) return false;
236
237 obj = alloc.body;
238 removes_.erase(it);
239 return true;
240 }
241
try_resize(object_t & obj)242 bool try_resize(object_t &obj) {
243 auto &alloc = obj.as<alloc_t>();
244 auto it = resizes_.find(alloc.buf);
245 if (it == resizes_.end()) return false;
246
247 obj = alloc_t::make(
248 alloc.buf, it->second, alloc.kind, alloc.attr, alloc.body);
249 resizes_.erase(it);
250 return true;
251 }
252
253 object_set_t<expr_t> removes_;
254 object_map_t<expr_t, int> resizes_;
255 };
256
257 // Returns a new statement with injected buffer allocations from `allocs`.
258 // - If put_innermost is false, then `stmt` is nested to all allocations
259 // - If put_innermost is true, then every allocation is injected as innermost
260 // as possible
261 stmt_t inject_alloc_stmts(const stmt_t &stmt, const std::vector<stmt_t> &allocs,
262 bool put_innermost = false);
263
264 // Returns a new statement with injected let statements, `stmt` is nested to
265 // all let statements.
266 stmt_t inject_let_stmts(const stmt_t &stmt, const std::vector<stmt_t> &lets);
267
268 template <typename T>
269 struct expr_cast_helper_t {
calldnnl::impl::gpu::jit::expr_cast_helper_t270 static T call(const expr_t &e) { return to_cpp<T>(e); }
271
calldnnl::impl::gpu::jit::expr_cast_helper_t272 static std::vector<T> call(const std::vector<expr_t> &exprs) {
273 std::vector<T> ret;
274 for (auto &e : exprs)
275 ret.push_back(to_cpp<T>(e));
276 return ret;
277 }
278 };
279
280 template <>
281 struct expr_cast_helper_t<expr_t> {
calldnnl::impl::gpu::jit::expr_cast_helper_t282 static expr_t call(const expr_t &e) { return e; }
283
calldnnl::impl::gpu::jit::expr_cast_helper_t284 static std::vector<expr_t> call(const std::vector<expr_t> &exprs) {
285 return exprs;
286 }
287
288 template <typename U,
289 typename
290 = typename std::enable_if<std::is_arithmetic<U>::value>::type>
calldnnl::impl::gpu::jit::expr_cast_helper_t291 static std::vector<expr_t> call(const std::vector<U> &vec) {
292 std::vector<expr_t> ret;
293 for (auto &v : vec)
294 ret.push_back(to_expr(v));
295 return ret;
296 }
297 };
298
299 template <typename DstT, typename SrcT>
300 DstT expr_cast(const SrcT &src) {
301 return expr_cast_helper_t<DstT>::call(src);
302 }
303
304 template <typename DstT, typename SrcT>
expr_cast(const std::vector<SrcT> & src)305 std::vector<DstT> expr_cast(const std::vector<SrcT> &src) {
306 return expr_cast_helper_t<DstT>::call(src);
307 }
308
309 // Performs constant folding recursively to an IR tree.
310 object_t const_fold(const object_t &obj);
311
312 // Performs constant folding non-recursively to an expression.
313 expr_t const_fold_non_recursive(const expr_t &e);
314
315 template <typename T>
316 std::vector<object_t> find_objects(const object_t &root);
317
318 template <typename T>
319 std::vector<object_t> find_objects_unique(const object_t &root);
320
321 class alloc_manager_t {
322 public:
alloc_manager_t(const stmt_t & root)323 alloc_manager_t(const stmt_t &root) {
324 auto allocs = find_objects<alloc_t>(root);
325 for (auto &_a : allocs) {
326 auto &a = _a.as<alloc_t>();
327 auto ret = buf2alloc_.insert({a.buf, _a});
328 buffers_.push_back(a.buf);
329 ir_assert(ret.second) << "Buffer already exists: " << a.buf;
330 MAYBE_UNUSED(ret);
331 }
332
333 // Sort buffers by name.
334 std::sort(buffers_.begin(), buffers_.end(),
335 [](const expr_t &a, const expr_t &b) {
336 return a.as<var_t>().name < b.as<var_t>().name;
337 });
338 }
339
buffers() const340 const std::vector<expr_t> &buffers() const { return buffers_; }
341
find_buffer(const std::string & name,bool allow_empty=false) const342 expr_t find_buffer(
343 const std::string &name, bool allow_empty = false) const {
344 for (auto &b : buffers())
345 if (b.as<var_t>().name == name) return b;
346
347 if (!allow_empty) ir_error_not_expected() << name;
348 return expr_t();
349 }
350
find_buffers(alloc_kind_t kind) const351 std::vector<expr_t> find_buffers(alloc_kind_t kind) const {
352 std::vector<expr_t> ret;
353 for (auto &b : buffers())
354 if (alloc_kind(b) == kind) ret.push_back(b);
355 return ret;
356 }
357
alloc_size(const expr_t & buf) const358 int alloc_size(const expr_t &buf) const {
359 auto *a = find_alloc(buf);
360 ir_assert(a) << buf;
361 return a->size;
362 }
363
alloc_kind(const expr_t & buf) const364 alloc_kind_t alloc_kind(const expr_t &buf) const {
365 auto *a = find_alloc(buf);
366 ir_assert(a) << buf;
367 return a->kind;
368 }
369
total_size(alloc_kind_t kind) const370 int total_size(alloc_kind_t kind) const {
371 int ret = 0;
372 for (auto &kv : buf2alloc_) {
373 auto &a = kv.second.as<alloc_t>();
374 if (a.kind == kind) ret += a.size;
375 }
376 return ret;
377 }
378
379 private:
find_alloc(const expr_t & buf) const380 const alloc_t *find_alloc(const expr_t &buf) const {
381 auto it = buf2alloc_.find(buf);
382 if (it == buf2alloc_.end()) return nullptr;
383 return it->second.as_ptr<alloc_t>();
384 }
385
386 object_map_t<expr_t, stmt_t> buf2alloc_;
387 std::vector<expr_t> buffers_;
388 object_map_t<expr_t, stmt_t> alloc_updates_;
389 };
390
391 // IR utility functions.
392 expr_t abs(const expr_t &e);
393
394 expr_t cast(const expr_t &e, const type_t &type, bool saturate = false);
395
396 bool is_zero(const expr_t &e);
397
398 bool is_one(const expr_t &e);
399
400 bool is_minus_one(const expr_t &e);
401
402 bool is_const_broadcast(const expr_t &e);
403
404 bool is_const_broadcast(const expr_t &e, const expr_t &value);
405
406 bool all_of(const expr_t &e, const expr_t &value);
407
408 expr_t make_buffer(const std::string &name);
409
410 // Utility functions for nary_op_t.
411 expr_t nary_op_back_transform(const expr_t &e);
412 expr_t nary_op_canonicalize(const expr_t &_e);
413 expr_t make_nary_op(op_kind_t op_kind, const std::vector<expr_t> &args);
414 std::vector<expr_t> cvt_expr_to_nary_op_args(const expr_t &e);
415
416 // Substitutes all occurrences of `from` to `to` in `root.
417 object_t substitute(const object_t &root, const object_t &from,
418 const object_t &to,
419 int max_substitutions = std::numeric_limits<int>::max());
420
421 // Returns leaf statements of `root`. Uses inorder traversal.
422 std::vector<stmt_t> flatten_statements(const stmt_t &root);
423
424 template <typename T, bool find_unique = false, bool save_objects = true>
425 class object_finder_t : public ir_visitor_t {
426 public:
_visit(const T & obj)427 void _visit(const T &obj) override {
428 ir_visitor_t::_visit(obj);
429 occurrences++;
430 if (!save_objects) return;
431 if (find_unique) {
432 found_unique.insert(obj);
433 } else {
434 found.push_back(obj);
435 }
436 }
437
438 std::vector<object_t> found;
439 object_set_t<object_t> found_unique;
440 int occurrences = 0;
441 };
442
443 // Returns all IR objects of type `T` found in `root`.
444 template <typename T>
find_objects(const object_t & root)445 std::vector<object_t> find_objects(const object_t &root) {
446 object_finder_t<T, /*find_unique=*/false> finder;
447 finder.visit(root);
448 return finder.found;
449 }
450
451 template <typename T>
count_objects(const object_t & root)452 int count_objects(const object_t &root) {
453 object_finder_t<T, /*find_unique=*/false, /*save_objects=*/false> finder;
454 finder.visit(root);
455 return finder.occurrences;
456 }
457
458 // Returns unique IR objects of type `T` found in `root`.
459 template <typename T>
find_unique_objects(const object_t & root)460 object_set_t<object_t> find_unique_objects(const object_t &root) {
461 object_finder_t<T, /*find_unique=*/true> finder;
462 finder.visit(root);
463 return finder.found_unique;
464 }
465
466 // Returns number of occurrences of `obj` in `root` (based on identity
467 // comparison).
468 int count_object(const object_t &root, const object_t &obj);
469
470 // Returns number of occurrences of `obj` in vector of root objects (based on
471 // identity comparison).
472 template <typename T>
count_object(const std::vector<T> & roots,const object_t & obj)473 int count_object(const std::vector<T> &roots, const object_t &obj) {
474 int ret = 0;
475 for (auto &root : roots)
476 ret += count_object(root, obj);
477 return ret;
478 }
479
480 // Checks if `root` contains `obj`.
481 bool contains_object(const object_t &root, const object_t &obj);
482
483 // Returns all statement groups matching the label.
484 std::vector<stmt_t> find_stmt_groups(
485 const object_t &root, const stmt_label_t &label);
486
487 // Returns a statement group matching the label. `root` must have exactly one
488 // occurrence.
489 stmt_t find_stmt_group(const object_t &root, const stmt_label_t &label);
490
491 class scope_visitor_t : public ir_visitor_t {
492 public:
is_expr_defined(const expr_t & e) const493 bool is_expr_defined(const expr_t &e) const {
494 auto vars = find_unique_objects<var_t>(e);
495 for (auto &v : vars) {
496 if (def_vars_.count(v) == 0) return false;
497 }
498 return true;
499 }
500
501 #define CASE(type, var_field, is_pre) \
502 if (obj.type_id() == type::_type_id()) { \
503 visit_scope((const type &)obj, ((const type &)obj).var_field, is_pre); \
504 return; \
505 }
506
pre_visit(const object_impl_t & obj)507 void pre_visit(const object_impl_t &obj) override {
508 CASE(alloc_t, buf, true);
509 CASE(let_t, var, true);
510 CASE(for_t, var, true);
511 }
512
post_visit(const object_impl_t & obj)513 void post_visit(const object_impl_t &obj) override {
514 CASE(alloc_t, buf, false);
515 CASE(let_t, var, false);
516 CASE(for_t, var, false);
517 }
518
519 #undef CASE
520
521 private:
522 template <typename T>
visit_scope(const T & obj,const expr_t & var,bool is_pre_visit)523 void visit_scope(const T &obj, const expr_t &var, bool is_pre_visit) {
524 if (is_pre_visit) {
525 def_vars_.insert(var);
526 return;
527 }
528 def_vars_.erase(var);
529 }
530
531 object_set_t<expr_t> def_vars_;
532 };
533
534 class ir_path_t {
535 public:
push(const object_impl_t * obj)536 void push(const object_impl_t *obj) { path_.push_back(obj); }
537
pop()538 void pop() { path_.pop_back(); }
539
back() const540 const object_impl_t *back() const {
541 ir_assert(!is_empty());
542 return path_.back();
543 }
544
is_empty() const545 bool is_empty() const { return path_.empty(); }
546
merge(const ir_path_t & other)547 void merge(const ir_path_t &other) {
548 size_t idx;
549 size_t min_size = std::min(path_.size(), other.path_.size());
550 for (idx = 0; idx < min_size; idx++) {
551 if (path_[idx] != other.path_[idx]) break;
552 }
553 path_.resize(idx);
554 }
555
556 private:
557 std::vector<const object_impl_t *> path_;
558 };
559
560 // Only for statements that create scope.
561 stmt_t get_stmt_body(const stmt_t &stmt);
562
563 stmt_t replace_stmt_body(const stmt_t &stmt, const stmt_t &new_body);
564
565 // Describes the linear transformation F(x) for variable x: F(x) = (a * x + b),
566 // where a and b are integer constants.
567 struct linear_transform_t {
568 expr_t x;
569 int a;
570 int b;
571
is_identitydnnl::impl::gpu::jit::linear_transform_t572 bool is_identity() const { return a == 1 && b == 0; }
573 };
574
575 // Relation: (lhs op rhs), where:
576 // - lhs is a variable
577 // - rhs is an integer constant
578 // - op is a comparison operation
579 class relation_t {
580 public:
relation_t(const expr_t & expr)581 relation_t(const expr_t &expr) : expr_(normalize(expr)) {}
582
expr() const583 const expr_t &expr() const { return expr_; }
584
var() const585 const expr_t &var() const { return expr_.as<binary_op_t>().a; }
586
rhs() const587 const expr_t &rhs() const { return expr_.as<binary_op_t>().b; }
588
op_kind() const589 op_kind_t op_kind() const { return expr_.as<binary_op_t>().op_kind; }
590
591 bool implies(const relation_t &other) const;
592
593 // Applies linear transformation to left and right hand sides of the relation.
594 relation_t transform(const linear_transform_t &t, const expr_t &new_var);
595
str() const596 std::string str() const {
597 std::ostringstream oss;
598 oss << expr_;
599 return oss.str();
600 }
601
is_relation_constraint(const expr_t & e)602 static bool is_relation_constraint(const expr_t &e) {
603 auto *binary_op = e.as_ptr<binary_op_t>();
604 if (!binary_op) return false;
605 if (!is_var(binary_op->a)) return false;
606 if (!is_const(binary_op->b)) return false;
607 if (!is_cmp_op(binary_op->op_kind)) return false;
608 return true;
609 }
610
611 private:
612 static expr_t normalize(const expr_t &e);
613
614 expr_t expr_;
615 };
616
operator <<(std::ostream & out,const relation_t & rel)617 inline std::ostream &operator<<(std::ostream &out, const relation_t &rel) {
618 out << rel.str();
619 return out;
620 }
621
622 // Equality for modulus: (var % mod) == 0, where:
623 // - var is a variable
624 // - mod is an integer constant
625 class modulus_info_t {
626 public:
modulus_info_t(const expr_t & expr)627 modulus_info_t(const expr_t &expr) : expr_(expr) {}
628
expr() const629 const expr_t &expr() const { return expr_; }
630
var() const631 const expr_t &var() const {
632 auto &mod_expr = expr_.as<binary_op_t>().a;
633 return mod_expr.as<binary_op_t>().a;
634 }
635
mod() const636 const expr_t &mod() const {
637 auto &mod_expr = expr_.as<binary_op_t>().a;
638 return mod_expr.as<binary_op_t>().b;
639 }
640
implies(const modulus_info_t & other) const641 bool implies(const modulus_info_t &other) const {
642 ir_assert(var().is_same(other.var()));
643
644 int64_t this_mod = to_cpp<int64_t>(mod());
645 int64_t other_mod = to_cpp<int64_t>(other.mod());
646
647 return this_mod % other_mod == 0;
648 }
649
str() const650 std::string str() const {
651 std::ostringstream oss;
652 oss << expr_;
653 return oss.str();
654 }
655
656 // Try to match (var % mod) == 0.
657 static bool is_modulus_constraint(const expr_t &e);
658
659 private:
660 expr_t expr_;
661 };
662
operator <<(std::ostream & out,const modulus_info_t & mod)663 inline std::ostream &operator<<(std::ostream &out, const modulus_info_t &mod) {
664 out << mod.str();
665 return out;
666 }
667
668 // Helper class to find constant bounds of integer expressions based on known
669 // relations.
670 class bound_finder_t {
671 public:
bound_finder_t(const object_map_t<expr_t,std::vector<relation_t>> & relations)672 bound_finder_t(
673 const object_map_t<expr_t, std::vector<relation_t>> &relations)
674 : relations_(relations) {}
675
find_low_bound(const expr_t & e) const676 int64_t find_low_bound(const expr_t &e) const {
677 return find_bound_impl(e, /*is_low=*/true);
678 }
679
find_high_bound(const expr_t & e) const680 int64_t find_high_bound(const expr_t &e) const {
681 return find_bound_impl(e, /*is_low=*/false);
682 }
683
is_good_bound(int64_t bound)684 static bool is_good_bound(int64_t bound) {
685 if (bound == unlimited_bound(true)) return false;
686 if (bound == unlimited_bound(false)) return false;
687 return true;
688 }
689
690 private:
691 // If is_low is true, searches for proven low bound, and high bound
692 // otherwise.
693 int64_t find_bound_impl(const expr_t &e, bool is_low) const;
694
unlimited_bound(bool is_low)695 static int64_t unlimited_bound(bool is_low) {
696 if (is_low) return std::numeric_limits<int64_t>::min();
697 return std::numeric_limits<int64_t>::max();
698 }
699
700 object_map_t<expr_t, std::vector<relation_t>> relations_;
701 };
702
703 // TODO: Add integers check (only integers can be constrained).
704 class constraint_set_t {
705 public:
706 void add_constraint(const expr_t &e);
707
can_prove(const expr_t & e,bool try_simplify=true) const708 bool can_prove(const expr_t &e, bool try_simplify = true) const {
709 auto ret = can_prove_impl(e, /*do_simplify=*/false);
710 if (ret || !try_simplify) return ret;
711
712 return can_prove_impl(e, /*do_simplify=*/true);
713 }
714
715 bool is_single_value(const expr_t &e, expr_t &value) const;
716
717 int max_proven_gcd(const expr_t &var) const;
718
719 private:
can_prove_modulus(const expr_t & e) const720 bool can_prove_modulus(const expr_t &e) const {
721 modulus_info_t unknown(e);
722 auto it = modulus_infos_.find(unknown.var());
723 if (it == modulus_infos_.end()) return false;
724
725 for (auto &known : it->second) {
726 if (known.implies(unknown)) return true;
727 }
728
729 return false;
730 }
731
can_prove_relation(const expr_t & e) const732 bool can_prove_relation(const expr_t &e) const {
733 relation_t unknown(e);
734 auto it = relations_.find(unknown.var());
735 if (it == relations_.end()) return false;
736
737 for (auto &known : it->second) {
738 if (known.implies(unknown)) return true;
739 }
740
741 return false;
742 }
743
try_prove_compound_relation(const expr_t & e) const744 bool try_prove_compound_relation(const expr_t &e) const {
745 auto *binary = e.as_ptr<binary_op_t>();
746 if (!binary) return false;
747
748 auto op_kind = binary->op_kind;
749 auto &a = binary->a;
750 auto &_b = binary->b;
751
752 if (!is_const(_b)) return false;
753
754 auto b = to_cpp<int64_t>(_b);
755
756 // Normalize operation kind.
757 switch (op_kind) {
758 case op_kind_t::_ge:
759 case op_kind_t::_le: break;
760 case op_kind_t::_gt:
761 op_kind = op_kind_t::_ge;
762 ir_assert(b < std::numeric_limits<int64_t>::max());
763 b += 1;
764 break;
765 case op_kind_t::_lt:
766 op_kind = op_kind_t::_le;
767 ir_assert(b > std::numeric_limits<int64_t>::min());
768 b -= 1;
769 break;
770 default: return false;
771 }
772
773 bound_finder_t finder(relations_);
774 if (op_kind == op_kind_t::_ge) {
775 auto lo = finder.find_low_bound(a);
776 if (!bound_finder_t::is_good_bound(lo)) return false;
777 return lo >= b;
778 }
779
780 if (op_kind == op_kind_t::_le) {
781 auto hi = finder.find_high_bound(a);
782 if (!bound_finder_t::is_good_bound(hi)) return false;
783 return hi <= b;
784 }
785
786 return false;
787 }
788
789 bool can_prove_impl(const expr_t &_e, bool do_simplify) const;
790
791 object_map_t<expr_t, std::vector<relation_t>> relations_;
792 object_map_t<expr_t, std::vector<modulus_info_t>> modulus_infos_;
793 };
794
795 // Simplifies expression or statement. An optional constraint set is used to
796 // pass known equalities and inequalities which may be used for simplification.
797 object_t simplify(const object_t &obj, const constraint_set_t &cset = {});
798
799 // Searches for expression patterns to reduce them to the equivalent ternary
800 // operations.
801 expr_t simplify_rewrite_with_ternary(const expr_t &e, bool recursive = true);
802
803 // Moves constants to the right hand side of an expression.
804 // Example: (c0 + x) op c1 -> x op (c1 - c0)
805 expr_t simplify_cmp_move_const_to_rhs(const expr_t &e);
806
807 // Reduces left and right hand sides of an expression.
808 // Example: A * x < A * B -> x < B (if A > 0).
809 expr_t simplify_cmp_reduce_lhs_rhs(const expr_t &e);
810
811 // Propagates shuffle down the expression tree for more effective vectorization.
812 expr_t simplify_propagate_shuffle(const expr_t &e);
813
814 // Pre-defined functions.
815 namespace funcs {
816
barrier_func()817 inline func_t barrier_func() {
818 static auto f = builtin_t::make("barrier");
819 return f;
820 }
821
barrier()822 inline stmt_t barrier() {
823 return barrier_func().call();
824 }
825
slm_fence_func()826 inline func_t slm_fence_func() {
827 static auto f = builtin_t::make("slm_fence");
828 return f;
829 }
830
slm_fence()831 inline stmt_t slm_fence() {
832 return slm_fence_func().call();
833 }
834
signal_func()835 inline func_t signal_func() {
836 static auto f = builtin_t::make("signal");
837 return f;
838 }
839
signal()840 inline stmt_t signal() {
841 return signal_func().call();
842 }
843
barrier_wait_func()844 inline func_t barrier_wait_func() {
845 static auto f = builtin_t::make("barrier_wait");
846 return f;
847 }
848
barrier_wait()849 inline stmt_t barrier_wait() {
850 return barrier_wait_func().call();
851 }
852
853 } // namespace funcs
854
855 // Helper functionality to extract ND indices packed into 1D index.
856 // Example:
857 // i = [0; Bi, 2 * Bi, ... (I - 1) * Bi]
858 // i_info.dim = I; i_info.block = Bi
859 // j = [0; Bj, 2 * Bj, ... (J - 1) * Bj]
860 // j_info.dim = J; j_info.block = Bj
861 // 1D index: ij_idx
862 // 2D indices: [i; j]
863 // Unpacking:
864 // i = (ij_idx % I) * Bi
865 // j = (ij_idx / I) * Bj
866 struct unpack_dim_info_t {
867 const expr_t &var;
868 int dim;
869 int block;
870 };
871
cvt_args_to_unpack_dim_info(std::vector<unpack_dim_info_t> &)872 inline void cvt_args_to_unpack_dim_info(std::vector<unpack_dim_info_t> &) {}
873
874 template <typename... ArgsT>
cvt_args_to_unpack_dim_info(std::vector<unpack_dim_info_t> & infos,const expr_t & var,int dim,int block,const ArgsT &...args)875 void cvt_args_to_unpack_dim_info(std::vector<unpack_dim_info_t> &infos,
876 const expr_t &var, int dim, int block, const ArgsT &... args) {
877 infos.push_back(unpack_dim_info_t {var, dim, block});
878 cvt_args_to_unpack_dim_info(infos, args...);
879 }
880
881 void unpack(std::vector<stmt_t> &init_stmts, constraint_set_t &cset,
882 const expr_t &_e, const std::vector<unpack_dim_info_t> &infos);
883
884 template <typename... ArgsT>
unpack(std::vector<stmt_t> & init_stmts,constraint_set_t & cset,const expr_t & e,const ArgsT &...args)885 void unpack(std::vector<stmt_t> &init_stmts, constraint_set_t &cset,
886 const expr_t &e, const ArgsT &... args) {
887 std::vector<unpack_dim_info_t> infos;
888 cvt_args_to_unpack_dim_info(infos, args...);
889 unpack(init_stmts, cset, e, infos);
890 }
891
892 } // namespace jit
893 } // namespace gpu
894 } // namespace impl
895 } // namespace dnnl
896
897 #endif
898