1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  *
22  * \brief Lift specified AttrStmt scope to outer if
23  *   the body contains the same scope.
24  * \file lift_attr_scope.cc
25  */
26 #include <tvm/ir_pass.h>
27 #include <tvm/ir_mutator.h>
28 #include "ir_util.h"
29 
30 namespace tvm {
31 namespace ir {
32 
33 // NOTE: this optimization can only be applied
34 // to a few specified attr keys
35 class AttrScopeLifter : public IRMutator {
36  public:
AttrScopeLifter(std::string attr_key)37   explicit AttrScopeLifter(std::string attr_key)
38       : attr_key_(attr_key) {}
39 
Lift(Stmt stmt)40   Stmt Lift(Stmt stmt) {
41     stmt = Mutate(stmt);
42     if (attr_node_.defined()) {
43       stmt = AttrStmt::make(
44           attr_node_, attr_key_, attr_value_, stmt);
45     }
46     return stmt;
47   }
48 
49   // do not go beyond
Mutate_(const Allocate * op,const Stmt & s)50   Stmt Mutate_(const Allocate* op, const Stmt& s) final {
51     Stmt stmt = IRMutator::Mutate_(op, s);
52     op = stmt.as<Allocate>();
53     if (attr_node_.defined()) {
54       Stmt body = AttrStmt::make(
55           attr_node_, attr_key_, attr_value_, op->body);
56       // undefine them
57       attr_node_ = NodeRef();
58       attr_value_ = Expr();
59       return Allocate::make(
60         op->buffer_var, op->type,
61         op->extents, op->condition, body,
62         op->new_expr, op->free_function);
63     } else {
64       return stmt;
65     }
66   }
67 
Mutate_(const AttrStmt * op,const Stmt & s)68   Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
69     if (op->attr_key == attr_key_) {
70       attr_node_ = op->node;
71       attr_value_ = op->value;
72       return op->body;
73     } else {
74       return IRMutator::Mutate_(op, s);
75     }
76   }
77 
Mutate_(const Block * op,const Stmt & s)78   Stmt Mutate_(const Block* op, const Stmt& s) final {
79     std::vector<Stmt> seq;
80     FlattenSeq(op->first, &seq);
81     FlattenSeq(op->rest, &seq);
82     seq = MutateSeq(seq);
83     if (seq.size() == 2 &&
84         seq[0].same_as(op->first) &&
85         seq[1].same_as(op->rest)) {
86       return s;
87     }
88     return MergeSeq(seq);
89   }
90 
Mutate_(const IfThenElse * op,const Stmt & s)91   Stmt Mutate_(const IfThenElse* op, const Stmt& s) final {
92     if (!op->else_case.defined()) {
93       return IRMutator::Mutate_(op, s);
94     }
95     Stmt then_case = this->Mutate(op->then_case);
96     NodeRef first_node;
97     Expr first_value;
98     std::swap(first_node, attr_node_);
99     std::swap(first_value, attr_value_);
100     Stmt else_case = this->Mutate(op->else_case);
101     if (attr_node_.defined() &&
102         attr_value_.defined() &&
103         first_node.defined() &&
104         first_value.defined() &&
105         attr_node_.same_as(first_node) &&
106         ValueSame(attr_value_, first_value)) {
107       if (then_case.same_as(op->then_case) &&
108           else_case.same_as(op->else_case)) {
109         return s;
110       } else {
111         return IfThenElse::make(op->condition, then_case, else_case);
112       }
113     } else {
114       if (first_node.defined()) {
115         then_case = AttrStmt::make(
116             first_node, attr_key_, first_value, then_case);
117       }
118       if (attr_node_.defined()) {
119         else_case = AttrStmt::make(
120             attr_node_, attr_key_, attr_value_, else_case);
121         // undefine them
122         attr_node_ = NodeRef();
123         attr_value_ = Expr();
124       }
125       if (then_case.same_as(op->then_case) &&
126           else_case.same_as(op->else_case)) {
127         return s;
128       } else {
129         return IfThenElse::make(op->condition, then_case, else_case);
130       }
131     }
132   }
133 
134  private:
FlattenSeq(Stmt s,std::vector<Stmt> * res)135   void FlattenSeq(Stmt s, std::vector<Stmt>* res) {
136     if (const Block* op = s.as<Block>()) {
137       FlattenSeq(op->first, res);
138       FlattenSeq(op->rest, res);
139     } else if (const ProducerConsumer* op = s.as<ProducerConsumer>()) {
140       if (!op->is_producer) {
141         FlattenSeq(op->body, res);
142       } else {
143         res->emplace_back(s);
144       }
145     } else {
146       res->emplace_back(s);
147     }
148   }
149 
MutateSeq(const std::vector<Stmt> & seq)150   std::vector<Stmt> MutateSeq(const std::vector<Stmt>& seq) {
151     std::vector<Stmt> res_seq;
152     NodeRef curr_node;
153     Expr curr_value;
154     Stmt curr_stmt;
155     for (const Stmt & stmt : seq) {
156       attr_node_ = NodeRef();
157       attr_value_ = Expr();
158       Stmt rest = this->Mutate(stmt);
159       if (attr_node_.defined() &&
160           attr_value_.defined() &&
161           curr_node.defined() &&
162           curr_value.defined() &&
163           attr_node_.same_as(curr_node) &&
164           ValueSame(attr_value_, curr_value)) {
165         curr_stmt = Block::make(curr_stmt, rest);
166       } else {
167         if (curr_stmt.defined()) {
168           if (curr_node.defined()) {
169             curr_stmt = AttrStmt::make(
170                 curr_node, attr_key_, curr_value, curr_stmt);
171           }
172           res_seq.push_back(curr_stmt);
173         }
174         curr_stmt = rest;
175         curr_node = attr_node_;
176         curr_value = attr_value_;
177       }
178     }
179 
180     if (curr_stmt.defined()) {
181       // keep attr_node_, attr_node_
182       if (res_seq.size() == 0) {
183         return {curr_stmt};
184       }
185       if (curr_node.defined()) {
186         curr_stmt = AttrStmt::make(
187             curr_node, attr_key_, curr_value, curr_stmt);
188       }
189       res_seq.push_back(curr_stmt);
190       // reset
191       attr_node_ = NodeRef();
192       attr_value_ = Expr();
193     }
194     return res_seq;
195   }
196 
197   // value comparison that also compares content of int constant
ValueSame(const Expr & a,const Expr & b)198   static bool ValueSame(const Expr& a, const Expr& b) {
199     if (a.same_as(b)) return true;
200     if (a->type_index() != b->type_index()) return false;
201     if (a.type() != b.type()) return false;
202     if (const IntImm* op = a.as<IntImm>()) {
203       return op->value == b.as<IntImm>()->value;
204     }
205     if (const UIntImm* op = a.as<UIntImm>()) {
206       return op->value == b.as<UIntImm>()->value;
207     }
208     return false;
209   }
210 
211   std::string attr_key_;
212   NodeRef attr_node_;
213   Expr attr_value_;
214 };
215 
LiftAttrScope(Stmt stmt,std::string attr_key)216 Stmt LiftAttrScope(Stmt stmt, std::string attr_key) {
217   return AttrScopeLifter(attr_key).Lift(stmt);
218 }
219 
220 }  // namespace ir
221 }  // namespace tvm
222