1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 #include <dmlc/logging.h>
21 #include <gtest/gtest.h>
22 #include <tvm/node/functor.h>
23 #include <tvm/tir/builtin.h>
24 #include <tvm/tir/expr.h>
25 #include <tvm/tir/expr_functor.h>
26 #include <tvm/tir/op.h>
27 #include <tvm/tir/stmt_functor.h>
28 
TEST(IRF,Basic)29 TEST(IRF, Basic) {
30   using namespace tvm;
31   using namespace tvm::tir;
32   Var x("x");
33   auto z = x + 1;
34 
35   NodeFunctor<int(const ObjectRef& n, int b)> f;
36   f.set_dispatch<VarNode>([](const ObjectRef& n, int b) { return b; });
37   f.set_dispatch<AddNode>([](const ObjectRef& n, int b) { return b + 2; });
38   CHECK_EQ(f(x, 2), 2);
39   CHECK_EQ(f(z, 2), 4);
40 }
41 
TEST(IRF,CountVar)42 TEST(IRF, CountVar) {
43   using namespace tvm;
44   using namespace tvm::tir;
45   int n_var = 0;
46   Var x("x"), y;
47 
48   auto z = x + 1 + y + y;
49   tir::PostOrderVisit(z, [&n_var](const ObjectRef& n) {
50     if (n.as<VarNode>()) ++n_var;
51   });
52   CHECK_EQ(n_var, 2);
53 }
54 
TEST(IRF,ExprTransform)55 TEST(IRF, ExprTransform) {
56   using namespace tvm;
57   using namespace tvm::tir;
58   Var x("x");
59   auto z = x + 1;
60 
61   class MyExprFunctor : public tir::ExprFunctor<int(const PrimExpr&, int)> {
62    public:
63     int VisitExpr_(const VarNode* op, int b) final { return b; }
64     int VisitExpr_(const IntImmNode* op, int b) final { return op->value; }
65     int VisitExpr_(const AddNode* op, int b) final {
66       return VisitExpr(op->a, b) + VisitExpr(op->b, b);
67     }
68   };
69   MyExprFunctor f;
70   CHECK_EQ(f(x, 2), 2);
71   CHECK_EQ(f(z, 2), 3);
72   try {
73     f(z - 1, 2);
74     LOG(FATAL) << "should fail";
75   } catch (dmlc::Error) {
76   }
77 }
78 
TEST(IRF,ExprVisit)79 TEST(IRF, ExprVisit) {
80   using namespace tvm;
81   using namespace tvm::tir;
82   Var x("x");
83   auto z = x + 1;
84 
85   class MyVisitor : public tir::ExprFunctor<void(const PrimExpr&)>,
86                     public tir::StmtFunctor<void(const Stmt&)> {
87    public:
88     int count = 0;
89     // implementation
90     void VisitExpr_(const VarNode* op) final { ++count; }
91     void VisitExpr_(const IntImmNode* op) final {}
92     void VisitExpr_(const AddNode* op) final {
93       VisitExpr(op->a);
94       VisitExpr(op->b);
95     }
96     void VisitStmt_(const EvaluateNode* op) final { VisitExpr(op->value); }
97   };
98   MyVisitor v;
99   v.VisitStmt(Evaluate(z));
100   CHECK_EQ(v.count, 1);
101 }
102 
TEST(IRF,StmtVisitor)103 TEST(IRF, StmtVisitor) {
104   using namespace tvm;
105   using namespace tvm::tir;
106   Var x("x");
107   class MyVisitor : public StmtExprVisitor {
108    public:
109     int count = 0;
110     // implementation
111     void VisitExpr_(const VarNode* op) final { ++count; }
112   };
113   MyVisitor v;
114   auto fmaketest = [&]() {
115     auto z = x + 1;
116     Stmt body = Evaluate(z);
117     Var buffer("b", DataType::Handle());
118     return Allocate(buffer, DataType::Float(32), {z, z}, const_true(), body);
119   };
120   v(fmaketest());
121   CHECK_EQ(v.count, 3);
122 }
123 
TEST(IRF,StmtMutator)124 TEST(IRF, StmtMutator) {
125   using namespace tvm;
126   using namespace tvm::tir;
127   Var x("x");
128 
129   class MyVisitor : public tir::StmtMutator, public tir::ExprMutator {
130    public:
131     using StmtMutator::operator();
132     using ExprMutator::operator();
133 
134    protected:
135     // implementation
136     PrimExpr VisitExpr_(const AddNode* op) final { return op->a; }
137     Stmt VisitStmt_(const SeqStmtNode* op) final { return StmtMutator::VisitSeqStmt_(op, true); }
138     PrimExpr VisitExpr(const PrimExpr& expr) final { return ExprMutator::VisitExpr(expr); }
139   };
140   auto fmakealloc = [&]() {
141     auto z = x + 1;
142     Stmt body = Evaluate(z);
143     Var buffer("b", DataType::Handle());
144     return Allocate(buffer, DataType::Float(32), {1, z}, const_true(), body);
145   };
146 
147   auto fmakeif = [&]() {
148     auto z = x + 1;
149     Stmt body = Evaluate(z);
150     return IfThenElse(x, Evaluate(0), body);
151   };
152 
153   MyVisitor v;
154   {
155     auto body = fmakealloc();
156     Stmt body2 = Evaluate(1);
157     Stmt bref = body.as<AllocateNode>()->body;
158     auto* extentptr = body.as<AllocateNode>()->extents.get();
159     Array<Stmt> arr{std::move(body), body2, body2};
160     auto* arrptr = arr.get();
161     arr.MutateByApply([&](Stmt s) { return v(std::move(s)); });
162     CHECK(arr.get() == arrptr);
163     // inplace update body
164     CHECK(arr[0].as<AllocateNode>()->extents[1].same_as(x));
165     CHECK(arr[0].as<AllocateNode>()->extents.get() == extentptr);
166     // copy because there is additional refs
167     CHECK(!arr[0].as<AllocateNode>()->body.same_as(bref));
168     CHECK(arr[0].as<AllocateNode>()->body.as<EvaluateNode>()->value.same_as(x));
169     CHECK(bref.as<EvaluateNode>()->value.as<AddNode>());
170   }
171   {
172     Array<Stmt> arr{fmakealloc()};
173     // mutate array get reference by another one, triiger copy.
174     Array<Stmt> arr2 = arr;
175     auto* arrptr = arr.get();
176     arr.MutateByApply([&](Stmt s) { return v(std::move(s)); });
177     CHECK(arr.get() != arrptr);
178     CHECK(arr[0].as<AllocateNode>()->extents[1].same_as(x));
179     CHECK(!arr2[0].as<AllocateNode>()->extents[1].same_as(x));
180     // mutate but no content change.
181     arr2 = arr;
182     arr.MutateByApply([&](Stmt s) { return v(std::move(s)); });
183     CHECK(arr2.get() == arr.get());
184   }
185   {
186     Array<Stmt> arr{fmakeif()};
187     arr.MutateByApply([&](Stmt s) { return v(std::move(s)); });
188     CHECK(arr[0].as<IfThenElseNode>()->else_case.as<EvaluateNode>()->value.same_as(x));
189     // mutate but no content change.
190     auto arr2 = arr;
191     arr.MutateByApply([&](Stmt s) { return v(std::move(s)); });
192     CHECK(arr2.get() == arr.get());
193   }
194 
195   {
196     auto body =
197         Evaluate(Call(DataType::Int(32), builtin::call_extern(), {StringImm("xyz"), x + 1}));
198     auto res = v(std::move(body));
199     CHECK(res.as<EvaluateNode>()->value.as<CallNode>()->args[1].same_as(x));
200   }
201   {
202     Stmt body = fmakealloc();
203     Stmt body2 = Evaluate(1);
204     auto* ref2 = body2.get();
205     auto* extentptr = body.as<AllocateNode>()->extents.get();
206     // construct a recursive SeqStmt.
207     body = SeqStmt({body});
208     body = SeqStmt({body, body2});
209     body = SeqStmt({body, body2});
210     body = v(std::move(body));
211     // the seq get flattened
212     CHECK(body.as<SeqStmtNode>()->size() == 3);
213     CHECK(body.as<SeqStmtNode>()->seq[0].as<AllocateNode>()->extents.get() == extentptr);
214     CHECK(body.as<SeqStmtNode>()->seq[1].get() == ref2);
215   }
216 
217   {
218     // Cannot cow because of bref
219     Stmt body = fmakealloc();
220     Stmt body2 = Evaluate(1);
221     auto* extentptr = body.as<AllocateNode>()->extents.get();
222     // construct a recursive SeqStmt.
223     body = SeqStmt({body});
224     auto bref = body;
225     body = SeqStmt({body, body2});
226     body = v(std::move(body));
227     // the seq get flattened
228     CHECK(body.as<SeqStmtNode>()->seq[0].as<AllocateNode>()->extents.get() != extentptr);
229   }
230 }
231 
main(int argc,char ** argv)232 int main(int argc, char** argv) {
233   testing::InitGoogleTest(&argc, argv);
234   testing::FLAGS_gtest_death_test_style = "threadsafe";
235   return RUN_ALL_TESTS();
236 }
237