1 /*******************************************************************************
2 * Copyright 2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef GPU_JIT_CONV_IR_HPP
18 #define GPU_JIT_CONV_IR_HPP
19 
20 #include <algorithm>
21 #include <mutex>
22 #include <thread>
23 #include <vector>
24 
25 #include "gpu/jit/conv/ir_core.hpp"
26 
27 namespace dnnl {
28 namespace impl {
29 namespace gpu {
30 namespace jit {
31 
32 // Helper class to walk through IR tree.
33 class ir_visitor_t {
34 public:
35     using dispatch_func_type = void (*)(ir_visitor_t *, const object_impl_t &);
36 
37     virtual ~ir_visitor_t() = default;
38 
visit(const object_t & obj)39     virtual void visit(const object_t &obj) { dispatch(obj.impl()); }
40 
41     template <typename T>
visit(const std::vector<T> & v)42     void visit(const std::vector<T> &v) {
43         for (auto &e : v)
44             visit(e);
45     }
46 
pre_visit(const object_impl_t & obj)47     virtual void pre_visit(const object_impl_t &obj) {}
post_visit(const object_impl_t & obj)48     virtual void post_visit(const object_impl_t &obj) {}
49 
50     // To catch missing _visit() handlers in ir_visitor_t.
_visit(const object_impl_t & obj)51     virtual void _visit(const object_impl_t &obj) {
52         ir_error_not_expected() << "Can't handle type: " << object_t(obj);
53     }
54 
55 #define DECL_VISIT_LEAF(name) \
56     virtual void _visit(const name &obj) {}
57 
58     DECL_VISIT_LEAF(bool_imm_t)
DECL_VISIT_LEAF(float_imm_t)59     DECL_VISIT_LEAF(float_imm_t)
60     DECL_VISIT_LEAF(func_impl_t)
61     DECL_VISIT_LEAF(int_imm_t)
62     DECL_VISIT_LEAF(var_t)
63 
64 #undef DECL_VISIT_LEAF
65 
66     virtual void _visit(const alloc_t &obj) {
67         visit(obj.buf);
68         visit(obj.body);
69     }
70 
_visit(const binary_op_t & obj)71     virtual void _visit(const binary_op_t &obj) {
72         visit(obj.a);
73         visit(obj.b);
74     }
75 
_visit(const cast_t & obj)76     virtual void _visit(const cast_t &obj) { visit(obj.expr); }
77 
_visit(const for_t & obj)78     virtual void _visit(const for_t &obj) {
79         visit(obj.var);
80         visit(obj.init);
81         visit(obj.bound);
82         visit(obj.body);
83     }
84 
_visit(const func_call_t & obj)85     virtual void _visit(const func_call_t &obj) {
86         visit(obj.func);
87         visit(obj.args);
88     }
89 
_visit(const if_t & obj)90     virtual void _visit(const if_t &obj) {
91         visit(obj.cond);
92         visit(obj.body);
93         visit(obj.else_body);
94     }
95 
_visit(const iif_t & obj)96     virtual void _visit(const iif_t &obj) {
97         visit(obj.cond);
98         visit(obj.true_expr);
99         visit(obj.false_expr);
100     }
101 
_visit(const let_t & obj)102     virtual void _visit(const let_t &obj) {
103         visit(obj.var);
104         visit(obj.value);
105         visit(obj.body);
106     }
107 
_visit(const load_t & obj)108     virtual void _visit(const load_t &obj) {
109         visit(obj.buf);
110         visit(obj.off);
111     }
112 
_visit(const ptr_t & obj)113     virtual void _visit(const ptr_t &obj) {
114         visit(obj.base);
115         visit(obj.off);
116     }
117 
_visit(const shuffle_t & obj)118     virtual void _visit(const shuffle_t &obj) { visit(obj.vec); }
119 
_visit(const stmt_group_t & obj)120     virtual void _visit(const stmt_group_t &obj) { visit(obj.body); }
121 
_visit(const stmt_seq_t & obj)122     virtual void _visit(const stmt_seq_t &obj) {
123         visit(obj.head);
124         visit(obj.tail);
125     }
126 
_visit(const store_t & obj)127     virtual void _visit(const store_t &obj) {
128         visit(obj.buf);
129         visit(obj.off);
130         visit(obj.value);
131         visit(obj.mask);
132     }
133 
_visit(const ternary_op_t & obj)134     virtual void _visit(const ternary_op_t &obj) {
135         visit(obj.a);
136         visit(obj.b);
137         visit(obj.c);
138     }
139 
_visit(const unary_op_t & obj)140     virtual void _visit(const unary_op_t &obj) { visit(obj.a); }
141 
is_supported(const object_t & obj) const142     bool is_supported(const object_t &obj) const {
143         if (obj.is_empty()) return true;
144 
145         auto *impl = obj.impl();
146         auto ti = impl->dispatch_type_id();
147         return ti < num_dispatch_funcs;
148     }
149 
150 protected:
find_dispatch_func(int64_t ti) const151     virtual dispatch_func_type find_dispatch_func(int64_t ti) const {
152         return ti < num_dispatch_funcs ? dispatch_funcs()[ti] : nullptr;
153     }
154 
155 private:
156     static const int64_t num_dispatch_funcs
157             = ir_type_id_t::end_visitable_ir_objects;
158     static std::array<dispatch_func_type, num_dispatch_funcs> &
dispatch_funcs()159     dispatch_funcs() {
160         static std::array<dispatch_func_type, num_dispatch_funcs>
161                 _dispatch_funcs;
162         static std::once_flag initialized;
163         std::call_once(initialized, [&]() {
164 #define HANDLE_IR_OBJECT(type) \
165     _dispatch_funcs[type::_dispatch_type_id()] = &call<type>;
166             HANDLE_ALL_IR_OBJECTS()
167 
168 #undef HANDLE_IR_OBJECT
169         });
170         return _dispatch_funcs;
171     }
172 
173     template <typename T>
call(ir_visitor_t * visitor,const object_impl_t & obj)174     static void call(ir_visitor_t *visitor, const object_impl_t &obj) {
175         visitor->pre_visit(obj);
176         visitor->_visit((const T &)obj);
177         visitor->post_visit(obj);
178     }
179 
dispatch(const object_impl_t * obj)180     void dispatch(const object_impl_t *obj) {
181         if (!obj) return;
182 
183         auto ti = obj->dispatch_type_id();
184         auto f = find_dispatch_func(ti);
185         if (!f) {
186             ir_error_not_expected() << "Can't handle type: " << object_t(obj);
187         }
188         f(this, *obj);
189     }
190 };
191 
192 class ir_context_t {
193 public:
create_tmp_var(const type_t & type,const std::string & prefix="tmp")194     expr_t create_tmp_var(
195             const type_t &type, const std::string &prefix = "tmp") {
196         int &id = prefix_ids_[prefix];
197         auto name = prefix + "_" + std::to_string(id);
198         id++;
199         return var_t::make(type, name);
200     }
201 
202 private:
203     std::unordered_map<std::string, int> prefix_ids_;
204 };
205 
206 class alloc_updater_t : public ir_mutator_t {
207 public:
resize(const expr_t & buf,int new_size)208     void resize(const expr_t &buf, int new_size) {
209         auto ret = resizes_.insert({buf, new_size});
210         ir_assert(ret.second) << buf;
211         MAYBE_UNUSED(ret);
212     }
213 
remove(const expr_t & buf)214     void remove(const expr_t &buf) {
215         auto ret = removes_.insert(buf);
216         ir_assert(ret.second) << buf;
217         MAYBE_UNUSED(ret);
218     }
219 
update(const stmt_t & stmt)220     stmt_t update(const stmt_t &stmt) { return mutate(stmt); }
221 
_mutate(const alloc_t & obj)222     object_t _mutate(const alloc_t &obj) override {
223         auto new_obj = ir_mutator_t::_mutate(obj);
224 
225         if (try_remove(new_obj)) return new_obj;
226         if (try_resize(new_obj)) return new_obj;
227 
228         return new_obj;
229     }
230 
231 private:
try_remove(object_t & obj)232     bool try_remove(object_t &obj) {
233         auto &alloc = obj.as<alloc_t>();
234         auto it = removes_.find(alloc.buf);
235         if (it == removes_.end()) return false;
236 
237         obj = alloc.body;
238         removes_.erase(it);
239         return true;
240     }
241 
try_resize(object_t & obj)242     bool try_resize(object_t &obj) {
243         auto &alloc = obj.as<alloc_t>();
244         auto it = resizes_.find(alloc.buf);
245         if (it == resizes_.end()) return false;
246 
247         obj = alloc_t::make(
248                 alloc.buf, it->second, alloc.kind, alloc.attr, alloc.body);
249         resizes_.erase(it);
250         return true;
251     }
252 
253     object_set_t<expr_t> removes_;
254     object_map_t<expr_t, int> resizes_;
255 };
256 
257 // Returns a new statement with injected buffer allocations from `allocs`.
258 // - If put_innermost is false, then `stmt` is nested to all allocations
259 // - If put_innermost is true, then every allocation is injected as innermost
260 //   as possible
261 stmt_t inject_alloc_stmts(const stmt_t &stmt, const std::vector<stmt_t> &allocs,
262         bool put_innermost = false);
263 
264 // Returns a new statement with injected let statements, `stmt` is nested to
265 // all let statements.
266 stmt_t inject_let_stmts(const stmt_t &stmt, const std::vector<stmt_t> &lets);
267 
268 template <typename T>
269 struct expr_cast_helper_t {
calldnnl::impl::gpu::jit::expr_cast_helper_t270     static T call(const expr_t &e) { return to_cpp<T>(e); }
271 
calldnnl::impl::gpu::jit::expr_cast_helper_t272     static std::vector<T> call(const std::vector<expr_t> &exprs) {
273         std::vector<T> ret;
274         for (auto &e : exprs)
275             ret.push_back(to_cpp<T>(e));
276         return ret;
277     }
278 };
279 
280 template <>
281 struct expr_cast_helper_t<expr_t> {
calldnnl::impl::gpu::jit::expr_cast_helper_t282     static expr_t call(const expr_t &e) { return e; }
283 
calldnnl::impl::gpu::jit::expr_cast_helper_t284     static std::vector<expr_t> call(const std::vector<expr_t> &exprs) {
285         return exprs;
286     }
287 
288     template <typename U,
289             typename
290             = typename std::enable_if<std::is_arithmetic<U>::value>::type>
calldnnl::impl::gpu::jit::expr_cast_helper_t291     static std::vector<expr_t> call(const std::vector<U> &vec) {
292         std::vector<expr_t> ret;
293         for (auto &v : vec)
294             ret.push_back(to_expr(v));
295         return ret;
296     }
297 };
298 
299 template <typename DstT, typename SrcT>
300 DstT expr_cast(const SrcT &src) {
301     return expr_cast_helper_t<DstT>::call(src);
302 }
303 
304 template <typename DstT, typename SrcT>
expr_cast(const std::vector<SrcT> & src)305 std::vector<DstT> expr_cast(const std::vector<SrcT> &src) {
306     return expr_cast_helper_t<DstT>::call(src);
307 }
308 
309 // Performs constant folding recursively to an IR tree.
310 object_t const_fold(const object_t &obj);
311 
312 // Performs constant folding non-recursively to an expression.
313 expr_t const_fold_non_recursive(const expr_t &e);
314 
315 template <typename T>
316 std::vector<object_t> find_objects(const object_t &root);
317 
318 template <typename T>
319 std::vector<object_t> find_objects_unique(const object_t &root);
320 
321 class alloc_manager_t {
322 public:
alloc_manager_t(const stmt_t & root)323     alloc_manager_t(const stmt_t &root) {
324         auto allocs = find_objects<alloc_t>(root);
325         for (auto &_a : allocs) {
326             auto &a = _a.as<alloc_t>();
327             auto ret = buf2alloc_.insert({a.buf, _a});
328             buffers_.push_back(a.buf);
329             ir_assert(ret.second) << "Buffer already exists: " << a.buf;
330             MAYBE_UNUSED(ret);
331         }
332 
333         // Sort buffers by name.
334         std::sort(buffers_.begin(), buffers_.end(),
335                 [](const expr_t &a, const expr_t &b) {
336                     return a.as<var_t>().name < b.as<var_t>().name;
337                 });
338     }
339 
buffers() const340     const std::vector<expr_t> &buffers() const { return buffers_; }
341 
find_buffer(const std::string & name,bool allow_empty=false) const342     expr_t find_buffer(
343             const std::string &name, bool allow_empty = false) const {
344         for (auto &b : buffers())
345             if (b.as<var_t>().name == name) return b;
346 
347         if (!allow_empty) ir_error_not_expected() << name;
348         return expr_t();
349     }
350 
find_buffers(alloc_kind_t kind) const351     std::vector<expr_t> find_buffers(alloc_kind_t kind) const {
352         std::vector<expr_t> ret;
353         for (auto &b : buffers())
354             if (alloc_kind(b) == kind) ret.push_back(b);
355         return ret;
356     }
357 
alloc_size(const expr_t & buf) const358     int alloc_size(const expr_t &buf) const {
359         auto *a = find_alloc(buf);
360         ir_assert(a) << buf;
361         return a->size;
362     }
363 
alloc_kind(const expr_t & buf) const364     alloc_kind_t alloc_kind(const expr_t &buf) const {
365         auto *a = find_alloc(buf);
366         ir_assert(a) << buf;
367         return a->kind;
368     }
369 
total_size(alloc_kind_t kind) const370     int total_size(alloc_kind_t kind) const {
371         int ret = 0;
372         for (auto &kv : buf2alloc_) {
373             auto &a = kv.second.as<alloc_t>();
374             if (a.kind == kind) ret += a.size;
375         }
376         return ret;
377     }
378 
379 private:
find_alloc(const expr_t & buf) const380     const alloc_t *find_alloc(const expr_t &buf) const {
381         auto it = buf2alloc_.find(buf);
382         if (it == buf2alloc_.end()) return nullptr;
383         return it->second.as_ptr<alloc_t>();
384     }
385 
386     object_map_t<expr_t, stmt_t> buf2alloc_;
387     std::vector<expr_t> buffers_;
388     object_map_t<expr_t, stmt_t> alloc_updates_;
389 };
390 
391 // IR utility functions.
392 expr_t abs(const expr_t &e);
393 
394 expr_t cast(const expr_t &e, const type_t &type, bool saturate = false);
395 
396 bool is_zero(const expr_t &e);
397 
398 bool is_one(const expr_t &e);
399 
400 bool is_minus_one(const expr_t &e);
401 
402 bool is_const_broadcast(const expr_t &e);
403 
404 bool is_const_broadcast(const expr_t &e, const expr_t &value);
405 
406 bool all_of(const expr_t &e, const expr_t &value);
407 
408 expr_t make_buffer(const std::string &name);
409 
410 // Utility functions for nary_op_t.
411 expr_t nary_op_back_transform(const expr_t &e);
412 expr_t nary_op_canonicalize(const expr_t &_e);
413 expr_t make_nary_op(op_kind_t op_kind, const std::vector<expr_t> &args);
414 std::vector<expr_t> cvt_expr_to_nary_op_args(const expr_t &e);
415 
416 // Substitutes all occurrences of `from` to `to` in `root.
417 object_t substitute(const object_t &root, const object_t &from,
418         const object_t &to,
419         int max_substitutions = std::numeric_limits<int>::max());
420 
421 // Returns leaf statements of `root`. Uses inorder traversal.
422 std::vector<stmt_t> flatten_statements(const stmt_t &root);
423 
424 template <typename T, bool find_unique = false, bool save_objects = true>
425 class object_finder_t : public ir_visitor_t {
426 public:
_visit(const T & obj)427     void _visit(const T &obj) override {
428         ir_visitor_t::_visit(obj);
429         occurrences++;
430         if (!save_objects) return;
431         if (find_unique) {
432             found_unique.insert(obj);
433         } else {
434             found.push_back(obj);
435         }
436     }
437 
438     std::vector<object_t> found;
439     object_set_t<object_t> found_unique;
440     int occurrences = 0;
441 };
442 
443 // Returns all IR objects of type `T` found in `root`.
444 template <typename T>
find_objects(const object_t & root)445 std::vector<object_t> find_objects(const object_t &root) {
446     object_finder_t<T, /*find_unique=*/false> finder;
447     finder.visit(root);
448     return finder.found;
449 }
450 
451 template <typename T>
count_objects(const object_t & root)452 int count_objects(const object_t &root) {
453     object_finder_t<T, /*find_unique=*/false, /*save_objects=*/false> finder;
454     finder.visit(root);
455     return finder.occurrences;
456 }
457 
458 // Returns unique IR objects of type `T` found in `root`.
459 template <typename T>
find_unique_objects(const object_t & root)460 object_set_t<object_t> find_unique_objects(const object_t &root) {
461     object_finder_t<T, /*find_unique=*/true> finder;
462     finder.visit(root);
463     return finder.found_unique;
464 }
465 
466 // Returns number of occurrences of `obj` in `root` (based on identity
467 // comparison).
468 int count_object(const object_t &root, const object_t &obj);
469 
470 // Returns number of occurrences of `obj` in vector of root objects (based on
471 // identity comparison).
472 template <typename T>
count_object(const std::vector<T> & roots,const object_t & obj)473 int count_object(const std::vector<T> &roots, const object_t &obj) {
474     int ret = 0;
475     for (auto &root : roots)
476         ret += count_object(root, obj);
477     return ret;
478 }
479 
480 // Checks if `root` contains `obj`.
481 bool contains_object(const object_t &root, const object_t &obj);
482 
483 // Returns all statement groups matching the label.
484 std::vector<stmt_t> find_stmt_groups(
485         const object_t &root, const stmt_label_t &label);
486 
487 // Returns a statement group matching the label. `root` must have exactly one
488 // occurrence.
489 stmt_t find_stmt_group(const object_t &root, const stmt_label_t &label);
490 
491 class scope_visitor_t : public ir_visitor_t {
492 public:
is_expr_defined(const expr_t & e) const493     bool is_expr_defined(const expr_t &e) const {
494         auto vars = find_unique_objects<var_t>(e);
495         for (auto &v : vars) {
496             if (def_vars_.count(v) == 0) return false;
497         }
498         return true;
499     }
500 
501 #define CASE(type, var_field, is_pre) \
502     if (obj.type_id() == type::_type_id()) { \
503         visit_scope((const type &)obj, ((const type &)obj).var_field, is_pre); \
504         return; \
505     }
506 
pre_visit(const object_impl_t & obj)507     void pre_visit(const object_impl_t &obj) override {
508         CASE(alloc_t, buf, true);
509         CASE(let_t, var, true);
510         CASE(for_t, var, true);
511     }
512 
post_visit(const object_impl_t & obj)513     void post_visit(const object_impl_t &obj) override {
514         CASE(alloc_t, buf, false);
515         CASE(let_t, var, false);
516         CASE(for_t, var, false);
517     }
518 
519 #undef CASE
520 
521 private:
522     template <typename T>
visit_scope(const T & obj,const expr_t & var,bool is_pre_visit)523     void visit_scope(const T &obj, const expr_t &var, bool is_pre_visit) {
524         if (is_pre_visit) {
525             def_vars_.insert(var);
526             return;
527         }
528         def_vars_.erase(var);
529     }
530 
531     object_set_t<expr_t> def_vars_;
532 };
533 
534 class ir_path_t {
535 public:
push(const object_impl_t * obj)536     void push(const object_impl_t *obj) { path_.push_back(obj); }
537 
pop()538     void pop() { path_.pop_back(); }
539 
back() const540     const object_impl_t *back() const {
541         ir_assert(!is_empty());
542         return path_.back();
543     }
544 
is_empty() const545     bool is_empty() const { return path_.empty(); }
546 
merge(const ir_path_t & other)547     void merge(const ir_path_t &other) {
548         size_t idx;
549         size_t min_size = std::min(path_.size(), other.path_.size());
550         for (idx = 0; idx < min_size; idx++) {
551             if (path_[idx] != other.path_[idx]) break;
552         }
553         path_.resize(idx);
554     }
555 
556 private:
557     std::vector<const object_impl_t *> path_;
558 };
559 
560 // Only for statements that create scope.
561 stmt_t get_stmt_body(const stmt_t &stmt);
562 
563 stmt_t replace_stmt_body(const stmt_t &stmt, const stmt_t &new_body);
564 
565 // Describes the linear transformation F(x) for variable x: F(x) = (a * x + b),
566 // where a and b are integer constants.
567 struct linear_transform_t {
568     expr_t x;
569     int a;
570     int b;
571 
is_identitydnnl::impl::gpu::jit::linear_transform_t572     bool is_identity() const { return a == 1 && b == 0; }
573 };
574 
575 // Relation: (lhs op rhs), where:
576 // - lhs is a variable
577 // - rhs is an integer constant
578 // - op is a comparison operation
579 class relation_t {
580 public:
relation_t(const expr_t & expr)581     relation_t(const expr_t &expr) : expr_(normalize(expr)) {}
582 
expr() const583     const expr_t &expr() const { return expr_; }
584 
var() const585     const expr_t &var() const { return expr_.as<binary_op_t>().a; }
586 
rhs() const587     const expr_t &rhs() const { return expr_.as<binary_op_t>().b; }
588 
op_kind() const589     op_kind_t op_kind() const { return expr_.as<binary_op_t>().op_kind; }
590 
591     bool implies(const relation_t &other) const;
592 
593     // Applies linear transformation to left and right hand sides of the relation.
594     relation_t transform(const linear_transform_t &t, const expr_t &new_var);
595 
str() const596     std::string str() const {
597         std::ostringstream oss;
598         oss << expr_;
599         return oss.str();
600     }
601 
is_relation_constraint(const expr_t & e)602     static bool is_relation_constraint(const expr_t &e) {
603         auto *binary_op = e.as_ptr<binary_op_t>();
604         if (!binary_op) return false;
605         if (!is_var(binary_op->a)) return false;
606         if (!is_const(binary_op->b)) return false;
607         if (!is_cmp_op(binary_op->op_kind)) return false;
608         return true;
609     }
610 
611 private:
612     static expr_t normalize(const expr_t &e);
613 
614     expr_t expr_;
615 };
616 
operator <<(std::ostream & out,const relation_t & rel)617 inline std::ostream &operator<<(std::ostream &out, const relation_t &rel) {
618     out << rel.str();
619     return out;
620 }
621 
622 // Equality for modulus: (var % mod) == 0, where:
623 // - var is a variable
624 // - mod is an integer constant
625 class modulus_info_t {
626 public:
modulus_info_t(const expr_t & expr)627     modulus_info_t(const expr_t &expr) : expr_(expr) {}
628 
expr() const629     const expr_t &expr() const { return expr_; }
630 
var() const631     const expr_t &var() const {
632         auto &mod_expr = expr_.as<binary_op_t>().a;
633         return mod_expr.as<binary_op_t>().a;
634     }
635 
mod() const636     const expr_t &mod() const {
637         auto &mod_expr = expr_.as<binary_op_t>().a;
638         return mod_expr.as<binary_op_t>().b;
639     }
640 
implies(const modulus_info_t & other) const641     bool implies(const modulus_info_t &other) const {
642         ir_assert(var().is_same(other.var()));
643 
644         int64_t this_mod = to_cpp<int64_t>(mod());
645         int64_t other_mod = to_cpp<int64_t>(other.mod());
646 
647         return this_mod % other_mod == 0;
648     }
649 
str() const650     std::string str() const {
651         std::ostringstream oss;
652         oss << expr_;
653         return oss.str();
654     }
655 
656     // Try to match (var % mod) == 0.
657     static bool is_modulus_constraint(const expr_t &e);
658 
659 private:
660     expr_t expr_;
661 };
662 
operator <<(std::ostream & out,const modulus_info_t & mod)663 inline std::ostream &operator<<(std::ostream &out, const modulus_info_t &mod) {
664     out << mod.str();
665     return out;
666 }
667 
668 // Helper class to find constant bounds of integer expressions based on known
669 // relations.
670 class bound_finder_t {
671 public:
bound_finder_t(const object_map_t<expr_t,std::vector<relation_t>> & relations)672     bound_finder_t(
673             const object_map_t<expr_t, std::vector<relation_t>> &relations)
674         : relations_(relations) {}
675 
find_low_bound(const expr_t & e) const676     int64_t find_low_bound(const expr_t &e) const {
677         return find_bound_impl(e, /*is_low=*/true);
678     }
679 
find_high_bound(const expr_t & e) const680     int64_t find_high_bound(const expr_t &e) const {
681         return find_bound_impl(e, /*is_low=*/false);
682     }
683 
is_good_bound(int64_t bound)684     static bool is_good_bound(int64_t bound) {
685         if (bound == unlimited_bound(true)) return false;
686         if (bound == unlimited_bound(false)) return false;
687         return true;
688     }
689 
690 private:
691     // If is_low is true, searches for proven low bound, and high bound
692     // otherwise.
693     int64_t find_bound_impl(const expr_t &e, bool is_low) const;
694 
unlimited_bound(bool is_low)695     static int64_t unlimited_bound(bool is_low) {
696         if (is_low) return std::numeric_limits<int64_t>::min();
697         return std::numeric_limits<int64_t>::max();
698     }
699 
700     object_map_t<expr_t, std::vector<relation_t>> relations_;
701 };
702 
703 // TODO: Add integers check (only integers can be constrained).
704 class constraint_set_t {
705 public:
706     void add_constraint(const expr_t &e);
707 
can_prove(const expr_t & e,bool try_simplify=true) const708     bool can_prove(const expr_t &e, bool try_simplify = true) const {
709         auto ret = can_prove_impl(e, /*do_simplify=*/false);
710         if (ret || !try_simplify) return ret;
711 
712         return can_prove_impl(e, /*do_simplify=*/true);
713     }
714 
715     bool is_single_value(const expr_t &e, expr_t &value) const;
716 
717     int max_proven_gcd(const expr_t &var) const;
718 
719 private:
can_prove_modulus(const expr_t & e) const720     bool can_prove_modulus(const expr_t &e) const {
721         modulus_info_t unknown(e);
722         auto it = modulus_infos_.find(unknown.var());
723         if (it == modulus_infos_.end()) return false;
724 
725         for (auto &known : it->second) {
726             if (known.implies(unknown)) return true;
727         }
728 
729         return false;
730     }
731 
can_prove_relation(const expr_t & e) const732     bool can_prove_relation(const expr_t &e) const {
733         relation_t unknown(e);
734         auto it = relations_.find(unknown.var());
735         if (it == relations_.end()) return false;
736 
737         for (auto &known : it->second) {
738             if (known.implies(unknown)) return true;
739         }
740 
741         return false;
742     }
743 
try_prove_compound_relation(const expr_t & e) const744     bool try_prove_compound_relation(const expr_t &e) const {
745         auto *binary = e.as_ptr<binary_op_t>();
746         if (!binary) return false;
747 
748         auto op_kind = binary->op_kind;
749         auto &a = binary->a;
750         auto &_b = binary->b;
751 
752         if (!is_const(_b)) return false;
753 
754         auto b = to_cpp<int64_t>(_b);
755 
756         // Normalize operation kind.
757         switch (op_kind) {
758             case op_kind_t::_ge:
759             case op_kind_t::_le: break;
760             case op_kind_t::_gt:
761                 op_kind = op_kind_t::_ge;
762                 ir_assert(b < std::numeric_limits<int64_t>::max());
763                 b += 1;
764                 break;
765             case op_kind_t::_lt:
766                 op_kind = op_kind_t::_le;
767                 ir_assert(b > std::numeric_limits<int64_t>::min());
768                 b -= 1;
769                 break;
770             default: return false;
771         }
772 
773         bound_finder_t finder(relations_);
774         if (op_kind == op_kind_t::_ge) {
775             auto lo = finder.find_low_bound(a);
776             if (!bound_finder_t::is_good_bound(lo)) return false;
777             return lo >= b;
778         }
779 
780         if (op_kind == op_kind_t::_le) {
781             auto hi = finder.find_high_bound(a);
782             if (!bound_finder_t::is_good_bound(hi)) return false;
783             return hi <= b;
784         }
785 
786         return false;
787     }
788 
789     bool can_prove_impl(const expr_t &_e, bool do_simplify) const;
790 
791     object_map_t<expr_t, std::vector<relation_t>> relations_;
792     object_map_t<expr_t, std::vector<modulus_info_t>> modulus_infos_;
793 };
794 
795 // Simplifies expression or statement. An optional constraint set is used to
796 // pass known equalities and inequalities which may be used for simplification.
797 object_t simplify(const object_t &obj, const constraint_set_t &cset = {});
798 
799 // Searches for expression patterns to reduce them to the equivalent ternary
800 // operations.
801 expr_t simplify_rewrite_with_ternary(const expr_t &e, bool recursive = true);
802 
803 // Moves constants to the right hand side of an expression.
804 // Example: (c0 + x) op c1 -> x op (c1 - c0)
805 expr_t simplify_cmp_move_const_to_rhs(const expr_t &e);
806 
807 // Reduces left and right hand sides of an expression.
808 // Example: A * x < A * B -> x < B (if A > 0).
809 expr_t simplify_cmp_reduce_lhs_rhs(const expr_t &e);
810 
811 // Propagates shuffle down the expression tree for more effective vectorization.
812 expr_t simplify_propagate_shuffle(const expr_t &e);
813 
814 // Pre-defined functions.
815 namespace funcs {
816 
barrier_func()817 inline func_t barrier_func() {
818     static auto f = builtin_t::make("barrier");
819     return f;
820 }
821 
barrier()822 inline stmt_t barrier() {
823     return barrier_func().call();
824 }
825 
slm_fence_func()826 inline func_t slm_fence_func() {
827     static auto f = builtin_t::make("slm_fence");
828     return f;
829 }
830 
slm_fence()831 inline stmt_t slm_fence() {
832     return slm_fence_func().call();
833 }
834 
signal_func()835 inline func_t signal_func() {
836     static auto f = builtin_t::make("signal");
837     return f;
838 }
839 
signal()840 inline stmt_t signal() {
841     return signal_func().call();
842 }
843 
barrier_wait_func()844 inline func_t barrier_wait_func() {
845     static auto f = builtin_t::make("barrier_wait");
846     return f;
847 }
848 
barrier_wait()849 inline stmt_t barrier_wait() {
850     return barrier_wait_func().call();
851 }
852 
853 } // namespace funcs
854 
855 // Helper functionality to extract ND indices packed into 1D index.
856 // Example:
857 //     i = [0; Bi, 2 * Bi, ... (I - 1) * Bi]
858 //     i_info.dim = I; i_info.block = Bi
859 //     j = [0; Bj, 2 * Bj, ... (J - 1) * Bj]
860 //     j_info.dim = J; j_info.block = Bj
861 // 1D index: ij_idx
862 // 2D indices: [i; j]
863 // Unpacking:
864 //     i = (ij_idx % I) * Bi
865 //     j = (ij_idx / I) * Bj
866 struct unpack_dim_info_t {
867     const expr_t &var;
868     int dim;
869     int block;
870 };
871 
cvt_args_to_unpack_dim_info(std::vector<unpack_dim_info_t> &)872 inline void cvt_args_to_unpack_dim_info(std::vector<unpack_dim_info_t> &) {}
873 
874 template <typename... ArgsT>
cvt_args_to_unpack_dim_info(std::vector<unpack_dim_info_t> & infos,const expr_t & var,int dim,int block,const ArgsT &...args)875 void cvt_args_to_unpack_dim_info(std::vector<unpack_dim_info_t> &infos,
876         const expr_t &var, int dim, int block, const ArgsT &... args) {
877     infos.push_back(unpack_dim_info_t {var, dim, block});
878     cvt_args_to_unpack_dim_info(infos, args...);
879 }
880 
881 void unpack(std::vector<stmt_t> &init_stmts, constraint_set_t &cset,
882         const expr_t &_e, const std::vector<unpack_dim_info_t> &infos);
883 
884 template <typename... ArgsT>
unpack(std::vector<stmt_t> & init_stmts,constraint_set_t & cset,const expr_t & e,const ArgsT &...args)885 void unpack(std::vector<stmt_t> &init_stmts, constraint_set_t &cset,
886         const expr_t &e, const ArgsT &... args) {
887     std::vector<unpack_dim_info_t> infos;
888     cvt_args_to_unpack_dim_info(infos, args...);
889     unpack(init_stmts, cset, e, infos);
890 }
891 
892 } // namespace jit
893 } // namespace gpu
894 } // namespace impl
895 } // namespace dnnl
896 
897 #endif
898