1 /*******************************************************************************
2 * Copyright 2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <sstream>
18
19 #include "common/math_utils.hpp"
20 #include "gpu/jit/conv/ir.hpp"
21 #include "gpu/jit/conv/ir_core.hpp"
22
23 namespace dnnl {
24 namespace impl {
25 namespace gpu {
26 namespace jit {
27
28 using namespace ir_utils;
29
30 namespace {
31
32 // Helper class to print IR objects.
33 class ir_printer_t : public ir_visitor_t {
34 public:
ir_printer_t(std::ostream & out)35 ir_printer_t(std::ostream &out) : out_(out) {}
36
visit(const object_t & obj)37 void visit(const object_t &obj) override {
38 if (is_supported(obj)) {
39 ir_visitor_t::visit(obj);
40 return;
41 }
42 // Only expressions/functions are expected here.
43 out_ << obj.str();
44 }
45
_visit(const alloc_t & obj)46 void _visit(const alloc_t &obj) override {
47 print_indent();
48 out_ << "alloc " << obj.buf.as<var_t>().name << "[" << obj.size
49 << "]\n";
50 visit(obj.body);
51 }
52
_visit(const binary_op_t & obj)53 void _visit(const binary_op_t &obj) override {
54 if (utils::one_of(obj.op_kind, op_kind_t::_min, op_kind_t::_max)) {
55 out_ << to_string(obj.op_kind) << "(" << obj.a << ", " << obj.b
56 << ")";
57 return;
58 }
59 out_ << "(";
60 visit(obj.a);
61 out_ << " " << to_string(obj.op_kind) << " ";
62 visit(obj.b);
63 out_ << ")";
64 }
65
_visit(const bool_imm_t & obj)66 void _visit(const bool_imm_t &obj) override {
67 out_ << (obj.value ? "true" : "false");
68 }
69
_visit(const cast_t & obj)70 void _visit(const cast_t &obj) override {
71 out_ << obj.type;
72 if (obj.saturate) out_ << ".sat";
73 out_ << "(" << obj.expr << ")";
74 }
75
_visit(const float_imm_t & obj)76 void _visit(const float_imm_t &obj) override { out_ << obj.value; }
77
_visit(const for_t & obj)78 void _visit(const for_t &obj) override {
79 print_indent();
80 out_ << "for (" << obj.var << " = " << obj.init << "; " << obj.var
81 << " < " << obj.bound << "; " << obj.var << "++) ";
82 if (obj.unroll != 1) out_ << "[unroll: " << obj.unroll << "] ";
83 out_ << "{\n";
84 add_indent();
85 visit(obj.body);
86 remove_indent();
87 print_indent();
88 out_ << "}\n";
89 }
90
_visit(const func_call_t & obj)91 void _visit(const func_call_t &obj) override {
92 print_indent();
93 out_ << obj.func << "(" << make_seq_print_helper(obj.args) << ")";
94 if (!obj.attr.is_empty()) out_ << " " << obj.attr;
95 out_ << "\n";
96 }
97
_visit(const func_impl_t & obj)98 void _visit(const func_impl_t &obj) override { out_ << obj.str(); }
99
_visit(const if_t & obj)100 void _visit(const if_t &obj) override {
101 print_indent();
102 out_ << "if (" << strip_parens(obj.cond.str()) << ") {\n";
103 add_indent();
104 visit(obj.body);
105 remove_indent();
106 print_indent();
107 if (obj.else_body.is_empty()) {
108 out_ << "}\n";
109 return;
110 }
111 out_ << "} else {\n";
112 add_indent();
113 visit(obj.else_body);
114 remove_indent();
115 print_indent();
116 out_ << "}\n";
117 }
118
_visit(const iif_t & obj)119 void _visit(const iif_t &obj) override {
120 out_ << "(" << obj.cond << " ? " << obj.true_expr << " : "
121 << obj.false_expr << ")";
122 }
123
_visit(const int_imm_t & obj)124 void _visit(const int_imm_t &obj) override {
125 out_ << std::to_string(obj.value);
126 }
127
_visit(const let_t & obj)128 void _visit(const let_t &obj) override {
129 print_indent();
130 out_ << obj.var << "." << obj.var.type() << " = " << obj.value << "\n";
131 visit(obj.body);
132 }
133
_visit(const load_t & obj)134 void _visit(const load_t &obj) override {
135 out_ << obj.buf;
136 if (obj.has_default_stride()) {
137 out_ << "." << obj.type << "(" << obj.off / obj.type.size() << ")";
138 } else {
139 out_ << "[" << obj.off << "]." << obj.type;
140 out_ << "<" << obj.stride << ">";
141 }
142 }
143
_visit(const ptr_t & obj)144 void _visit(const ptr_t &obj) override {
145 out_ << obj.base << "[" << obj.off << "]";
146 }
147
_visit(const shuffle_t & obj)148 void _visit(const shuffle_t &obj) override {
149 if (obj.is_broadcast()) {
150 out_ << "bcast" << obj.elems() << "(" << obj.vec[0] << ")";
151 return;
152 }
153 std::vector<expr_t> vec_all;
154 for (auto &v : obj.vec) {
155 for (int i = 0; i < v.type().elems(); i++)
156 vec_all.push_back(v);
157 }
158 int elems = obj.type.elems();
159 out_ << "(";
160 for (int i = 0; i < elems; i++) {
161 int idx = obj.idx[i];
162 auto &v = vec_all[idx];
163 int v_elems = v.type().elems();
164 out_ << v;
165 if (v_elems != 1) out_ << "[" << idx << "]";
166 if (i != elems - 1) out_ << ", ";
167 }
168 out_ << ")";
169 }
170
_visit(const stmt_group_t & obj)171 void _visit(const stmt_group_t &obj) override {
172 print_indent();
173 out_ << obj.label << " {\n";
174 add_indent();
175 visit(obj.body);
176 remove_indent();
177 print_indent();
178 out_ << "}\n";
179 return;
180 }
181
_visit(const stmt_seq_t & obj)182 void _visit(const stmt_seq_t &obj) override {
183 visit(obj.head);
184 visit(obj.tail);
185 }
186
_visit(const store_t & obj)187 void _visit(const store_t &obj) override {
188 print_indent();
189 out_ << load_t::make(obj.value.type(), obj.buf, obj.off, obj.stride);
190 out_ << " = " << obj.value;
191 if (!obj.mask.is_empty()) out_ << " [masked]";
192 out_ << "\n";
193 }
194
_visit(const ternary_op_t & obj)195 void _visit(const ternary_op_t &obj) override {
196 out_ << to_string(obj.op_kind) << "(" << obj.a << ", " << obj.b << ", "
197 << obj.c << ")";
198 return;
199 }
200
_visit(const unary_op_t & obj)201 void _visit(const unary_op_t &obj) override {
202 out_ << to_string(obj.op_kind);
203 visit(obj.a);
204 }
205
_visit(const var_t & obj)206 void _visit(const var_t &obj) override { out_ << obj.name; }
207
208 private:
strip_parens(const std::string & s)209 static std::string strip_parens(const std::string &s) {
210 if (s.size() < 2 || s[0] != '(' || s[s.size() - 1] != ')') return s;
211 auto ret = s;
212 ret.resize(s.size() - 1);
213 return ret.substr(1);
214 }
215
print_indent()216 void print_indent() {
217 for (int i = 0; i < indent_; i++)
218 out_ << prefix_;
219 }
220
add_indent()221 void add_indent() { indent_++; }
remove_indent()222 void remove_indent() { indent_--; }
223
224 std::ostream &out_;
225 int indent_ = 0;
226
227 std::string prefix_ = " ";
228 };
229
230 class substitute_mutator_t : public ir_mutator_t {
231 public:
substitute_mutator_t(const object_t & from,const object_t & to)232 substitute_mutator_t(const object_t &from, const object_t &to)
233 : from_(from), to_(to) {}
234
substitutions() const235 int substitutions() const { return substitutions_; }
236
237 #define HANDLE_IR_OBJECT(type) \
238 object_t _mutate(const type &obj) override { \
239 auto *this_mutator = (substitute_mutator_t *)this; \
240 if (this_mutator->from_.impl() == (const object_impl_t *)&obj) { \
241 this_mutator->substitutions_++; \
242 return this_mutator->to_; \
243 } \
244 return ir_mutator_t::_mutate(obj); \
245 };
246
247 HANDLE_MUTATE_TARGETS()
248
249 #undef HANDLE_IR_OBJECT
250
251 private:
252 object_t from_;
253 object_t to_;
254
255 int substitutions_ = 0;
256 };
257
258 class stmt_flattener_t : public ir_visitor_t {
259 public:
260 #define HANDLE_IR_OBJECT(type) \
261 void _visit(const type &obj) { \
262 size_t old_size = stmts.size(); \
263 ir_visitor_t::_visit(obj); \
264 if (stmts.size() > old_size) return; \
265 if (obj.is_stmt()) stmts.push_back(obj); \
266 }
267
268 HANDLE_ALL_IR_OBJECTS()
269
270 #undef HANDLE_IR_OBJECT
271
272 std::vector<stmt_t> stmts;
273 };
274
275 class alloc_injector_t : public ir_mutator_t {
276 public:
alloc_injector_t(const stmt_t & root,const std::vector<stmt_t> & allocs,bool put_innermost)277 alloc_injector_t(const stmt_t &root, const std::vector<stmt_t> &allocs,
278 bool put_innermost)
279 : root_(root), put_innermost_(put_innermost), allocs_(allocs) {
280 for (auto &_a : allocs) {
281 auto &a = _a.as<alloc_t>();
282 if (a.kind != alloc_kind_t::global) ir_assert(a.size > 0) << _a;
283 alloc_map_.insert({a.buf, _a});
284 }
285 mutate(root_);
286 buf_total_refs_ = buf_cur_refs_;
287 for (auto &kv : buf_cur_refs_)
288 kv.second = 0;
289 in_ctor_ = false;
290 }
291
292 #define HANDLE_IR_OBJECT(type) \
293 object_t _mutate(const type &obj) override { return mutate_stmt(obj); }
294
HANDLE_STMT_IR_OBJECTS()295 HANDLE_STMT_IR_OBJECTS()
296
297 #undef HANDLE_IR_OBJECT
298 object_t _mutate(const var_t &obj) override {
299 if (alloc_map_.find(obj) != alloc_map_.end()) buf_cur_refs_[obj]++;
300 return obj;
301 }
302
303 private:
304 template <typename T>
mutate_stmt(const T & obj)305 object_t mutate_stmt(const T &obj) {
306 if (in_ctor_) return ir_mutator_t::_mutate(obj);
307 object_t new_obj = obj;
308 object_set_t<expr_t> undef_bufs;
309 if (put_innermost_) {
310 for (auto &kv : buf_cur_refs_)
311 if (kv.second == 0) undef_bufs.insert(kv.first);
312 new_obj = ir_mutator_t::_mutate(obj);
313 }
314 for (auto &a : allocs_) {
315 auto it = alloc_map_.find(a.as<alloc_t>().buf);
316 auto &buf = it->first;
317 if (it->second.is_empty()) continue; // Already injected.
318 bool do_inject = false;
319 if (put_innermost_) {
320 int cur_refs = buf_cur_refs_[buf];
321 int total_refs = buf_total_refs_[buf];
322 bool was_undef = (undef_bufs.count(buf) != 0);
323 do_inject = was_undef && (cur_refs == total_refs);
324 } else {
325 do_inject = root_.is_same(obj);
326 }
327 if (do_inject) {
328 auto &a = it->second.as<alloc_t>();
329 new_obj = alloc_t::make(a.buf, a.size, a.kind, a.attr, new_obj);
330 it->second = stmt_t();
331 }
332 }
333 return new_obj;
334 }
335
336 bool in_ctor_ = true;
337 const stmt_t &root_;
338 bool put_innermost_;
339 std::vector<stmt_t> allocs_;
340 object_map_t<expr_t, stmt_t> alloc_map_;
341 object_map_t<expr_t, int> buf_total_refs_;
342 object_map_t<expr_t, int> buf_cur_refs_;
343 };
344
345 } // namespace
346
str() const347 std::string object_impl_t::str() const {
348 std::ostringstream oss;
349 ir_printer_t printer(oss);
350 ir_assert(printer.is_supported(this));
351 printer.visit(this);
352 return oss.str();
353 }
354
substitute(const object_t & root,const object_t & from,const object_t & to,int max_substitutions)355 object_t substitute(const object_t &root, const object_t &from,
356 const object_t &to, int max_substitutions) {
357 if (to.is_same(from)) return root;
358 substitute_mutator_t sm(from, to);
359 auto ret = sm.mutate(root);
360 ir_assert(sm.substitutions() <= max_substitutions)
361 << "Unexpected number of substitutions.";
362 MAYBE_UNUSED(&substitute_mutator_t::substitutions);
363 MAYBE_UNUSED(max_substitutions);
364 return ret;
365 }
366
flatten_statements(const stmt_t & root)367 std::vector<stmt_t> flatten_statements(const stmt_t &root) {
368 stmt_flattener_t f;
369 f.visit(root);
370 return f.stmts;
371 }
372
inject_alloc_stmts(const stmt_t & stmt,const std::vector<stmt_t> & allocs,bool put_innermost)373 stmt_t inject_alloc_stmts(const stmt_t &stmt, const std::vector<stmt_t> &allocs,
374 bool put_innermost) {
375 alloc_injector_t injector(stmt, allocs, put_innermost);
376 return injector.mutate(stmt);
377 }
378
inject_let_stmts(const stmt_t & stmt,const std::vector<stmt_t> & lets)379 stmt_t inject_let_stmts(const stmt_t &stmt, const std::vector<stmt_t> &lets) {
380 stmt_t ret = stmt;
381 for (auto it = lets.rbegin(); it != lets.rend(); ++it) {
382 auto &let = it->as<let_t>();
383 ret = let_t::make(let.var, let.value, ret);
384 }
385 return ret;
386 }
387
abs(const expr_t & e)388 expr_t abs(const expr_t &e) {
389 ir_assert(is_const(e)) << e;
390 if (to_cpp<bool>(e >= 0)) return e;
391 return -e;
392 }
393
cast(const expr_t & e,const type_t & type,bool saturate)394 expr_t cast(const expr_t &e, const type_t &type, bool saturate) {
395 if (e.type() == type) return e;
396 return const_fold(cast_t::make(type, e, saturate));
397 }
398
is_zero(const expr_t & e)399 bool is_zero(const expr_t &e) {
400 if (!e.type().is_scalar()) return false;
401 return e.is_equal(to_expr(0, e.type()));
402 }
403
is_one(const expr_t & e)404 bool is_one(const expr_t &e) {
405 if (!e.type().is_scalar()) return false;
406 return e.is_equal(to_expr(1, e.type()));
407 }
408
is_minus_one(const expr_t & e)409 bool is_minus_one(const expr_t &e) {
410 if (!e.type().is_scalar()) return false;
411 return e.is_equal(to_expr(-1, e.type()));
412 }
413
is_const_broadcast(const expr_t & e)414 bool is_const_broadcast(const expr_t &e) {
415 auto *shuffle = e.as_ptr<shuffle_t>();
416 if (!shuffle) return false;
417 if (!shuffle->is_broadcast()) return false;
418 return is_const(shuffle->vec[0]);
419 }
420
is_const_broadcast(const expr_t & e,const expr_t & value)421 bool is_const_broadcast(const expr_t &e, const expr_t &value) {
422 if (!is_const_broadcast(e)) return false;
423 return e.as<shuffle_t>().vec[0].is_equal(value);
424 }
425
all_of(const expr_t & e,const expr_t & value)426 bool all_of(const expr_t &e, const expr_t &value) {
427 auto *shuffle = e.as_ptr<shuffle_t>();
428 if (!shuffle) return e.is_equal(value);
429 for (auto &i : shuffle->idx) {
430 if (!shuffle->vec[i].is_equal(value)) return false;
431 }
432 return true;
433 }
434
make_buffer(const std::string & name)435 expr_t make_buffer(const std::string &name) {
436 return var_t::make(type_t::byte_ptr(), name);
437 }
438
439 // Returns number of occurrences of `obj` in `root` (based on identity equality).
count_object(const object_t & root,const object_t & obj)440 int count_object(const object_t &root, const object_t &obj) {
441 ir_assert(!obj.is_empty());
442
443 std::vector<object_t> found;
444 do {
445 #define HANDLE_IR_OBJECT(type) \
446 if (obj.dispatch_type_id() == type::_dispatch_type_id()) { \
447 found = find_objects<type>(root); \
448 break; \
449 }
450
451 HANDLE_ALL_IR_OBJECTS()
452
453 #undef HANDLE_IR_OBJECT
454
455 ir_error_not_expected() << obj;
456 } while (false);
457
458 int ret = 0;
459 for (auto &f : found)
460 if (f.is_equal(obj)) ret++;
461 return ret;
462 }
463
contains_object(const object_t & root,const object_t & obj)464 bool contains_object(const object_t &root, const object_t &obj) {
465 ir_assert(is_var(obj)) << obj;
466 return count_object(root, obj) > 0;
467 }
468
find_stmt_groups(const object_t & root,const stmt_label_t & label)469 std::vector<stmt_t> find_stmt_groups(
470 const object_t &root, const stmt_label_t &label) {
471 auto groups = find_objects<stmt_group_t>(root);
472 std::vector<stmt_t> ret;
473 for (auto &g : groups) {
474 if (g.as<stmt_group_t>().label == label) ret.push_back(g);
475 }
476 return ret;
477 }
478
find_stmt_group(const object_t & root,const stmt_label_t & label)479 stmt_t find_stmt_group(const object_t &root, const stmt_label_t &label) {
480 auto groups = find_stmt_groups(root, label);
481 ir_assert(groups.size() == 1);
482 return groups[0];
483 }
484
get_stmt_body(const stmt_t & stmt)485 stmt_t get_stmt_body(const stmt_t &stmt) {
486 auto *alloc = stmt.as_ptr<alloc_t>();
487 if (alloc) return alloc->body;
488
489 auto *_for = stmt.as_ptr<for_t>();
490 if (_for) return _for->body;
491
492 auto *let = stmt.as_ptr<let_t>();
493 if (let) return let->body;
494
495 auto *group = stmt.as_ptr<stmt_group_t>();
496 if (group) return group->body;
497
498 return stmt;
499 }
500
replace_stmt_body(const stmt_t & stmt,const stmt_t & new_body)501 stmt_t replace_stmt_body(const stmt_t &stmt, const stmt_t &new_body) {
502 auto *alloc = stmt.as_ptr<alloc_t>();
503 if (alloc) {
504 return alloc_t::make(
505 alloc->buf, alloc->size, alloc->kind, alloc->attr, new_body);
506 }
507
508 auto *_for = stmt.as_ptr<for_t>();
509 if (_for) {
510 return for_t::make(
511 _for->var, _for->init, _for->bound, new_body, _for->unroll);
512 }
513
514 auto *let = stmt.as_ptr<let_t>();
515 if (let) { return let_t::make(let->var, let->value, new_body); }
516
517 auto *group = stmt.as_ptr<stmt_group_t>();
518 if (group) { return stmt_group_t::make(group->label, new_body); }
519
520 return new_body;
521 }
522
implies(const relation_t & other) const523 bool relation_t::implies(const relation_t &other) const {
524 ir_assert(var().is_same(other.var()));
525
526 if (op_kind() != other.op_kind()) return false;
527
528 auto A = to_cpp<int64_t>(rhs());
529 auto B = to_cpp<int64_t>(other.rhs());
530
531 switch (op_kind()) {
532 // (x > A) && (A >= B) => (x > B)
533 // (x >= A) && (A >= B) => (x >= B)
534 case op_kind_t::_gt:
535 case op_kind_t::_ge: return A >= B;
536 // (x < A) && (A <= B) => (x < B)
537 // (x <= A) && (A <= B) => (x <= B)
538 case op_kind_t::_lt:
539 case op_kind_t::_le: return A <= B;
540 default: ir_error_not_expected() << "Not implemented: " << expr_;
541 }
542 return false;
543 }
544
transform(const linear_transform_t & t,const expr_t & new_var)545 relation_t relation_t::transform(
546 const linear_transform_t &t, const expr_t &new_var) {
547 ir_assert(t.a == 1) << "Not implemented.";
548 return relation_t(binary_op_t::make(op_kind(), new_var, rhs() + t.b));
549 }
550
normalize(const expr_t & e)551 expr_t relation_t::normalize(const expr_t &e) {
552 ir_assert(is_relation_constraint(e)) << e;
553 auto &op = e.as<binary_op_t>();
554
555 auto op_kind = op.op_kind;
556 auto a = op.a;
557 auto b = op.b;
558
559 switch (op_kind) {
560 case op_kind_t::_lt:
561 op_kind = op_kind_t::_le;
562 b -= 1;
563 break;
564 case op_kind_t::_gt:
565 op_kind = op_kind_t::_ge;
566 b += 1;
567 break;
568 default: return e;
569 }
570 return binary_op_t::make(op_kind, a, b);
571 }
572
is_modulus_constraint(const expr_t & e)573 bool modulus_info_t::is_modulus_constraint(const expr_t &e) {
574 auto *binary_op = e.as_ptr<binary_op_t>();
575 if (!binary_op) return false;
576 if (!is_zero(binary_op->b)) return false;
577 if (binary_op->op_kind != op_kind_t::_eq) return false;
578
579 auto *mod_op = binary_op->a.as_ptr<binary_op_t>();
580 if (!mod_op) return false;
581 if (mod_op->op_kind != op_kind_t::_mod) return false;
582 if (!is_var(mod_op->a)) return false;
583 if (!is_const(mod_op->b)) return false;
584
585 return true;
586 }
587
find_bound_impl(const expr_t & e,bool is_low) const588 int64_t bound_finder_t::find_bound_impl(const expr_t &e, bool is_low) const {
589 int64_t def_bound = unlimited_bound(is_low);
590 if (is_const(e)) return to_cpp<int64_t>(e);
591 if (is_var(e)) {
592 auto it = relations_.find(e);
593 if (it == relations_.end()) return def_bound;
594
595 int64_t ret = def_bound;
596 for (auto &rel : it->second) {
597 bool is_ge = (rel.op_kind() == op_kind_t::_ge);
598 if (is_ge != is_low) continue;
599 if (is_ge) {
600 ret = std::max(to_cpp<int64_t>(rel.rhs()), ret);
601 } else {
602 ret = std::min(to_cpp<int64_t>(rel.rhs()), ret);
603 }
604 }
605 return ret;
606 }
607
608 auto *unary = e.as_ptr<unary_op_t>();
609 if (unary) {
610 ir_assert(unary->op_kind == op_kind_t::_minus) << e;
611 auto a = find_bound_impl(unary->a, !is_low);
612 if (!is_good_bound(a)) return 0;
613 return -a;
614 }
615
616 auto *binary = e.as_ptr<binary_op_t>();
617 if (binary) {
618 switch (binary->op_kind) {
619 case op_kind_t::_add: {
620 auto a = find_bound_impl(binary->a, is_low);
621 auto b = find_bound_impl(binary->b, is_low);
622 if (!is_good_bound(a) || !is_good_bound(b)) return def_bound;
623 return a + b;
624 }
625 case op_kind_t::_sub: {
626 auto a = find_bound_impl(binary->a, is_low);
627 auto b = find_bound_impl(binary->b, !is_low);
628 if (!is_good_bound(a) || !is_good_bound(b)) return def_bound;
629 return a - b;
630 }
631 case op_kind_t::_mul: {
632 auto a = binary->a;
633 auto b = binary->b;
634 if (!is_const(a) && is_const(b)) std::swap(a, b);
635 if (!is_const(a)) return def_bound;
636
637 auto a_const = to_cpp<int64_t>(a);
638
639 auto b_lo = find_low_bound(b);
640 auto b_hi = find_high_bound(b);
641 auto b_lo_ok = is_good_bound(b_lo);
642 auto b_hi_ok = is_good_bound(b_hi);
643
644 bool b_ge_0 = b_lo_ok && (b_lo >= 0);
645 bool b_le_0 = b_hi_ok && (b_hi <= 0);
646 bool b_same_sign = (b_ge_0 || b_le_0);
647
648 if (a_const >= 0 && b_same_sign) {
649 if (is_low && b_lo_ok) return a_const * b_lo;
650 if (b_hi_ok) return a_const * b_hi;
651 }
652
653 if (a_const <= 0 && b_same_sign) {
654 if (is_low && b_hi_ok) return a_const * b_hi;
655 if (b_lo_ok) return a_const * b_lo;
656 }
657 break;
658 }
659 case op_kind_t::_div: {
660 if (!is_const(binary->b)) return def_bound;
661
662 auto b = to_cpp<int64_t>(binary->b);
663 ir_assert(b != 0);
664
665 auto a = find_bound_impl(binary->a, b > 0 ? is_low : !is_low);
666 if (!is_good_bound(a)) return def_bound;
667
668 bool is_neg = ((a > 0) && (b < 0)) || ((a < 0) && (b > 0));
669
670 int64_t div_bound;
671 if (is_low != is_neg) {
672 // Truncate away from zero.
673 div_bound = utils::div_up(std::abs(a), std::abs(b));
674 } else {
675 // Truncate towards zero.
676 div_bound = std::abs(a) / std::abs(b);
677 }
678 if (is_neg) div_bound *= -1;
679 return div_bound;
680 }
681 case op_kind_t::_mod: {
682 if (is_low) return 0;
683 auto max_mod = find_bound_impl(binary->b, /*is_low=*/false);
684 if (!is_good_bound(max_mod)) return def_bound;
685 return max_mod - 1;
686 }
687 default: break;
688 }
689 }
690
691 return def_bound;
692 }
693
is_linear_var_transform(const expr_t & e,linear_transform_t & t)694 bool is_linear_var_transform(const expr_t &e, linear_transform_t &t) {
695 if (is_var(e)) {
696 t.x = e;
697 t.a = 1;
698 t.b = 0;
699 return true;
700 }
701
702 auto *binary_op = e.as_ptr<binary_op_t>();
703 if (!binary_op) return false;
704
705 auto vars = find_objects<var_t>(e);
706 if (vars.size() != 1) return false;
707
708 auto &var = vars[0];
709
710 // TODO: Extend to match multiplication: (a * var).
711 if (!utils::one_of(binary_op->op_kind, op_kind_t::_add, op_kind_t::_sub))
712 return false;
713
714 auto &a = binary_op->a;
715 auto &b = binary_op->b;
716
717 bool is_sub = (binary_op->op_kind == op_kind_t::_sub);
718
719 // var op b -> (t.a = 1, t.b = +/-b)
720 if (a.is_same(var) && is_const(b)) {
721 t.x = var;
722 t.a = 1;
723 t.b = (is_sub ? -1 : 1) * to_cpp<int>(b);
724 return true;
725 }
726
727 // a op var -> (t.a = +/-1, t.b = a)
728 if (is_const(a) && b.is_same(var)) {
729 t.x = var;
730 t.a = (is_sub ? -1 : 1);
731 t.b = to_cpp<int>(a);
732 return true;
733 }
734
735 return false;
736 }
737
add_constraint(const expr_t & e)738 void constraint_set_t::add_constraint(const expr_t &e) {
739 auto *shuffle = e.as_ptr<shuffle_t>();
740 if (shuffle) {
741 if (shuffle->is_broadcast()) add_constraint(shuffle->vec[0]);
742 return;
743 }
744
745 if (modulus_info_t::is_modulus_constraint(e)) {
746 modulus_info_t mi(e);
747 modulus_infos_[mi.var()].push_back(mi);
748 return;
749 }
750
751 if (relation_t::is_relation_constraint(e)) {
752 relation_t rel(e);
753 relations_[rel.var()].push_back(rel);
754 return;
755 }
756
757 // Propagate constraints from y for (x == y) equalities.
758 auto *binary_op = e.as_ptr<binary_op_t>();
759 if (binary_op && binary_op->op_kind == op_kind_t::_eq) {
760 auto &a = binary_op->a;
761 auto &b = binary_op->b;
762 linear_transform_t t;
763 if (is_var(a) && is_linear_var_transform(b, t)) {
764 // Relations.
765 auto r_it = relations_.find(t.x);
766 if (r_it != relations_.end()) {
767 for (auto &c : r_it->second) {
768 add_constraint(c.transform(t, a).expr());
769 }
770 }
771 // Modulus.
772 if (t.is_identity()) {
773 auto m_it = modulus_infos_.find(t.x);
774 if (m_it != modulus_infos_.end()) {
775 for (auto &c : m_it->second) {
776 add_constraint(substitute(c.expr(), b, a));
777 }
778 }
779 }
780 return;
781 }
782 }
783 }
784
is_single_value(const expr_t & e,expr_t & value) const785 bool constraint_set_t::is_single_value(const expr_t &e, expr_t &value) const {
786 ir_assert(is_var(e)) << e;
787 auto it = relations_.find(e);
788 if (it == relations_.end()) return false;
789
790 expr_t lo;
791 expr_t hi;
792 for (auto &rel : it->second) {
793 ir_assert(is_const(rel.rhs())) << rel;
794 bool do_break = false;
795 switch (rel.op_kind()) {
796 case op_kind_t::_eq:
797 lo = hi = rel.rhs();
798 do_break = true;
799 break;
800 case op_kind_t::_ge:
801 case op_kind_t::_gt: {
802 auto cur_lo = (rel.op_kind() == op_kind_t::_ge ? rel.rhs()
803 : rel.rhs() + 1);
804 if (lo.is_empty() || to_cpp<bool>(cur_lo > lo)) { lo = cur_lo; }
805 break;
806 }
807 case op_kind_t::_le:
808 case op_kind_t::_lt: {
809 auto cur_hi = (rel.op_kind() == op_kind_t::_le ? rel.rhs()
810 : rel.rhs() - 1);
811 if (hi.is_empty() || to_cpp<bool>(cur_hi < hi)) { hi = cur_hi; }
812 break;
813 }
814 default: ir_error_not_expected() << rel;
815 }
816 if (do_break) break;
817 }
818 bool ret = !lo.is_empty() && lo.is_equal(hi);
819 if (ret) value = lo;
820 return ret;
821 }
822
can_prove_impl(const expr_t & _e,bool do_simplify) const823 bool constraint_set_t::can_prove_impl(
824 const expr_t &_e, bool do_simplify) const {
825 auto e = _e;
826 if (is_const(e)) {
827 ir_assert(e.type() == type_t::_bool()) << e;
828 return to_cpp<bool>(e);
829 }
830
831 if (do_simplify) {
832 // These passes for comparison help to prove more inequalities.
833 e = simplify_cmp_move_const_to_rhs(e);
834 e = simplify_cmp_reduce_lhs_rhs(e);
835 e = simplify(e);
836 if (is_const(e)) {
837 ir_assert(e.type() == type_t::_bool()) << e;
838 return to_cpp<bool>(e);
839 }
840 }
841
842 if (modulus_info_t::is_modulus_constraint(e)) return can_prove_modulus(e);
843 if (relation_t::is_relation_constraint(e)) return can_prove_relation(e);
844
845 // Try to estimate bounds for compound relation.
846 if (try_prove_compound_relation(e)) return true;
847
848 // Can't prove.
849 return false;
850 }
851
max_proven_gcd(const expr_t & var) const852 int constraint_set_t::max_proven_gcd(const expr_t &var) const {
853 auto it = modulus_infos_.find(var);
854 if (it == modulus_infos_.end()) return 1;
855 int ret = 1;
856 for (auto &c : it->second) {
857 ret = math::lcm(ret, to_cpp<int>(c.mod()));
858 }
859 return ret;
860 }
861
unpack(std::vector<stmt_t> & init_stmts,constraint_set_t & cset,const expr_t & _e,const std::vector<unpack_dim_info_t> & infos)862 void unpack(std::vector<stmt_t> &init_stmts, constraint_set_t &cset,
863 const expr_t &_e, const std::vector<unpack_dim_info_t> &infos) {
864 int elems = 1;
865 for (auto &info : infos)
866 elems *= info.dim;
867 ir_assert(elems >= 1);
868
869 expr_t e = _e;
870 int rem_elems = elems;
871 for (auto &info : infos) {
872 auto &var = info.var;
873 int dim = info.dim;
874 int block = info.block;
875 expr_t value;
876 if (dim == 1) {
877 value = expr_t(0);
878 cset.add_constraint(var == 0);
879 } else {
880 value = block * (rem_elems > dim ? e % dim : e);
881 e = e / dim;
882 }
883 init_stmts.emplace_back(let_t::make(var, value));
884 if (dim > 1 && block > 1) cset.add_constraint(var % block == 0);
885 rem_elems /= dim;
886 }
887 }
888
889 } // namespace jit
890 } // namespace gpu
891 } // namespace impl
892 } // namespace dnnl
893