1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 *
22 * \brief Lift specified AttrStmt scope to outer if
23 * the body contains the same scope.
24 * \file lift_attr_scope.cc
25 */
26 #include <tvm/ir_pass.h>
27 #include <tvm/ir_mutator.h>
28 #include "ir_util.h"
29
30 namespace tvm {
31 namespace ir {
32
33 // NOTE: this optimization can only be applied
34 // to a few specified attr keys
35 class AttrScopeLifter : public IRMutator {
36 public:
AttrScopeLifter(std::string attr_key)37 explicit AttrScopeLifter(std::string attr_key)
38 : attr_key_(attr_key) {}
39
Lift(Stmt stmt)40 Stmt Lift(Stmt stmt) {
41 stmt = Mutate(stmt);
42 if (attr_node_.defined()) {
43 stmt = AttrStmt::make(
44 attr_node_, attr_key_, attr_value_, stmt);
45 }
46 return stmt;
47 }
48
49 // do not go beyond
Mutate_(const Allocate * op,const Stmt & s)50 Stmt Mutate_(const Allocate* op, const Stmt& s) final {
51 Stmt stmt = IRMutator::Mutate_(op, s);
52 op = stmt.as<Allocate>();
53 if (attr_node_.defined()) {
54 Stmt body = AttrStmt::make(
55 attr_node_, attr_key_, attr_value_, op->body);
56 // undefine them
57 attr_node_ = NodeRef();
58 attr_value_ = Expr();
59 return Allocate::make(
60 op->buffer_var, op->type,
61 op->extents, op->condition, body,
62 op->new_expr, op->free_function);
63 } else {
64 return stmt;
65 }
66 }
67
Mutate_(const AttrStmt * op,const Stmt & s)68 Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
69 if (op->attr_key == attr_key_) {
70 attr_node_ = op->node;
71 attr_value_ = op->value;
72 return op->body;
73 } else {
74 return IRMutator::Mutate_(op, s);
75 }
76 }
77
Mutate_(const Block * op,const Stmt & s)78 Stmt Mutate_(const Block* op, const Stmt& s) final {
79 std::vector<Stmt> seq;
80 FlattenSeq(op->first, &seq);
81 FlattenSeq(op->rest, &seq);
82 seq = MutateSeq(seq);
83 if (seq.size() == 2 &&
84 seq[0].same_as(op->first) &&
85 seq[1].same_as(op->rest)) {
86 return s;
87 }
88 return MergeSeq(seq);
89 }
90
Mutate_(const IfThenElse * op,const Stmt & s)91 Stmt Mutate_(const IfThenElse* op, const Stmt& s) final {
92 if (!op->else_case.defined()) {
93 return IRMutator::Mutate_(op, s);
94 }
95 Stmt then_case = this->Mutate(op->then_case);
96 NodeRef first_node;
97 Expr first_value;
98 std::swap(first_node, attr_node_);
99 std::swap(first_value, attr_value_);
100 Stmt else_case = this->Mutate(op->else_case);
101 if (attr_node_.defined() &&
102 attr_value_.defined() &&
103 first_node.defined() &&
104 first_value.defined() &&
105 attr_node_.same_as(first_node) &&
106 ValueSame(attr_value_, first_value)) {
107 if (then_case.same_as(op->then_case) &&
108 else_case.same_as(op->else_case)) {
109 return s;
110 } else {
111 return IfThenElse::make(op->condition, then_case, else_case);
112 }
113 } else {
114 if (first_node.defined()) {
115 then_case = AttrStmt::make(
116 first_node, attr_key_, first_value, then_case);
117 }
118 if (attr_node_.defined()) {
119 else_case = AttrStmt::make(
120 attr_node_, attr_key_, attr_value_, else_case);
121 // undefine them
122 attr_node_ = NodeRef();
123 attr_value_ = Expr();
124 }
125 if (then_case.same_as(op->then_case) &&
126 else_case.same_as(op->else_case)) {
127 return s;
128 } else {
129 return IfThenElse::make(op->condition, then_case, else_case);
130 }
131 }
132 }
133
134 private:
FlattenSeq(Stmt s,std::vector<Stmt> * res)135 void FlattenSeq(Stmt s, std::vector<Stmt>* res) {
136 if (const Block* op = s.as<Block>()) {
137 FlattenSeq(op->first, res);
138 FlattenSeq(op->rest, res);
139 } else if (const ProducerConsumer* op = s.as<ProducerConsumer>()) {
140 if (!op->is_producer) {
141 FlattenSeq(op->body, res);
142 } else {
143 res->emplace_back(s);
144 }
145 } else {
146 res->emplace_back(s);
147 }
148 }
149
MutateSeq(const std::vector<Stmt> & seq)150 std::vector<Stmt> MutateSeq(const std::vector<Stmt>& seq) {
151 std::vector<Stmt> res_seq;
152 NodeRef curr_node;
153 Expr curr_value;
154 Stmt curr_stmt;
155 for (const Stmt & stmt : seq) {
156 attr_node_ = NodeRef();
157 attr_value_ = Expr();
158 Stmt rest = this->Mutate(stmt);
159 if (attr_node_.defined() &&
160 attr_value_.defined() &&
161 curr_node.defined() &&
162 curr_value.defined() &&
163 attr_node_.same_as(curr_node) &&
164 ValueSame(attr_value_, curr_value)) {
165 curr_stmt = Block::make(curr_stmt, rest);
166 } else {
167 if (curr_stmt.defined()) {
168 if (curr_node.defined()) {
169 curr_stmt = AttrStmt::make(
170 curr_node, attr_key_, curr_value, curr_stmt);
171 }
172 res_seq.push_back(curr_stmt);
173 }
174 curr_stmt = rest;
175 curr_node = attr_node_;
176 curr_value = attr_value_;
177 }
178 }
179
180 if (curr_stmt.defined()) {
181 // keep attr_node_, attr_node_
182 if (res_seq.size() == 0) {
183 return {curr_stmt};
184 }
185 if (curr_node.defined()) {
186 curr_stmt = AttrStmt::make(
187 curr_node, attr_key_, curr_value, curr_stmt);
188 }
189 res_seq.push_back(curr_stmt);
190 // reset
191 attr_node_ = NodeRef();
192 attr_value_ = Expr();
193 }
194 return res_seq;
195 }
196
197 // value comparison that also compares content of int constant
ValueSame(const Expr & a,const Expr & b)198 static bool ValueSame(const Expr& a, const Expr& b) {
199 if (a.same_as(b)) return true;
200 if (a->type_index() != b->type_index()) return false;
201 if (a.type() != b.type()) return false;
202 if (const IntImm* op = a.as<IntImm>()) {
203 return op->value == b.as<IntImm>()->value;
204 }
205 if (const UIntImm* op = a.as<UIntImm>()) {
206 return op->value == b.as<UIntImm>()->value;
207 }
208 return false;
209 }
210
211 std::string attr_key_;
212 NodeRef attr_node_;
213 Expr attr_value_;
214 };
215
LiftAttrScope(Stmt stmt,std::string attr_key)216 Stmt LiftAttrScope(Stmt stmt, std::string attr_key) {
217 return AttrScopeLifter(attr_key).Lift(stmt);
218 }
219
220 } // namespace ir
221 } // namespace tvm
222