1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * SSA related checks and pass.
22 *
23 * SSA requires each varaible to be only defined once.
24 * \file ssa.cc
25 */
26 #include <tvm/ir.h>
27 #include <tvm/ir_visitor.h>
28 #include <tvm/ir_mutator.h>
29 #include <tvm/ir_pass.h>
30 #include <unordered_set>
31 #include <unordered_map>
32 #include <vector>
33
34 namespace tvm {
35 namespace ir {
36 namespace {
37 class IRVerifySSA final : public IRVisitor {
38 public:
39 bool is_ssa{true};
40
Visit(const NodeRef & n)41 void Visit(const NodeRef& n) final {
42 if (!is_ssa) return;
43 IRVisitor::Visit(n);
44 }
Visit_(const Let * op)45 void Visit_(const Let* op) final {
46 MarkDef(op->var.get());
47 IRVisitor::Visit_(op);
48 }
Visit_(const LetStmt * op)49 void Visit_(const LetStmt* op) final {
50 MarkDef(op->var.get());
51 IRVisitor::Visit_(op);
52 }
Visit_(const For * op)53 void Visit_(const For* op) final {
54 MarkDef(op->loop_var.get());
55 IRVisitor::Visit_(op);
56 }
Visit_(const Allocate * op)57 void Visit_(const Allocate* op) final {
58 MarkDef(op->buffer_var.get());
59 IRVisitor::Visit_(op);
60 }
61
62 private:
MarkDef(const Variable * v)63 void MarkDef(const Variable* v) {
64 if (defined_.count(v) != 0) {
65 is_ssa = false; return;
66 } else {
67 defined_[v] = 1;
68 }
69 }
70 std::unordered_map<const Variable*, int> defined_;
71 };
72
73 class IRConvertSSA final : public IRMutator {
74 public:
Mutate_(const Variable * op,const Expr & e)75 Expr Mutate_(const Variable* op, const Expr& e) final {
76 if (scope_.count(op)) {
77 return scope_[op].back();
78 } else {
79 return e;
80 }
81 }
Mutate_(const Let * op,const Expr & e)82 Expr Mutate_(const Let* op, const Expr& e) final {
83 const VarExpr& v = op->var;
84 if (defined_.count(v.get())) {
85 Expr value = IRMutator::Mutate(op->value);
86 VarExpr new_var = Variable::make(v.type(), v->name_hint);
87 scope_[v.get()].push_back(new_var);
88 Expr body = IRMutator::Mutate(op->body);
89 scope_[v.get()].pop_back();
90 return Let::make(new_var, value, body);
91 } else {
92 defined_.insert(v.get());
93 return IRMutator::Mutate_(op, e);
94 }
95 }
Mutate_(const Load * op,const Expr & e)96 Expr Mutate_(const Load* op, const Expr& e) final {
97 Expr expr = IRMutator::Mutate_(op, e);
98 op = expr.as<Load>();
99 if (scope_.count(op->buffer_var.get())) {
100 return Load::make(
101 op->type, scope_[op->buffer_var.get()].back(),
102 op->index, op->predicate);
103 } else {
104 return expr;
105 }
106 }
Mutate_(const Store * op,const Stmt & s)107 Stmt Mutate_(const Store* op, const Stmt& s) final {
108 Stmt stmt = IRMutator::Mutate_(op, s);
109 op = stmt.as<Store>();
110 if (scope_.count(op->buffer_var.get())) {
111 return Store::make(
112 scope_[op->buffer_var.get()].back(), op->value,
113 op->index, op->predicate);
114 } else {
115 return stmt;
116 }
117 }
Mutate_(const LetStmt * op,const Stmt & s)118 Stmt Mutate_(const LetStmt* op, const Stmt& s) final {
119 const VarExpr& v = op->var;
120 if (defined_.count(v.get())) {
121 Expr value = IRMutator::Mutate(op->value);
122 VarExpr new_var = Variable::make(v.type(), v->name_hint);
123 scope_[v.get()].push_back(new_var);
124 Stmt body = IRMutator::Mutate(op->body);
125 scope_[v.get()].pop_back();
126 return LetStmt::make(new_var, value, body);
127 } else {
128 defined_.insert(v.get());
129 return IRMutator::Mutate_(op, s);
130 }
131 }
Mutate_(const For * op,const Stmt & s)132 Stmt Mutate_(const For* op, const Stmt& s) final {
133 const VarExpr& v = op->loop_var;
134 if (defined_.count(v.get())) {
135 VarExpr new_var = Variable::make(v.type(), v->name_hint);
136 scope_[v.get()].push_back(new_var);
137 Stmt stmt = IRMutator::Mutate_(op, s);
138 scope_[v.get()].pop_back();
139 op = stmt.as<For>();
140 return For::make(
141 new_var, op->min, op->extent, op->for_type, op->device_api, op->body);
142 } else {
143 defined_.insert(v.get());
144 return IRMutator::Mutate_(op, s);
145 }
146 }
Mutate_(const Allocate * op,const Stmt & s)147 Stmt Mutate_(const Allocate* op, const Stmt& s) final {
148 const VarExpr& v = op->buffer_var;
149 if (defined_.count(v.get())) {
150 VarExpr new_var = Variable::make(v.type(), v->name_hint);
151 scope_[v.get()].push_back(new_var);
152 Stmt stmt = IRMutator::Mutate_(op, s);
153 scope_[v.get()].pop_back();
154 op = stmt.as<Allocate>();
155 return Allocate::make(
156 new_var, op->type, op->extents, op->condition,
157 op->body, op->new_expr, op->free_function);
158 } else {
159 defined_.insert(v.get());
160 return IRMutator::Mutate_(op, s);
161 }
162 }
Mutate_(const AttrStmt * op,const Stmt & s)163 Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
164 if (const Variable* v = op->node.as<Variable>()) {
165 if (op->attr_key == attr::storage_scope) {
166 const Allocate* alloc = op->body.as<Allocate>();
167 if (alloc && op->node.same_as(alloc->buffer_var)) {
168 Stmt new_alloc = Mutate(op->body);
169 if (new_alloc.same_as(op->body)) return s;
170 alloc = new_alloc.as<Allocate>();
171 CHECK(alloc);
172 return AttrStmt::make(
173 alloc->buffer_var, op->attr_key, op->value, new_alloc);
174 }
175 }
176 Stmt stmt = IRMutator::Mutate_(op, s);
177 op = stmt.as<AttrStmt>();
178 if (scope_.count(v) && scope_[v].size() != 0) {
179 return AttrStmt::make(
180 scope_[v].back(), op->attr_key, op->value, op->body);
181 } else {
182 return stmt;
183 }
184 } else {
185 return IRMutator::Mutate_(op, s);
186 }
187 }
188
189 private:
190 std::unordered_map<const Variable*, std::vector<VarExpr> > scope_;
191 std::unordered_set<const Variable*> defined_;
192 };
193
194 } // namespace
195
VerifySSA(const Stmt & ir)196 bool VerifySSA(const Stmt& ir) {
197 IRVerifySSA v;
198 v.Visit(ir);
199 return v.is_ssa;
200 }
201
ConvertSSA(Stmt stmt)202 Stmt ConvertSSA(Stmt stmt) {
203 return IRConvertSSA().Mutate(stmt);
204 }
205
206 } // namespace ir
207 } // namespace tvm
208