1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  *  SSA related checks and pass.
22  *
23  *  SSA requires each varaible to be only defined once.
24  * \file ssa.cc
25  */
26 #include <tvm/ir.h>
27 #include <tvm/ir_visitor.h>
28 #include <tvm/ir_mutator.h>
29 #include <tvm/ir_pass.h>
30 #include <unordered_set>
31 #include <unordered_map>
32 #include <vector>
33 
34 namespace tvm {
35 namespace ir {
36 namespace {
37 class IRVerifySSA final : public IRVisitor {
38  public:
39   bool is_ssa{true};
40 
Visit(const NodeRef & n)41   void Visit(const NodeRef& n) final {
42     if (!is_ssa) return;
43     IRVisitor::Visit(n);
44   }
Visit_(const Let * op)45   void Visit_(const Let* op) final {
46     MarkDef(op->var.get());
47     IRVisitor::Visit_(op);
48   }
Visit_(const LetStmt * op)49   void Visit_(const LetStmt* op) final {
50     MarkDef(op->var.get());
51     IRVisitor::Visit_(op);
52   }
Visit_(const For * op)53   void Visit_(const For* op) final {
54     MarkDef(op->loop_var.get());
55     IRVisitor::Visit_(op);
56   }
Visit_(const Allocate * op)57   void Visit_(const Allocate* op) final {
58     MarkDef(op->buffer_var.get());
59     IRVisitor::Visit_(op);
60   }
61 
62  private:
MarkDef(const Variable * v)63   void MarkDef(const Variable* v) {
64     if (defined_.count(v) != 0) {
65       is_ssa = false; return;
66     } else {
67       defined_[v] = 1;
68     }
69   }
70   std::unordered_map<const Variable*, int> defined_;
71 };
72 
73 class IRConvertSSA final : public IRMutator {
74  public:
Mutate_(const Variable * op,const Expr & e)75   Expr Mutate_(const Variable* op, const Expr& e) final {
76     if (scope_.count(op)) {
77       return scope_[op].back();
78     } else {
79       return e;
80     }
81   }
Mutate_(const Let * op,const Expr & e)82   Expr Mutate_(const Let* op, const Expr& e) final {
83     const VarExpr& v = op->var;
84     if (defined_.count(v.get())) {
85       Expr value = IRMutator::Mutate(op->value);
86       VarExpr new_var = Variable::make(v.type(), v->name_hint);
87       scope_[v.get()].push_back(new_var);
88       Expr body = IRMutator::Mutate(op->body);
89       scope_[v.get()].pop_back();
90       return Let::make(new_var, value, body);
91     } else {
92       defined_.insert(v.get());
93       return IRMutator::Mutate_(op, e);
94     }
95   }
Mutate_(const Load * op,const Expr & e)96   Expr Mutate_(const Load* op, const Expr& e) final {
97     Expr expr = IRMutator::Mutate_(op, e);
98     op = expr.as<Load>();
99     if (scope_.count(op->buffer_var.get())) {
100       return Load::make(
101           op->type, scope_[op->buffer_var.get()].back(),
102           op->index, op->predicate);
103     } else {
104       return expr;
105     }
106   }
Mutate_(const Store * op,const Stmt & s)107   Stmt Mutate_(const Store* op, const Stmt& s) final {
108     Stmt stmt = IRMutator::Mutate_(op, s);
109     op = stmt.as<Store>();
110     if (scope_.count(op->buffer_var.get())) {
111       return Store::make(
112           scope_[op->buffer_var.get()].back(), op->value,
113           op->index, op->predicate);
114     } else {
115       return stmt;
116     }
117   }
Mutate_(const LetStmt * op,const Stmt & s)118   Stmt Mutate_(const LetStmt* op, const Stmt& s) final {
119     const VarExpr& v = op->var;
120     if (defined_.count(v.get())) {
121       Expr value = IRMutator::Mutate(op->value);
122       VarExpr new_var = Variable::make(v.type(), v->name_hint);
123       scope_[v.get()].push_back(new_var);
124       Stmt body = IRMutator::Mutate(op->body);
125       scope_[v.get()].pop_back();
126       return LetStmt::make(new_var, value, body);
127     } else {
128       defined_.insert(v.get());
129       return IRMutator::Mutate_(op, s);
130     }
131   }
Mutate_(const For * op,const Stmt & s)132   Stmt Mutate_(const For* op, const Stmt& s) final {
133     const VarExpr& v = op->loop_var;
134     if (defined_.count(v.get())) {
135       VarExpr new_var = Variable::make(v.type(), v->name_hint);
136       scope_[v.get()].push_back(new_var);
137       Stmt stmt = IRMutator::Mutate_(op, s);
138       scope_[v.get()].pop_back();
139       op = stmt.as<For>();
140       return For::make(
141           new_var, op->min, op->extent, op->for_type, op->device_api, op->body);
142     } else {
143       defined_.insert(v.get());
144       return IRMutator::Mutate_(op, s);
145     }
146   }
Mutate_(const Allocate * op,const Stmt & s)147   Stmt Mutate_(const Allocate* op, const Stmt& s) final {
148     const VarExpr& v = op->buffer_var;
149     if (defined_.count(v.get())) {
150       VarExpr new_var = Variable::make(v.type(), v->name_hint);
151       scope_[v.get()].push_back(new_var);
152       Stmt stmt = IRMutator::Mutate_(op, s);
153       scope_[v.get()].pop_back();
154       op = stmt.as<Allocate>();
155       return Allocate::make(
156           new_var, op->type, op->extents, op->condition,
157           op->body, op->new_expr, op->free_function);
158     } else {
159       defined_.insert(v.get());
160       return IRMutator::Mutate_(op, s);
161     }
162   }
Mutate_(const AttrStmt * op,const Stmt & s)163   Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
164     if (const Variable* v = op->node.as<Variable>()) {
165       if (op->attr_key == attr::storage_scope) {
166         const Allocate* alloc = op->body.as<Allocate>();
167         if (alloc && op->node.same_as(alloc->buffer_var)) {
168           Stmt new_alloc = Mutate(op->body);
169           if (new_alloc.same_as(op->body)) return s;
170           alloc = new_alloc.as<Allocate>();
171           CHECK(alloc);
172           return AttrStmt::make(
173               alloc->buffer_var, op->attr_key, op->value, new_alloc);
174         }
175       }
176       Stmt stmt = IRMutator::Mutate_(op, s);
177       op = stmt.as<AttrStmt>();
178       if (scope_.count(v) && scope_[v].size() != 0) {
179         return AttrStmt::make(
180             scope_[v].back(), op->attr_key, op->value, op->body);
181       } else {
182         return stmt;
183       }
184     } else {
185       return IRMutator::Mutate_(op, s);
186     }
187   }
188 
189  private:
190   std::unordered_map<const Variable*, std::vector<VarExpr> > scope_;
191   std::unordered_set<const Variable*> defined_;
192 };
193 
194 }  // namespace
195 
VerifySSA(const Stmt & ir)196 bool VerifySSA(const Stmt& ir) {
197   IRVerifySSA v;
198   v.Visit(ir);
199   return v.is_ssa;
200 }
201 
ConvertSSA(Stmt stmt)202 Stmt ConvertSSA(Stmt stmt) {
203   return IRConvertSSA().Mutate(stmt);
204 }
205 
206 }  // namespace ir
207 }  // namespace tvm
208