1 /*******************************************************************************
2 * Copyright 2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include <sstream>
18 
19 #include "common/math_utils.hpp"
20 #include "gpu/jit/conv/ir.hpp"
21 #include "gpu/jit/conv/ir_core.hpp"
22 
23 namespace dnnl {
24 namespace impl {
25 namespace gpu {
26 namespace jit {
27 
28 using namespace ir_utils;
29 
30 namespace {
31 
32 // Helper class to print IR objects.
33 class ir_printer_t : public ir_visitor_t {
34 public:
ir_printer_t(std::ostream & out)35     ir_printer_t(std::ostream &out) : out_(out) {}
36 
visit(const object_t & obj)37     void visit(const object_t &obj) override {
38         if (is_supported(obj)) {
39             ir_visitor_t::visit(obj);
40             return;
41         }
42         // Only expressions/functions are expected here.
43         out_ << obj.str();
44     }
45 
_visit(const alloc_t & obj)46     void _visit(const alloc_t &obj) override {
47         print_indent();
48         out_ << "alloc " << obj.buf.as<var_t>().name << "[" << obj.size
49              << "]\n";
50         visit(obj.body);
51     }
52 
_visit(const binary_op_t & obj)53     void _visit(const binary_op_t &obj) override {
54         if (utils::one_of(obj.op_kind, op_kind_t::_min, op_kind_t::_max)) {
55             out_ << to_string(obj.op_kind) << "(" << obj.a << ", " << obj.b
56                  << ")";
57             return;
58         }
59         out_ << "(";
60         visit(obj.a);
61         out_ << " " << to_string(obj.op_kind) << " ";
62         visit(obj.b);
63         out_ << ")";
64     }
65 
_visit(const bool_imm_t & obj)66     void _visit(const bool_imm_t &obj) override {
67         out_ << (obj.value ? "true" : "false");
68     }
69 
_visit(const cast_t & obj)70     void _visit(const cast_t &obj) override {
71         out_ << obj.type;
72         if (obj.saturate) out_ << ".sat";
73         out_ << "(" << obj.expr << ")";
74     }
75 
_visit(const float_imm_t & obj)76     void _visit(const float_imm_t &obj) override { out_ << obj.value; }
77 
_visit(const for_t & obj)78     void _visit(const for_t &obj) override {
79         print_indent();
80         out_ << "for (" << obj.var << " = " << obj.init << "; " << obj.var
81              << " < " << obj.bound << "; " << obj.var << "++) ";
82         if (obj.unroll != 1) out_ << "[unroll: " << obj.unroll << "] ";
83         out_ << "{\n";
84         add_indent();
85         visit(obj.body);
86         remove_indent();
87         print_indent();
88         out_ << "}\n";
89     }
90 
_visit(const func_call_t & obj)91     void _visit(const func_call_t &obj) override {
92         print_indent();
93         out_ << obj.func << "(" << make_seq_print_helper(obj.args) << ")";
94         if (!obj.attr.is_empty()) out_ << " " << obj.attr;
95         out_ << "\n";
96     }
97 
_visit(const func_impl_t & obj)98     void _visit(const func_impl_t &obj) override { out_ << obj.str(); }
99 
_visit(const if_t & obj)100     void _visit(const if_t &obj) override {
101         print_indent();
102         out_ << "if (" << strip_parens(obj.cond.str()) << ") {\n";
103         add_indent();
104         visit(obj.body);
105         remove_indent();
106         print_indent();
107         if (obj.else_body.is_empty()) {
108             out_ << "}\n";
109             return;
110         }
111         out_ << "} else {\n";
112         add_indent();
113         visit(obj.else_body);
114         remove_indent();
115         print_indent();
116         out_ << "}\n";
117     }
118 
_visit(const iif_t & obj)119     void _visit(const iif_t &obj) override {
120         out_ << "(" << obj.cond << " ? " << obj.true_expr << " : "
121              << obj.false_expr << ")";
122     }
123 
_visit(const int_imm_t & obj)124     void _visit(const int_imm_t &obj) override {
125         out_ << std::to_string(obj.value);
126     }
127 
_visit(const let_t & obj)128     void _visit(const let_t &obj) override {
129         print_indent();
130         out_ << obj.var << "." << obj.var.type() << " = " << obj.value << "\n";
131         visit(obj.body);
132     }
133 
_visit(const load_t & obj)134     void _visit(const load_t &obj) override {
135         out_ << obj.buf;
136         if (obj.has_default_stride()) {
137             out_ << "." << obj.type << "(" << obj.off / obj.type.size() << ")";
138         } else {
139             out_ << "[" << obj.off << "]." << obj.type;
140             out_ << "<" << obj.stride << ">";
141         }
142     }
143 
_visit(const ptr_t & obj)144     void _visit(const ptr_t &obj) override {
145         out_ << obj.base << "[" << obj.off << "]";
146     }
147 
_visit(const shuffle_t & obj)148     void _visit(const shuffle_t &obj) override {
149         if (obj.is_broadcast()) {
150             out_ << "bcast" << obj.elems() << "(" << obj.vec[0] << ")";
151             return;
152         }
153         std::vector<expr_t> vec_all;
154         for (auto &v : obj.vec) {
155             for (int i = 0; i < v.type().elems(); i++)
156                 vec_all.push_back(v);
157         }
158         int elems = obj.type.elems();
159         out_ << "(";
160         for (int i = 0; i < elems; i++) {
161             int idx = obj.idx[i];
162             auto &v = vec_all[idx];
163             int v_elems = v.type().elems();
164             out_ << v;
165             if (v_elems != 1) out_ << "[" << idx << "]";
166             if (i != elems - 1) out_ << ", ";
167         }
168         out_ << ")";
169     }
170 
_visit(const stmt_group_t & obj)171     void _visit(const stmt_group_t &obj) override {
172         print_indent();
173         out_ << obj.label << " {\n";
174         add_indent();
175         visit(obj.body);
176         remove_indent();
177         print_indent();
178         out_ << "}\n";
179         return;
180     }
181 
_visit(const stmt_seq_t & obj)182     void _visit(const stmt_seq_t &obj) override {
183         visit(obj.head);
184         visit(obj.tail);
185     }
186 
_visit(const store_t & obj)187     void _visit(const store_t &obj) override {
188         print_indent();
189         out_ << load_t::make(obj.value.type(), obj.buf, obj.off, obj.stride);
190         out_ << " = " << obj.value;
191         if (!obj.mask.is_empty()) out_ << " [masked]";
192         out_ << "\n";
193     }
194 
_visit(const ternary_op_t & obj)195     void _visit(const ternary_op_t &obj) override {
196         out_ << to_string(obj.op_kind) << "(" << obj.a << ", " << obj.b << ", "
197              << obj.c << ")";
198         return;
199     }
200 
_visit(const unary_op_t & obj)201     void _visit(const unary_op_t &obj) override {
202         out_ << to_string(obj.op_kind);
203         visit(obj.a);
204     }
205 
_visit(const var_t & obj)206     void _visit(const var_t &obj) override { out_ << obj.name; }
207 
208 private:
strip_parens(const std::string & s)209     static std::string strip_parens(const std::string &s) {
210         if (s.size() < 2 || s[0] != '(' || s[s.size() - 1] != ')') return s;
211         auto ret = s;
212         ret.resize(s.size() - 1);
213         return ret.substr(1);
214     }
215 
print_indent()216     void print_indent() {
217         for (int i = 0; i < indent_; i++)
218             out_ << prefix_;
219     }
220 
add_indent()221     void add_indent() { indent_++; }
remove_indent()222     void remove_indent() { indent_--; }
223 
224     std::ostream &out_;
225     int indent_ = 0;
226 
227     std::string prefix_ = "  ";
228 };
229 
230 class substitute_mutator_t : public ir_mutator_t {
231 public:
substitute_mutator_t(const object_t & from,const object_t & to)232     substitute_mutator_t(const object_t &from, const object_t &to)
233         : from_(from), to_(to) {}
234 
substitutions() const235     int substitutions() const { return substitutions_; }
236 
237 #define HANDLE_IR_OBJECT(type) \
238     object_t _mutate(const type &obj) override { \
239         auto *this_mutator = (substitute_mutator_t *)this; \
240         if (this_mutator->from_.impl() == (const object_impl_t *)&obj) { \
241             this_mutator->substitutions_++; \
242             return this_mutator->to_; \
243         } \
244         return ir_mutator_t::_mutate(obj); \
245     };
246 
247     HANDLE_MUTATE_TARGETS()
248 
249 #undef HANDLE_IR_OBJECT
250 
251 private:
252     object_t from_;
253     object_t to_;
254 
255     int substitutions_ = 0;
256 };
257 
258 class stmt_flattener_t : public ir_visitor_t {
259 public:
260 #define HANDLE_IR_OBJECT(type) \
261     void _visit(const type &obj) { \
262         size_t old_size = stmts.size(); \
263         ir_visitor_t::_visit(obj); \
264         if (stmts.size() > old_size) return; \
265         if (obj.is_stmt()) stmts.push_back(obj); \
266     }
267 
268     HANDLE_ALL_IR_OBJECTS()
269 
270 #undef HANDLE_IR_OBJECT
271 
272     std::vector<stmt_t> stmts;
273 };
274 
275 class alloc_injector_t : public ir_mutator_t {
276 public:
alloc_injector_t(const stmt_t & root,const std::vector<stmt_t> & allocs,bool put_innermost)277     alloc_injector_t(const stmt_t &root, const std::vector<stmt_t> &allocs,
278             bool put_innermost)
279         : root_(root), put_innermost_(put_innermost), allocs_(allocs) {
280         for (auto &_a : allocs) {
281             auto &a = _a.as<alloc_t>();
282             if (a.kind != alloc_kind_t::global) ir_assert(a.size > 0) << _a;
283             alloc_map_.insert({a.buf, _a});
284         }
285         mutate(root_);
286         buf_total_refs_ = buf_cur_refs_;
287         for (auto &kv : buf_cur_refs_)
288             kv.second = 0;
289         in_ctor_ = false;
290     }
291 
292 #define HANDLE_IR_OBJECT(type) \
293     object_t _mutate(const type &obj) override { return mutate_stmt(obj); }
294 
HANDLE_STMT_IR_OBJECTS()295     HANDLE_STMT_IR_OBJECTS()
296 
297 #undef HANDLE_IR_OBJECT
298     object_t _mutate(const var_t &obj) override {
299         if (alloc_map_.find(obj) != alloc_map_.end()) buf_cur_refs_[obj]++;
300         return obj;
301     }
302 
303 private:
304     template <typename T>
mutate_stmt(const T & obj)305     object_t mutate_stmt(const T &obj) {
306         if (in_ctor_) return ir_mutator_t::_mutate(obj);
307         object_t new_obj = obj;
308         object_set_t<expr_t> undef_bufs;
309         if (put_innermost_) {
310             for (auto &kv : buf_cur_refs_)
311                 if (kv.second == 0) undef_bufs.insert(kv.first);
312             new_obj = ir_mutator_t::_mutate(obj);
313         }
314         for (auto &a : allocs_) {
315             auto it = alloc_map_.find(a.as<alloc_t>().buf);
316             auto &buf = it->first;
317             if (it->second.is_empty()) continue; // Already injected.
318             bool do_inject = false;
319             if (put_innermost_) {
320                 int cur_refs = buf_cur_refs_[buf];
321                 int total_refs = buf_total_refs_[buf];
322                 bool was_undef = (undef_bufs.count(buf) != 0);
323                 do_inject = was_undef && (cur_refs == total_refs);
324             } else {
325                 do_inject = root_.is_same(obj);
326             }
327             if (do_inject) {
328                 auto &a = it->second.as<alloc_t>();
329                 new_obj = alloc_t::make(a.buf, a.size, a.kind, a.attr, new_obj);
330                 it->second = stmt_t();
331             }
332         }
333         return new_obj;
334     }
335 
336     bool in_ctor_ = true;
337     const stmt_t &root_;
338     bool put_innermost_;
339     std::vector<stmt_t> allocs_;
340     object_map_t<expr_t, stmt_t> alloc_map_;
341     object_map_t<expr_t, int> buf_total_refs_;
342     object_map_t<expr_t, int> buf_cur_refs_;
343 };
344 
345 } // namespace
346 
str() const347 std::string object_impl_t::str() const {
348     std::ostringstream oss;
349     ir_printer_t printer(oss);
350     ir_assert(printer.is_supported(this));
351     printer.visit(this);
352     return oss.str();
353 }
354 
substitute(const object_t & root,const object_t & from,const object_t & to,int max_substitutions)355 object_t substitute(const object_t &root, const object_t &from,
356         const object_t &to, int max_substitutions) {
357     if (to.is_same(from)) return root;
358     substitute_mutator_t sm(from, to);
359     auto ret = sm.mutate(root);
360     ir_assert(sm.substitutions() <= max_substitutions)
361             << "Unexpected number of substitutions.";
362     MAYBE_UNUSED(&substitute_mutator_t::substitutions);
363     MAYBE_UNUSED(max_substitutions);
364     return ret;
365 }
366 
flatten_statements(const stmt_t & root)367 std::vector<stmt_t> flatten_statements(const stmt_t &root) {
368     stmt_flattener_t f;
369     f.visit(root);
370     return f.stmts;
371 }
372 
inject_alloc_stmts(const stmt_t & stmt,const std::vector<stmt_t> & allocs,bool put_innermost)373 stmt_t inject_alloc_stmts(const stmt_t &stmt, const std::vector<stmt_t> &allocs,
374         bool put_innermost) {
375     alloc_injector_t injector(stmt, allocs, put_innermost);
376     return injector.mutate(stmt);
377 }
378 
inject_let_stmts(const stmt_t & stmt,const std::vector<stmt_t> & lets)379 stmt_t inject_let_stmts(const stmt_t &stmt, const std::vector<stmt_t> &lets) {
380     stmt_t ret = stmt;
381     for (auto it = lets.rbegin(); it != lets.rend(); ++it) {
382         auto &let = it->as<let_t>();
383         ret = let_t::make(let.var, let.value, ret);
384     }
385     return ret;
386 }
387 
abs(const expr_t & e)388 expr_t abs(const expr_t &e) {
389     ir_assert(is_const(e)) << e;
390     if (to_cpp<bool>(e >= 0)) return e;
391     return -e;
392 }
393 
cast(const expr_t & e,const type_t & type,bool saturate)394 expr_t cast(const expr_t &e, const type_t &type, bool saturate) {
395     if (e.type() == type) return e;
396     return const_fold(cast_t::make(type, e, saturate));
397 }
398 
is_zero(const expr_t & e)399 bool is_zero(const expr_t &e) {
400     if (!e.type().is_scalar()) return false;
401     return e.is_equal(to_expr(0, e.type()));
402 }
403 
is_one(const expr_t & e)404 bool is_one(const expr_t &e) {
405     if (!e.type().is_scalar()) return false;
406     return e.is_equal(to_expr(1, e.type()));
407 }
408 
is_minus_one(const expr_t & e)409 bool is_minus_one(const expr_t &e) {
410     if (!e.type().is_scalar()) return false;
411     return e.is_equal(to_expr(-1, e.type()));
412 }
413 
is_const_broadcast(const expr_t & e)414 bool is_const_broadcast(const expr_t &e) {
415     auto *shuffle = e.as_ptr<shuffle_t>();
416     if (!shuffle) return false;
417     if (!shuffle->is_broadcast()) return false;
418     return is_const(shuffle->vec[0]);
419 }
420 
is_const_broadcast(const expr_t & e,const expr_t & value)421 bool is_const_broadcast(const expr_t &e, const expr_t &value) {
422     if (!is_const_broadcast(e)) return false;
423     return e.as<shuffle_t>().vec[0].is_equal(value);
424 }
425 
all_of(const expr_t & e,const expr_t & value)426 bool all_of(const expr_t &e, const expr_t &value) {
427     auto *shuffle = e.as_ptr<shuffle_t>();
428     if (!shuffle) return e.is_equal(value);
429     for (auto &i : shuffle->idx) {
430         if (!shuffle->vec[i].is_equal(value)) return false;
431     }
432     return true;
433 }
434 
make_buffer(const std::string & name)435 expr_t make_buffer(const std::string &name) {
436     return var_t::make(type_t::byte_ptr(), name);
437 }
438 
439 // Returns number of occurrences of `obj` in `root` (based on identity equality).
count_object(const object_t & root,const object_t & obj)440 int count_object(const object_t &root, const object_t &obj) {
441     ir_assert(!obj.is_empty());
442 
443     std::vector<object_t> found;
444     do {
445 #define HANDLE_IR_OBJECT(type) \
446     if (obj.dispatch_type_id() == type::_dispatch_type_id()) { \
447         found = find_objects<type>(root); \
448         break; \
449     }
450 
451         HANDLE_ALL_IR_OBJECTS()
452 
453 #undef HANDLE_IR_OBJECT
454 
455         ir_error_not_expected() << obj;
456     } while (false);
457 
458     int ret = 0;
459     for (auto &f : found)
460         if (f.is_equal(obj)) ret++;
461     return ret;
462 }
463 
contains_object(const object_t & root,const object_t & obj)464 bool contains_object(const object_t &root, const object_t &obj) {
465     ir_assert(is_var(obj)) << obj;
466     return count_object(root, obj) > 0;
467 }
468 
find_stmt_groups(const object_t & root,const stmt_label_t & label)469 std::vector<stmt_t> find_stmt_groups(
470         const object_t &root, const stmt_label_t &label) {
471     auto groups = find_objects<stmt_group_t>(root);
472     std::vector<stmt_t> ret;
473     for (auto &g : groups) {
474         if (g.as<stmt_group_t>().label == label) ret.push_back(g);
475     }
476     return ret;
477 }
478 
find_stmt_group(const object_t & root,const stmt_label_t & label)479 stmt_t find_stmt_group(const object_t &root, const stmt_label_t &label) {
480     auto groups = find_stmt_groups(root, label);
481     ir_assert(groups.size() == 1);
482     return groups[0];
483 }
484 
get_stmt_body(const stmt_t & stmt)485 stmt_t get_stmt_body(const stmt_t &stmt) {
486     auto *alloc = stmt.as_ptr<alloc_t>();
487     if (alloc) return alloc->body;
488 
489     auto *_for = stmt.as_ptr<for_t>();
490     if (_for) return _for->body;
491 
492     auto *let = stmt.as_ptr<let_t>();
493     if (let) return let->body;
494 
495     auto *group = stmt.as_ptr<stmt_group_t>();
496     if (group) return group->body;
497 
498     return stmt;
499 }
500 
replace_stmt_body(const stmt_t & stmt,const stmt_t & new_body)501 stmt_t replace_stmt_body(const stmt_t &stmt, const stmt_t &new_body) {
502     auto *alloc = stmt.as_ptr<alloc_t>();
503     if (alloc) {
504         return alloc_t::make(
505                 alloc->buf, alloc->size, alloc->kind, alloc->attr, new_body);
506     }
507 
508     auto *_for = stmt.as_ptr<for_t>();
509     if (_for) {
510         return for_t::make(
511                 _for->var, _for->init, _for->bound, new_body, _for->unroll);
512     }
513 
514     auto *let = stmt.as_ptr<let_t>();
515     if (let) { return let_t::make(let->var, let->value, new_body); }
516 
517     auto *group = stmt.as_ptr<stmt_group_t>();
518     if (group) { return stmt_group_t::make(group->label, new_body); }
519 
520     return new_body;
521 }
522 
implies(const relation_t & other) const523 bool relation_t::implies(const relation_t &other) const {
524     ir_assert(var().is_same(other.var()));
525 
526     if (op_kind() != other.op_kind()) return false;
527 
528     auto A = to_cpp<int64_t>(rhs());
529     auto B = to_cpp<int64_t>(other.rhs());
530 
531     switch (op_kind()) {
532         // (x > A) && (A >= B) => (x > B)
533         // (x >= A) && (A >= B) => (x >= B)
534         case op_kind_t::_gt:
535         case op_kind_t::_ge: return A >= B;
536         // (x < A) && (A <= B) => (x < B)
537         // (x <= A) && (A <= B) => (x <= B)
538         case op_kind_t::_lt:
539         case op_kind_t::_le: return A <= B;
540         default: ir_error_not_expected() << "Not implemented: " << expr_;
541     }
542     return false;
543 }
544 
transform(const linear_transform_t & t,const expr_t & new_var)545 relation_t relation_t::transform(
546         const linear_transform_t &t, const expr_t &new_var) {
547     ir_assert(t.a == 1) << "Not implemented.";
548     return relation_t(binary_op_t::make(op_kind(), new_var, rhs() + t.b));
549 }
550 
normalize(const expr_t & e)551 expr_t relation_t::normalize(const expr_t &e) {
552     ir_assert(is_relation_constraint(e)) << e;
553     auto &op = e.as<binary_op_t>();
554 
555     auto op_kind = op.op_kind;
556     auto a = op.a;
557     auto b = op.b;
558 
559     switch (op_kind) {
560         case op_kind_t::_lt:
561             op_kind = op_kind_t::_le;
562             b -= 1;
563             break;
564         case op_kind_t::_gt:
565             op_kind = op_kind_t::_ge;
566             b += 1;
567             break;
568         default: return e;
569     }
570     return binary_op_t::make(op_kind, a, b);
571 }
572 
is_modulus_constraint(const expr_t & e)573 bool modulus_info_t::is_modulus_constraint(const expr_t &e) {
574     auto *binary_op = e.as_ptr<binary_op_t>();
575     if (!binary_op) return false;
576     if (!is_zero(binary_op->b)) return false;
577     if (binary_op->op_kind != op_kind_t::_eq) return false;
578 
579     auto *mod_op = binary_op->a.as_ptr<binary_op_t>();
580     if (!mod_op) return false;
581     if (mod_op->op_kind != op_kind_t::_mod) return false;
582     if (!is_var(mod_op->a)) return false;
583     if (!is_const(mod_op->b)) return false;
584 
585     return true;
586 }
587 
find_bound_impl(const expr_t & e,bool is_low) const588 int64_t bound_finder_t::find_bound_impl(const expr_t &e, bool is_low) const {
589     int64_t def_bound = unlimited_bound(is_low);
590     if (is_const(e)) return to_cpp<int64_t>(e);
591     if (is_var(e)) {
592         auto it = relations_.find(e);
593         if (it == relations_.end()) return def_bound;
594 
595         int64_t ret = def_bound;
596         for (auto &rel : it->second) {
597             bool is_ge = (rel.op_kind() == op_kind_t::_ge);
598             if (is_ge != is_low) continue;
599             if (is_ge) {
600                 ret = std::max(to_cpp<int64_t>(rel.rhs()), ret);
601             } else {
602                 ret = std::min(to_cpp<int64_t>(rel.rhs()), ret);
603             }
604         }
605         return ret;
606     }
607 
608     auto *unary = e.as_ptr<unary_op_t>();
609     if (unary) {
610         ir_assert(unary->op_kind == op_kind_t::_minus) << e;
611         auto a = find_bound_impl(unary->a, !is_low);
612         if (!is_good_bound(a)) return 0;
613         return -a;
614     }
615 
616     auto *binary = e.as_ptr<binary_op_t>();
617     if (binary) {
618         switch (binary->op_kind) {
619             case op_kind_t::_add: {
620                 auto a = find_bound_impl(binary->a, is_low);
621                 auto b = find_bound_impl(binary->b, is_low);
622                 if (!is_good_bound(a) || !is_good_bound(b)) return def_bound;
623                 return a + b;
624             }
625             case op_kind_t::_sub: {
626                 auto a = find_bound_impl(binary->a, is_low);
627                 auto b = find_bound_impl(binary->b, !is_low);
628                 if (!is_good_bound(a) || !is_good_bound(b)) return def_bound;
629                 return a - b;
630             }
631             case op_kind_t::_mul: {
632                 auto a = binary->a;
633                 auto b = binary->b;
634                 if (!is_const(a) && is_const(b)) std::swap(a, b);
635                 if (!is_const(a)) return def_bound;
636 
637                 auto a_const = to_cpp<int64_t>(a);
638 
639                 auto b_lo = find_low_bound(b);
640                 auto b_hi = find_high_bound(b);
641                 auto b_lo_ok = is_good_bound(b_lo);
642                 auto b_hi_ok = is_good_bound(b_hi);
643 
644                 bool b_ge_0 = b_lo_ok && (b_lo >= 0);
645                 bool b_le_0 = b_hi_ok && (b_hi <= 0);
646                 bool b_same_sign = (b_ge_0 || b_le_0);
647 
648                 if (a_const >= 0 && b_same_sign) {
649                     if (is_low && b_lo_ok) return a_const * b_lo;
650                     if (b_hi_ok) return a_const * b_hi;
651                 }
652 
653                 if (a_const <= 0 && b_same_sign) {
654                     if (is_low && b_hi_ok) return a_const * b_hi;
655                     if (b_lo_ok) return a_const * b_lo;
656                 }
657                 break;
658             }
659             case op_kind_t::_div: {
660                 if (!is_const(binary->b)) return def_bound;
661 
662                 auto b = to_cpp<int64_t>(binary->b);
663                 ir_assert(b != 0);
664 
665                 auto a = find_bound_impl(binary->a, b > 0 ? is_low : !is_low);
666                 if (!is_good_bound(a)) return def_bound;
667 
668                 bool is_neg = ((a > 0) && (b < 0)) || ((a < 0) && (b > 0));
669 
670                 int64_t div_bound;
671                 if (is_low != is_neg) {
672                     // Truncate away from zero.
673                     div_bound = utils::div_up(std::abs(a), std::abs(b));
674                 } else {
675                     // Truncate towards zero.
676                     div_bound = std::abs(a) / std::abs(b);
677                 }
678                 if (is_neg) div_bound *= -1;
679                 return div_bound;
680             }
681             case op_kind_t::_mod: {
682                 if (is_low) return 0;
683                 auto max_mod = find_bound_impl(binary->b, /*is_low=*/false);
684                 if (!is_good_bound(max_mod)) return def_bound;
685                 return max_mod - 1;
686             }
687             default: break;
688         }
689     }
690 
691     return def_bound;
692 }
693 
is_linear_var_transform(const expr_t & e,linear_transform_t & t)694 bool is_linear_var_transform(const expr_t &e, linear_transform_t &t) {
695     if (is_var(e)) {
696         t.x = e;
697         t.a = 1;
698         t.b = 0;
699         return true;
700     }
701 
702     auto *binary_op = e.as_ptr<binary_op_t>();
703     if (!binary_op) return false;
704 
705     auto vars = find_objects<var_t>(e);
706     if (vars.size() != 1) return false;
707 
708     auto &var = vars[0];
709 
710     // TODO: Extend to match multiplication: (a * var).
711     if (!utils::one_of(binary_op->op_kind, op_kind_t::_add, op_kind_t::_sub))
712         return false;
713 
714     auto &a = binary_op->a;
715     auto &b = binary_op->b;
716 
717     bool is_sub = (binary_op->op_kind == op_kind_t::_sub);
718 
719     // var op b -> (t.a = 1, t.b = +/-b)
720     if (a.is_same(var) && is_const(b)) {
721         t.x = var;
722         t.a = 1;
723         t.b = (is_sub ? -1 : 1) * to_cpp<int>(b);
724         return true;
725     }
726 
727     // a op var -> (t.a = +/-1, t.b = a)
728     if (is_const(a) && b.is_same(var)) {
729         t.x = var;
730         t.a = (is_sub ? -1 : 1);
731         t.b = to_cpp<int>(a);
732         return true;
733     }
734 
735     return false;
736 }
737 
add_constraint(const expr_t & e)738 void constraint_set_t::add_constraint(const expr_t &e) {
739     auto *shuffle = e.as_ptr<shuffle_t>();
740     if (shuffle) {
741         if (shuffle->is_broadcast()) add_constraint(shuffle->vec[0]);
742         return;
743     }
744 
745     if (modulus_info_t::is_modulus_constraint(e)) {
746         modulus_info_t mi(e);
747         modulus_infos_[mi.var()].push_back(mi);
748         return;
749     }
750 
751     if (relation_t::is_relation_constraint(e)) {
752         relation_t rel(e);
753         relations_[rel.var()].push_back(rel);
754         return;
755     }
756 
757     // Propagate constraints from y for (x == y) equalities.
758     auto *binary_op = e.as_ptr<binary_op_t>();
759     if (binary_op && binary_op->op_kind == op_kind_t::_eq) {
760         auto &a = binary_op->a;
761         auto &b = binary_op->b;
762         linear_transform_t t;
763         if (is_var(a) && is_linear_var_transform(b, t)) {
764             // Relations.
765             auto r_it = relations_.find(t.x);
766             if (r_it != relations_.end()) {
767                 for (auto &c : r_it->second) {
768                     add_constraint(c.transform(t, a).expr());
769                 }
770             }
771             // Modulus.
772             if (t.is_identity()) {
773                 auto m_it = modulus_infos_.find(t.x);
774                 if (m_it != modulus_infos_.end()) {
775                     for (auto &c : m_it->second) {
776                         add_constraint(substitute(c.expr(), b, a));
777                     }
778                 }
779             }
780             return;
781         }
782     }
783 }
784 
is_single_value(const expr_t & e,expr_t & value) const785 bool constraint_set_t::is_single_value(const expr_t &e, expr_t &value) const {
786     ir_assert(is_var(e)) << e;
787     auto it = relations_.find(e);
788     if (it == relations_.end()) return false;
789 
790     expr_t lo;
791     expr_t hi;
792     for (auto &rel : it->second) {
793         ir_assert(is_const(rel.rhs())) << rel;
794         bool do_break = false;
795         switch (rel.op_kind()) {
796             case op_kind_t::_eq:
797                 lo = hi = rel.rhs();
798                 do_break = true;
799                 break;
800             case op_kind_t::_ge:
801             case op_kind_t::_gt: {
802                 auto cur_lo = (rel.op_kind() == op_kind_t::_ge ? rel.rhs()
803                                                                : rel.rhs() + 1);
804                 if (lo.is_empty() || to_cpp<bool>(cur_lo > lo)) { lo = cur_lo; }
805                 break;
806             }
807             case op_kind_t::_le:
808             case op_kind_t::_lt: {
809                 auto cur_hi = (rel.op_kind() == op_kind_t::_le ? rel.rhs()
810                                                                : rel.rhs() - 1);
811                 if (hi.is_empty() || to_cpp<bool>(cur_hi < hi)) { hi = cur_hi; }
812                 break;
813             }
814             default: ir_error_not_expected() << rel;
815         }
816         if (do_break) break;
817     }
818     bool ret = !lo.is_empty() && lo.is_equal(hi);
819     if (ret) value = lo;
820     return ret;
821 }
822 
can_prove_impl(const expr_t & _e,bool do_simplify) const823 bool constraint_set_t::can_prove_impl(
824         const expr_t &_e, bool do_simplify) const {
825     auto e = _e;
826     if (is_const(e)) {
827         ir_assert(e.type() == type_t::_bool()) << e;
828         return to_cpp<bool>(e);
829     }
830 
831     if (do_simplify) {
832         // These passes for comparison help to prove more inequalities.
833         e = simplify_cmp_move_const_to_rhs(e);
834         e = simplify_cmp_reduce_lhs_rhs(e);
835         e = simplify(e);
836         if (is_const(e)) {
837             ir_assert(e.type() == type_t::_bool()) << e;
838             return to_cpp<bool>(e);
839         }
840     }
841 
842     if (modulus_info_t::is_modulus_constraint(e)) return can_prove_modulus(e);
843     if (relation_t::is_relation_constraint(e)) return can_prove_relation(e);
844 
845     // Try to estimate bounds for compound relation.
846     if (try_prove_compound_relation(e)) return true;
847 
848     // Can't prove.
849     return false;
850 }
851 
max_proven_gcd(const expr_t & var) const852 int constraint_set_t::max_proven_gcd(const expr_t &var) const {
853     auto it = modulus_infos_.find(var);
854     if (it == modulus_infos_.end()) return 1;
855     int ret = 1;
856     for (auto &c : it->second) {
857         ret = math::lcm(ret, to_cpp<int>(c.mod()));
858     }
859     return ret;
860 }
861 
unpack(std::vector<stmt_t> & init_stmts,constraint_set_t & cset,const expr_t & _e,const std::vector<unpack_dim_info_t> & infos)862 void unpack(std::vector<stmt_t> &init_stmts, constraint_set_t &cset,
863         const expr_t &_e, const std::vector<unpack_dim_info_t> &infos) {
864     int elems = 1;
865     for (auto &info : infos)
866         elems *= info.dim;
867     ir_assert(elems >= 1);
868 
869     expr_t e = _e;
870     int rem_elems = elems;
871     for (auto &info : infos) {
872         auto &var = info.var;
873         int dim = info.dim;
874         int block = info.block;
875         expr_t value;
876         if (dim == 1) {
877             value = expr_t(0);
878             cset.add_constraint(var == 0);
879         } else {
880             value = block * (rem_elems > dim ? e % dim : e);
881             e = e / dim;
882         }
883         init_stmts.emplace_back(let_t::make(var, value));
884         if (dim > 1 && block > 1) cset.add_constraint(var % block == 0);
885         rem_elems /= dim;
886     }
887 }
888 
889 } // namespace jit
890 } // namespace gpu
891 } // namespace impl
892 } // namespace dnnl
893