1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  *  SSA related checks and pass.
22  *
23  *  SSA requires each varaible to be only defined once.
24  * \file verify_ssa.cc
25  */
26 #include <tvm/runtime/registry.h>
27 #include <tvm/tir/analysis.h>
28 #include <tvm/tir/expr.h>
29 #include <tvm/tir/stmt_functor.h>
30 
31 #include <unordered_map>
32 #include <unordered_set>
33 #include <vector>
34 
35 namespace tvm {
36 namespace tir {
37 
38 class SSAVerifier final : public StmtExprVisitor {
39  public:
40   bool is_ssa_{true};
41 
VisitExpr(const PrimExpr & n)42   void VisitExpr(const PrimExpr& n) final {
43     if (!is_ssa_) return;
44     StmtExprVisitor::VisitExpr(n);
45   }
VisitStmt(const Stmt & n)46   void VisitStmt(const Stmt& n) final {
47     if (!is_ssa_) return;
48     StmtExprVisitor::VisitStmt(n);
49   }
VisitExpr_(const LetNode * op)50   void VisitExpr_(const LetNode* op) final {
51     // Weaker SSA condition
52     // A single var can be binded in multiple lets
53     // but they have to bind to the same value.
54     // This is used to enable cases when we reuse a single let
55     // expression to cosntruct a nested expr.
56     // (let x = 1 in x + 1) * (let x = 1 in x + 1)
57     auto it = def_map_.find(op->var);
58     if (it != def_map_.end()) {
59       if (!deep_equal_(it->second, op->value)) {
60         is_ssa_ = false;
61         return;
62       }
63     } else {
64       MarkDef(op->var, op->value);
65     }
66     StmtExprVisitor::VisitExpr_(op);
67   }
68 
VisitStmt_(const LetStmtNode * op)69   void VisitStmt_(const LetStmtNode* op) final {
70     MarkDef(op->var, op->value);
71     StmtExprVisitor::VisitStmt_(op);
72   }
VisitStmt_(const ForNode * op)73   void VisitStmt_(const ForNode* op) final {
74     MarkDef(op->loop_var, op->loop_var);
75     StmtExprVisitor::VisitStmt_(op);
76   }
VisitStmt_(const AllocateNode * op)77   void VisitStmt_(const AllocateNode* op) final {
78     MarkDef(op->buffer_var, op->buffer_var);
79     StmtExprVisitor::VisitStmt_(op);
80   }
81 
VisitExpr_(const VarNode * node)82   void VisitExpr_(const VarNode* node) final {
83     auto var = GetRef<Var>(node);
84     if (match_scope_) {
85       MarkDef(var, var, true);
86     }
87   }
88 
Run(const PrimFunc & func)89   void Run(const PrimFunc& func) {
90     for (auto param : func->params) {
91       MarkDef(param, param);
92     }
93 
94     for (auto kv : func->buffer_map) {
95       this->DefineBuffer(kv.second);
96     }
97     this->VisitStmt(func->body);
98   }
99 
DefineBuffer(const Buffer & buffer)100   void DefineBuffer(const Buffer& buffer) {
101     match_scope_ = true;
102     this->VisitExpr(buffer->data);
103     for (size_t i = 0; i < buffer->shape.size(); ++i) {
104       this->VisitExpr(buffer->shape[i]);
105     }
106 
107     if (buffer->strides.defined()) {
108       for (size_t i = 0; i < buffer->strides.size(); ++i) {
109         this->VisitExpr(buffer->strides[i]);
110       }
111     }
112     this->VisitExpr(buffer->elem_offset);
113 
114     match_scope_ = false;
115   }
116 
117  private:
MarkDef(const Var & var,PrimExpr value,bool allow_dup=false)118   void MarkDef(const Var& var, PrimExpr value, bool allow_dup = false) {
119     if (def_map_.count(var) != 0) {
120       if (!allow_dup) {
121         is_ssa_ = false;
122         return;
123       }
124     } else {
125       def_map_[var] = value;
126     }
127   }
128   // whether we are in match scope, where a var can occur multiple times.
129   bool match_scope_{false};
130   // deep equal
131   ExprDeepEqual deep_equal_;
132   // def map, for let, maps to the bind value, for others maps to self.
133   std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> def_map_;
134 };
135 
VerifySSA(const PrimFunc & func)136 bool VerifySSA(const PrimFunc& func) {
137   SSAVerifier visitor;
138   visitor.Run(func);
139   return visitor.is_ssa_;
140 }
141 
142 TVM_REGISTER_GLOBAL("tir.analysis.verify_ssa").set_body_typed(VerifySSA);
143 
144 namespace transform {
145 
VerifySSA()146 Pass VerifySSA() {
147   auto pass_func = [=](IRModule mod, PassContext ctx) {
148     for (auto kv : mod->functions) {
149       if (auto* n = kv.second.as<PrimFuncNode>()) {
150         auto func = GetRef<PrimFunc>(n);
151         CHECK(VerifySSA(func)) << "RuntimeError: IR is not in SSA form" << func;
152       }
153     }
154     return mod;
155   };
156   return tvm::transform::CreateModulePass(pass_func, 0, "tir.VerifySSA", {});
157 }
158 
159 TVM_REGISTER_GLOBAL("tir.transform.VerifySSA").set_body_typed(VerifySSA);
160 
161 }  // namespace transform
162 
163 }  // namespace tir
164 }  // namespace tvm
165