/******************************************************************************* * Copyright 2021 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. *******************************************************************************/ #ifndef GPU_JIT_CONV_IR_HPP #define GPU_JIT_CONV_IR_HPP #include #include #include #include #include "gpu/jit/conv/ir_core.hpp" namespace dnnl { namespace impl { namespace gpu { namespace jit { // Helper class to walk through IR tree. class ir_visitor_t { public: using dispatch_func_type = void (*)(ir_visitor_t *, const object_impl_t &); virtual ~ir_visitor_t() = default; virtual void visit(const object_t &obj) { dispatch(obj.impl()); } template void visit(const std::vector &v) { for (auto &e : v) visit(e); } virtual void pre_visit(const object_impl_t &obj) {} virtual void post_visit(const object_impl_t &obj) {} // To catch missing _visit() handlers in ir_visitor_t. virtual void _visit(const object_impl_t &obj) { ir_error_not_expected() << "Can't handle type: " << object_t(obj); } #define DECL_VISIT_LEAF(name) \ virtual void _visit(const name &obj) {} DECL_VISIT_LEAF(bool_imm_t) DECL_VISIT_LEAF(float_imm_t) DECL_VISIT_LEAF(func_impl_t) DECL_VISIT_LEAF(int_imm_t) DECL_VISIT_LEAF(var_t) #undef DECL_VISIT_LEAF virtual void _visit(const alloc_t &obj) { visit(obj.buf); visit(obj.body); } virtual void _visit(const binary_op_t &obj) { visit(obj.a); visit(obj.b); } virtual void _visit(const cast_t &obj) { visit(obj.expr); } virtual void _visit(const for_t &obj) { visit(obj.var); visit(obj.init); visit(obj.bound); visit(obj.body); } virtual void _visit(const func_call_t &obj) { visit(obj.func); visit(obj.args); } virtual void _visit(const if_t &obj) { visit(obj.cond); visit(obj.body); visit(obj.else_body); } virtual void _visit(const iif_t &obj) { visit(obj.cond); visit(obj.true_expr); visit(obj.false_expr); } virtual void _visit(const let_t &obj) { visit(obj.var); visit(obj.value); visit(obj.body); } virtual void _visit(const load_t &obj) { visit(obj.buf); visit(obj.off); } virtual void _visit(const ptr_t &obj) { visit(obj.base); visit(obj.off); } virtual void _visit(const shuffle_t &obj) { visit(obj.vec); } virtual void _visit(const stmt_group_t &obj) { visit(obj.body); } virtual void _visit(const stmt_seq_t &obj) { visit(obj.head); visit(obj.tail); } virtual void _visit(const store_t &obj) { visit(obj.buf); visit(obj.off); visit(obj.value); visit(obj.mask); } virtual void _visit(const ternary_op_t &obj) { visit(obj.a); visit(obj.b); visit(obj.c); } virtual void _visit(const unary_op_t &obj) { visit(obj.a); } bool is_supported(const object_t &obj) const { if (obj.is_empty()) return true; auto *impl = obj.impl(); auto ti = impl->dispatch_type_id(); return ti < num_dispatch_funcs; } protected: virtual dispatch_func_type find_dispatch_func(int64_t ti) const { return ti < num_dispatch_funcs ? dispatch_funcs()[ti] : nullptr; } private: static const int64_t num_dispatch_funcs = ir_type_id_t::end_visitable_ir_objects; static std::array & dispatch_funcs() { static std::array _dispatch_funcs; static std::once_flag initialized; std::call_once(initialized, [&]() { #define HANDLE_IR_OBJECT(type) \ _dispatch_funcs[type::_dispatch_type_id()] = &call; HANDLE_ALL_IR_OBJECTS() #undef HANDLE_IR_OBJECT }); return _dispatch_funcs; } template static void call(ir_visitor_t *visitor, const object_impl_t &obj) { visitor->pre_visit(obj); visitor->_visit((const T &)obj); visitor->post_visit(obj); } void dispatch(const object_impl_t *obj) { if (!obj) return; auto ti = obj->dispatch_type_id(); auto f = find_dispatch_func(ti); if (!f) { ir_error_not_expected() << "Can't handle type: " << object_t(obj); } f(this, *obj); } }; class ir_context_t { public: expr_t create_tmp_var( const type_t &type, const std::string &prefix = "tmp") { int &id = prefix_ids_[prefix]; auto name = prefix + "_" + std::to_string(id); id++; return var_t::make(type, name); } private: std::unordered_map prefix_ids_; }; class alloc_updater_t : public ir_mutator_t { public: void resize(const expr_t &buf, int new_size) { auto ret = resizes_.insert({buf, new_size}); ir_assert(ret.second) << buf; MAYBE_UNUSED(ret); } void remove(const expr_t &buf) { auto ret = removes_.insert(buf); ir_assert(ret.second) << buf; MAYBE_UNUSED(ret); } stmt_t update(const stmt_t &stmt) { return mutate(stmt); } object_t _mutate(const alloc_t &obj) override { auto new_obj = ir_mutator_t::_mutate(obj); if (try_remove(new_obj)) return new_obj; if (try_resize(new_obj)) return new_obj; return new_obj; } private: bool try_remove(object_t &obj) { auto &alloc = obj.as(); auto it = removes_.find(alloc.buf); if (it == removes_.end()) return false; obj = alloc.body; removes_.erase(it); return true; } bool try_resize(object_t &obj) { auto &alloc = obj.as(); auto it = resizes_.find(alloc.buf); if (it == resizes_.end()) return false; obj = alloc_t::make( alloc.buf, it->second, alloc.kind, alloc.attr, alloc.body); resizes_.erase(it); return true; } object_set_t removes_; object_map_t resizes_; }; // Returns a new statement with injected buffer allocations from `allocs`. // - If put_innermost is false, then `stmt` is nested to all allocations // - If put_innermost is true, then every allocation is injected as innermost // as possible stmt_t inject_alloc_stmts(const stmt_t &stmt, const std::vector &allocs, bool put_innermost = false); // Returns a new statement with injected let statements, `stmt` is nested to // all let statements. stmt_t inject_let_stmts(const stmt_t &stmt, const std::vector &lets); template struct expr_cast_helper_t { static T call(const expr_t &e) { return to_cpp(e); } static std::vector call(const std::vector &exprs) { std::vector ret; for (auto &e : exprs) ret.push_back(to_cpp(e)); return ret; } }; template <> struct expr_cast_helper_t { static expr_t call(const expr_t &e) { return e; } static std::vector call(const std::vector &exprs) { return exprs; } template ::value>::type> static std::vector call(const std::vector &vec) { std::vector ret; for (auto &v : vec) ret.push_back(to_expr(v)); return ret; } }; template DstT expr_cast(const SrcT &src) { return expr_cast_helper_t::call(src); } template std::vector expr_cast(const std::vector &src) { return expr_cast_helper_t::call(src); } // Performs constant folding recursively to an IR tree. object_t const_fold(const object_t &obj); // Performs constant folding non-recursively to an expression. expr_t const_fold_non_recursive(const expr_t &e); template std::vector find_objects(const object_t &root); template std::vector find_objects_unique(const object_t &root); class alloc_manager_t { public: alloc_manager_t(const stmt_t &root) { auto allocs = find_objects(root); for (auto &_a : allocs) { auto &a = _a.as(); auto ret = buf2alloc_.insert({a.buf, _a}); buffers_.push_back(a.buf); ir_assert(ret.second) << "Buffer already exists: " << a.buf; MAYBE_UNUSED(ret); } // Sort buffers by name. std::sort(buffers_.begin(), buffers_.end(), [](const expr_t &a, const expr_t &b) { return a.as().name < b.as().name; }); } const std::vector &buffers() const { return buffers_; } expr_t find_buffer( const std::string &name, bool allow_empty = false) const { for (auto &b : buffers()) if (b.as().name == name) return b; if (!allow_empty) ir_error_not_expected() << name; return expr_t(); } std::vector find_buffers(alloc_kind_t kind) const { std::vector ret; for (auto &b : buffers()) if (alloc_kind(b) == kind) ret.push_back(b); return ret; } int alloc_size(const expr_t &buf) const { auto *a = find_alloc(buf); ir_assert(a) << buf; return a->size; } alloc_kind_t alloc_kind(const expr_t &buf) const { auto *a = find_alloc(buf); ir_assert(a) << buf; return a->kind; } int total_size(alloc_kind_t kind) const { int ret = 0; for (auto &kv : buf2alloc_) { auto &a = kv.second.as(); if (a.kind == kind) ret += a.size; } return ret; } private: const alloc_t *find_alloc(const expr_t &buf) const { auto it = buf2alloc_.find(buf); if (it == buf2alloc_.end()) return nullptr; return it->second.as_ptr(); } object_map_t buf2alloc_; std::vector buffers_; object_map_t alloc_updates_; }; // IR utility functions. expr_t abs(const expr_t &e); expr_t cast(const expr_t &e, const type_t &type, bool saturate = false); bool is_zero(const expr_t &e); bool is_one(const expr_t &e); bool is_minus_one(const expr_t &e); bool is_const_broadcast(const expr_t &e); bool is_const_broadcast(const expr_t &e, const expr_t &value); bool all_of(const expr_t &e, const expr_t &value); expr_t make_buffer(const std::string &name); // Utility functions for nary_op_t. expr_t nary_op_back_transform(const expr_t &e); expr_t nary_op_canonicalize(const expr_t &_e); expr_t make_nary_op(op_kind_t op_kind, const std::vector &args); std::vector cvt_expr_to_nary_op_args(const expr_t &e); // Substitutes all occurrences of `from` to `to` in `root. object_t substitute(const object_t &root, const object_t &from, const object_t &to, int max_substitutions = std::numeric_limits::max()); // Returns leaf statements of `root`. Uses inorder traversal. std::vector flatten_statements(const stmt_t &root); template class object_finder_t : public ir_visitor_t { public: void _visit(const T &obj) override { ir_visitor_t::_visit(obj); occurrences++; if (!save_objects) return; if (find_unique) { found_unique.insert(obj); } else { found.push_back(obj); } } std::vector found; object_set_t found_unique; int occurrences = 0; }; // Returns all IR objects of type `T` found in `root`. template std::vector find_objects(const object_t &root) { object_finder_t finder; finder.visit(root); return finder.found; } template int count_objects(const object_t &root) { object_finder_t finder; finder.visit(root); return finder.occurrences; } // Returns unique IR objects of type `T` found in `root`. template object_set_t find_unique_objects(const object_t &root) { object_finder_t finder; finder.visit(root); return finder.found_unique; } // Returns number of occurrences of `obj` in `root` (based on identity // comparison). int count_object(const object_t &root, const object_t &obj); // Returns number of occurrences of `obj` in vector of root objects (based on // identity comparison). template int count_object(const std::vector &roots, const object_t &obj) { int ret = 0; for (auto &root : roots) ret += count_object(root, obj); return ret; } // Checks if `root` contains `obj`. bool contains_object(const object_t &root, const object_t &obj); // Returns all statement groups matching the label. std::vector find_stmt_groups( const object_t &root, const stmt_label_t &label); // Returns a statement group matching the label. `root` must have exactly one // occurrence. stmt_t find_stmt_group(const object_t &root, const stmt_label_t &label); class scope_visitor_t : public ir_visitor_t { public: bool is_expr_defined(const expr_t &e) const { auto vars = find_unique_objects(e); for (auto &v : vars) { if (def_vars_.count(v) == 0) return false; } return true; } #define CASE(type, var_field, is_pre) \ if (obj.type_id() == type::_type_id()) { \ visit_scope((const type &)obj, ((const type &)obj).var_field, is_pre); \ return; \ } void pre_visit(const object_impl_t &obj) override { CASE(alloc_t, buf, true); CASE(let_t, var, true); CASE(for_t, var, true); } void post_visit(const object_impl_t &obj) override { CASE(alloc_t, buf, false); CASE(let_t, var, false); CASE(for_t, var, false); } #undef CASE private: template void visit_scope(const T &obj, const expr_t &var, bool is_pre_visit) { if (is_pre_visit) { def_vars_.insert(var); return; } def_vars_.erase(var); } object_set_t def_vars_; }; class ir_path_t { public: void push(const object_impl_t *obj) { path_.push_back(obj); } void pop() { path_.pop_back(); } const object_impl_t *back() const { ir_assert(!is_empty()); return path_.back(); } bool is_empty() const { return path_.empty(); } void merge(const ir_path_t &other) { size_t idx; size_t min_size = std::min(path_.size(), other.path_.size()); for (idx = 0; idx < min_size; idx++) { if (path_[idx] != other.path_[idx]) break; } path_.resize(idx); } private: std::vector path_; }; // Only for statements that create scope. stmt_t get_stmt_body(const stmt_t &stmt); stmt_t replace_stmt_body(const stmt_t &stmt, const stmt_t &new_body); // Describes the linear transformation F(x) for variable x: F(x) = (a * x + b), // where a and b are integer constants. struct linear_transform_t { expr_t x; int a; int b; bool is_identity() const { return a == 1 && b == 0; } }; // Relation: (lhs op rhs), where: // - lhs is a variable // - rhs is an integer constant // - op is a comparison operation class relation_t { public: relation_t(const expr_t &expr) : expr_(normalize(expr)) {} const expr_t &expr() const { return expr_; } const expr_t &var() const { return expr_.as().a; } const expr_t &rhs() const { return expr_.as().b; } op_kind_t op_kind() const { return expr_.as().op_kind; } bool implies(const relation_t &other) const; // Applies linear transformation to left and right hand sides of the relation. relation_t transform(const linear_transform_t &t, const expr_t &new_var); std::string str() const { std::ostringstream oss; oss << expr_; return oss.str(); } static bool is_relation_constraint(const expr_t &e) { auto *binary_op = e.as_ptr(); if (!binary_op) return false; if (!is_var(binary_op->a)) return false; if (!is_const(binary_op->b)) return false; if (!is_cmp_op(binary_op->op_kind)) return false; return true; } private: static expr_t normalize(const expr_t &e); expr_t expr_; }; inline std::ostream &operator<<(std::ostream &out, const relation_t &rel) { out << rel.str(); return out; } // Equality for modulus: (var % mod) == 0, where: // - var is a variable // - mod is an integer constant class modulus_info_t { public: modulus_info_t(const expr_t &expr) : expr_(expr) {} const expr_t &expr() const { return expr_; } const expr_t &var() const { auto &mod_expr = expr_.as().a; return mod_expr.as().a; } const expr_t &mod() const { auto &mod_expr = expr_.as().a; return mod_expr.as().b; } bool implies(const modulus_info_t &other) const { ir_assert(var().is_same(other.var())); int64_t this_mod = to_cpp(mod()); int64_t other_mod = to_cpp(other.mod()); return this_mod % other_mod == 0; } std::string str() const { std::ostringstream oss; oss << expr_; return oss.str(); } // Try to match (var % mod) == 0. static bool is_modulus_constraint(const expr_t &e); private: expr_t expr_; }; inline std::ostream &operator<<(std::ostream &out, const modulus_info_t &mod) { out << mod.str(); return out; } // Helper class to find constant bounds of integer expressions based on known // relations. class bound_finder_t { public: bound_finder_t( const object_map_t> &relations) : relations_(relations) {} int64_t find_low_bound(const expr_t &e) const { return find_bound_impl(e, /*is_low=*/true); } int64_t find_high_bound(const expr_t &e) const { return find_bound_impl(e, /*is_low=*/false); } static bool is_good_bound(int64_t bound) { if (bound == unlimited_bound(true)) return false; if (bound == unlimited_bound(false)) return false; return true; } private: // If is_low is true, searches for proven low bound, and high bound // otherwise. int64_t find_bound_impl(const expr_t &e, bool is_low) const; static int64_t unlimited_bound(bool is_low) { if (is_low) return std::numeric_limits::min(); return std::numeric_limits::max(); } object_map_t> relations_; }; // TODO: Add integers check (only integers can be constrained). class constraint_set_t { public: void add_constraint(const expr_t &e); bool can_prove(const expr_t &e, bool try_simplify = true) const { auto ret = can_prove_impl(e, /*do_simplify=*/false); if (ret || !try_simplify) return ret; return can_prove_impl(e, /*do_simplify=*/true); } bool is_single_value(const expr_t &e, expr_t &value) const; int max_proven_gcd(const expr_t &var) const; private: bool can_prove_modulus(const expr_t &e) const { modulus_info_t unknown(e); auto it = modulus_infos_.find(unknown.var()); if (it == modulus_infos_.end()) return false; for (auto &known : it->second) { if (known.implies(unknown)) return true; } return false; } bool can_prove_relation(const expr_t &e) const { relation_t unknown(e); auto it = relations_.find(unknown.var()); if (it == relations_.end()) return false; for (auto &known : it->second) { if (known.implies(unknown)) return true; } return false; } bool try_prove_compound_relation(const expr_t &e) const { auto *binary = e.as_ptr(); if (!binary) return false; auto op_kind = binary->op_kind; auto &a = binary->a; auto &_b = binary->b; if (!is_const(_b)) return false; auto b = to_cpp(_b); // Normalize operation kind. switch (op_kind) { case op_kind_t::_ge: case op_kind_t::_le: break; case op_kind_t::_gt: op_kind = op_kind_t::_ge; ir_assert(b < std::numeric_limits::max()); b += 1; break; case op_kind_t::_lt: op_kind = op_kind_t::_le; ir_assert(b > std::numeric_limits::min()); b -= 1; break; default: return false; } bound_finder_t finder(relations_); if (op_kind == op_kind_t::_ge) { auto lo = finder.find_low_bound(a); if (!bound_finder_t::is_good_bound(lo)) return false; return lo >= b; } if (op_kind == op_kind_t::_le) { auto hi = finder.find_high_bound(a); if (!bound_finder_t::is_good_bound(hi)) return false; return hi <= b; } return false; } bool can_prove_impl(const expr_t &_e, bool do_simplify) const; object_map_t> relations_; object_map_t> modulus_infos_; }; // Simplifies expression or statement. An optional constraint set is used to // pass known equalities and inequalities which may be used for simplification. object_t simplify(const object_t &obj, const constraint_set_t &cset = {}); // Searches for expression patterns to reduce them to the equivalent ternary // operations. expr_t simplify_rewrite_with_ternary(const expr_t &e, bool recursive = true); // Moves constants to the right hand side of an expression. // Example: (c0 + x) op c1 -> x op (c1 - c0) expr_t simplify_cmp_move_const_to_rhs(const expr_t &e); // Reduces left and right hand sides of an expression. // Example: A * x < A * B -> x < B (if A > 0). expr_t simplify_cmp_reduce_lhs_rhs(const expr_t &e); // Propagates shuffle down the expression tree for more effective vectorization. expr_t simplify_propagate_shuffle(const expr_t &e); // Pre-defined functions. namespace funcs { inline func_t barrier_func() { static auto f = builtin_t::make("barrier"); return f; } inline stmt_t barrier() { return barrier_func().call(); } inline func_t slm_fence_func() { static auto f = builtin_t::make("slm_fence"); return f; } inline stmt_t slm_fence() { return slm_fence_func().call(); } inline func_t signal_func() { static auto f = builtin_t::make("signal"); return f; } inline stmt_t signal() { return signal_func().call(); } inline func_t barrier_wait_func() { static auto f = builtin_t::make("barrier_wait"); return f; } inline stmt_t barrier_wait() { return barrier_wait_func().call(); } } // namespace funcs // Helper functionality to extract ND indices packed into 1D index. // Example: // i = [0; Bi, 2 * Bi, ... (I - 1) * Bi] // i_info.dim = I; i_info.block = Bi // j = [0; Bj, 2 * Bj, ... (J - 1) * Bj] // j_info.dim = J; j_info.block = Bj // 1D index: ij_idx // 2D indices: [i; j] // Unpacking: // i = (ij_idx % I) * Bi // j = (ij_idx / I) * Bj struct unpack_dim_info_t { const expr_t &var; int dim; int block; }; inline void cvt_args_to_unpack_dim_info(std::vector &) {} template void cvt_args_to_unpack_dim_info(std::vector &infos, const expr_t &var, int dim, int block, const ArgsT &... args) { infos.push_back(unpack_dim_info_t {var, dim, block}); cvt_args_to_unpack_dim_info(infos, args...); } void unpack(std::vector &init_stmts, constraint_set_t &cset, const expr_t &_e, const std::vector &infos); template void unpack(std::vector &init_stmts, constraint_set_t &cset, const expr_t &e, const ArgsT &... args) { std::vector infos; cvt_args_to_unpack_dim_info(infos, args...); unpack(init_stmts, cset, e, infos); } } // namespace jit } // namespace gpu } // namespace impl } // namespace dnnl #endif