1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 #include <dmlc/logging.h>
21 #include <gtest/gtest.h>
22 #include <tvm/node/functor.h>
23 #include <tvm/tir/builtin.h>
24 #include <tvm/tir/expr.h>
25 #include <tvm/tir/expr_functor.h>
26 #include <tvm/tir/op.h>
27 #include <tvm/tir/stmt_functor.h>
28
TEST(IRF,Basic)29 TEST(IRF, Basic) {
30 using namespace tvm;
31 using namespace tvm::tir;
32 Var x("x");
33 auto z = x + 1;
34
35 NodeFunctor<int(const ObjectRef& n, int b)> f;
36 f.set_dispatch<VarNode>([](const ObjectRef& n, int b) { return b; });
37 f.set_dispatch<AddNode>([](const ObjectRef& n, int b) { return b + 2; });
38 CHECK_EQ(f(x, 2), 2);
39 CHECK_EQ(f(z, 2), 4);
40 }
41
TEST(IRF,CountVar)42 TEST(IRF, CountVar) {
43 using namespace tvm;
44 using namespace tvm::tir;
45 int n_var = 0;
46 Var x("x"), y;
47
48 auto z = x + 1 + y + y;
49 tir::PostOrderVisit(z, [&n_var](const ObjectRef& n) {
50 if (n.as<VarNode>()) ++n_var;
51 });
52 CHECK_EQ(n_var, 2);
53 }
54
TEST(IRF,ExprTransform)55 TEST(IRF, ExprTransform) {
56 using namespace tvm;
57 using namespace tvm::tir;
58 Var x("x");
59 auto z = x + 1;
60
61 class MyExprFunctor : public tir::ExprFunctor<int(const PrimExpr&, int)> {
62 public:
63 int VisitExpr_(const VarNode* op, int b) final { return b; }
64 int VisitExpr_(const IntImmNode* op, int b) final { return op->value; }
65 int VisitExpr_(const AddNode* op, int b) final {
66 return VisitExpr(op->a, b) + VisitExpr(op->b, b);
67 }
68 };
69 MyExprFunctor f;
70 CHECK_EQ(f(x, 2), 2);
71 CHECK_EQ(f(z, 2), 3);
72 try {
73 f(z - 1, 2);
74 LOG(FATAL) << "should fail";
75 } catch (dmlc::Error) {
76 }
77 }
78
TEST(IRF,ExprVisit)79 TEST(IRF, ExprVisit) {
80 using namespace tvm;
81 using namespace tvm::tir;
82 Var x("x");
83 auto z = x + 1;
84
85 class MyVisitor : public tir::ExprFunctor<void(const PrimExpr&)>,
86 public tir::StmtFunctor<void(const Stmt&)> {
87 public:
88 int count = 0;
89 // implementation
90 void VisitExpr_(const VarNode* op) final { ++count; }
91 void VisitExpr_(const IntImmNode* op) final {}
92 void VisitExpr_(const AddNode* op) final {
93 VisitExpr(op->a);
94 VisitExpr(op->b);
95 }
96 void VisitStmt_(const EvaluateNode* op) final { VisitExpr(op->value); }
97 };
98 MyVisitor v;
99 v.VisitStmt(Evaluate(z));
100 CHECK_EQ(v.count, 1);
101 }
102
TEST(IRF,StmtVisitor)103 TEST(IRF, StmtVisitor) {
104 using namespace tvm;
105 using namespace tvm::tir;
106 Var x("x");
107 class MyVisitor : public StmtExprVisitor {
108 public:
109 int count = 0;
110 // implementation
111 void VisitExpr_(const VarNode* op) final { ++count; }
112 };
113 MyVisitor v;
114 auto fmaketest = [&]() {
115 auto z = x + 1;
116 Stmt body = Evaluate(z);
117 Var buffer("b", DataType::Handle());
118 return Allocate(buffer, DataType::Float(32), {z, z}, const_true(), body);
119 };
120 v(fmaketest());
121 CHECK_EQ(v.count, 3);
122 }
123
TEST(IRF,StmtMutator)124 TEST(IRF, StmtMutator) {
125 using namespace tvm;
126 using namespace tvm::tir;
127 Var x("x");
128
129 class MyVisitor : public tir::StmtMutator, public tir::ExprMutator {
130 public:
131 using StmtMutator::operator();
132 using ExprMutator::operator();
133
134 protected:
135 // implementation
136 PrimExpr VisitExpr_(const AddNode* op) final { return op->a; }
137 Stmt VisitStmt_(const SeqStmtNode* op) final { return StmtMutator::VisitSeqStmt_(op, true); }
138 PrimExpr VisitExpr(const PrimExpr& expr) final { return ExprMutator::VisitExpr(expr); }
139 };
140 auto fmakealloc = [&]() {
141 auto z = x + 1;
142 Stmt body = Evaluate(z);
143 Var buffer("b", DataType::Handle());
144 return Allocate(buffer, DataType::Float(32), {1, z}, const_true(), body);
145 };
146
147 auto fmakeif = [&]() {
148 auto z = x + 1;
149 Stmt body = Evaluate(z);
150 return IfThenElse(x, Evaluate(0), body);
151 };
152
153 MyVisitor v;
154 {
155 auto body = fmakealloc();
156 Stmt body2 = Evaluate(1);
157 Stmt bref = body.as<AllocateNode>()->body;
158 auto* extentptr = body.as<AllocateNode>()->extents.get();
159 Array<Stmt> arr{std::move(body), body2, body2};
160 auto* arrptr = arr.get();
161 arr.MutateByApply([&](Stmt s) { return v(std::move(s)); });
162 CHECK(arr.get() == arrptr);
163 // inplace update body
164 CHECK(arr[0].as<AllocateNode>()->extents[1].same_as(x));
165 CHECK(arr[0].as<AllocateNode>()->extents.get() == extentptr);
166 // copy because there is additional refs
167 CHECK(!arr[0].as<AllocateNode>()->body.same_as(bref));
168 CHECK(arr[0].as<AllocateNode>()->body.as<EvaluateNode>()->value.same_as(x));
169 CHECK(bref.as<EvaluateNode>()->value.as<AddNode>());
170 }
171 {
172 Array<Stmt> arr{fmakealloc()};
173 // mutate array get reference by another one, triiger copy.
174 Array<Stmt> arr2 = arr;
175 auto* arrptr = arr.get();
176 arr.MutateByApply([&](Stmt s) { return v(std::move(s)); });
177 CHECK(arr.get() != arrptr);
178 CHECK(arr[0].as<AllocateNode>()->extents[1].same_as(x));
179 CHECK(!arr2[0].as<AllocateNode>()->extents[1].same_as(x));
180 // mutate but no content change.
181 arr2 = arr;
182 arr.MutateByApply([&](Stmt s) { return v(std::move(s)); });
183 CHECK(arr2.get() == arr.get());
184 }
185 {
186 Array<Stmt> arr{fmakeif()};
187 arr.MutateByApply([&](Stmt s) { return v(std::move(s)); });
188 CHECK(arr[0].as<IfThenElseNode>()->else_case.as<EvaluateNode>()->value.same_as(x));
189 // mutate but no content change.
190 auto arr2 = arr;
191 arr.MutateByApply([&](Stmt s) { return v(std::move(s)); });
192 CHECK(arr2.get() == arr.get());
193 }
194
195 {
196 auto body =
197 Evaluate(Call(DataType::Int(32), builtin::call_extern(), {StringImm("xyz"), x + 1}));
198 auto res = v(std::move(body));
199 CHECK(res.as<EvaluateNode>()->value.as<CallNode>()->args[1].same_as(x));
200 }
201 {
202 Stmt body = fmakealloc();
203 Stmt body2 = Evaluate(1);
204 auto* ref2 = body2.get();
205 auto* extentptr = body.as<AllocateNode>()->extents.get();
206 // construct a recursive SeqStmt.
207 body = SeqStmt({body});
208 body = SeqStmt({body, body2});
209 body = SeqStmt({body, body2});
210 body = v(std::move(body));
211 // the seq get flattened
212 CHECK(body.as<SeqStmtNode>()->size() == 3);
213 CHECK(body.as<SeqStmtNode>()->seq[0].as<AllocateNode>()->extents.get() == extentptr);
214 CHECK(body.as<SeqStmtNode>()->seq[1].get() == ref2);
215 }
216
217 {
218 // Cannot cow because of bref
219 Stmt body = fmakealloc();
220 Stmt body2 = Evaluate(1);
221 auto* extentptr = body.as<AllocateNode>()->extents.get();
222 // construct a recursive SeqStmt.
223 body = SeqStmt({body});
224 auto bref = body;
225 body = SeqStmt({body, body2});
226 body = v(std::move(body));
227 // the seq get flattened
228 CHECK(body.as<SeqStmtNode>()->seq[0].as<AllocateNode>()->extents.get() != extentptr);
229 }
230 }
231
main(int argc,char ** argv)232 int main(int argc, char** argv) {
233 testing::InitGoogleTest(&argc, argv);
234 testing::FLAGS_gtest_death_test_style = "threadsafe";
235 return RUN_ALL_TESTS();
236 }
237