1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  *  Exposre of pass functions.
22  * \file api_pass.cc
23  */
24 #include <tvm/expr.h>
25 #include <tvm/ir.h>
26 #include <tvm/attrs.h>
27 #include <tvm/ir_pass.h>
28 #include <tvm/ir_visitor.h>
29 #include <tvm/ir_mutator.h>
30 #include <tvm/api_registry.h>
31 
32 namespace tvm {
33 namespace ir {
34 
35 TVM_REGISTER_API("ir_pass.Simplify")
__anone4b6f6690102(TVMArgs args, TVMRetValue *ret) 36 .set_body([](TVMArgs args, TVMRetValue *ret) {
37     if (args[0].IsObjectRef<Stmt>()) {
38       if (args.size() > 1) {
39         *ret = Simplify(args[0].operator Stmt(), args[1]);
40       } else {
41         *ret = Simplify(args[0].operator Stmt());
42       }
43     } else {
44       if (args.size() > 1) {
45         *ret = Simplify(args[0].operator Expr(), args[1]);
46       } else {
47         *ret = Simplify(args[0].operator Expr());
48       }
49     }
50   });
51 
52 TVM_REGISTER_API("ir_pass.CanonicalSimplify")
__anone4b6f6690202(TVMArgs args, TVMRetValue *ret) 53 .set_body([](TVMArgs args, TVMRetValue *ret) {
54     if (args[0].IsObjectRef<Stmt>()) {
55       if (args.size() > 1) {
56         *ret = CanonicalSimplify(args[0].operator Stmt(), args[1]);
57       } else {
58         *ret = CanonicalSimplify(args[0].operator Stmt());
59       }
60     } else {
61       if (args.size() > 1) {
62         *ret = CanonicalSimplify(args[0].operator Expr(), args[1]);
63       } else {
64         *ret = CanonicalSimplify(args[0].operator Expr());
65       }
66     }
67   });
68 
69 TVM_REGISTER_API("ir_pass.Substitute")
__anone4b6f6690302(TVMArgs args, TVMRetValue *ret) 70 .set_body([](TVMArgs args, TVMRetValue *ret) {
71     if (args[0].IsObjectRef<Stmt>()) {
72       *ret = Substitute(args[0].operator Stmt(), args[1].operator Map<Var, Expr>());
73     } else {
74       *ret = Substitute(args[0].operator Expr(), args[1].operator Map<Var, Expr>());
75     }
76   });
77 
78 TVM_REGISTER_API("ir_pass.Equal")
__anone4b6f6690402(TVMArgs args, TVMRetValue *ret) 79 .set_body([](TVMArgs args, TVMRetValue *ret) {
80     if (args[0].IsObjectRef<Stmt>()) {
81       *ret = Equal(args[0].operator Stmt(), args[1].operator Stmt());
82     } else {
83       *ret = Equal(args[0].operator Expr(), args[1].operator Expr());
84     }
85   });
86 
87 TVM_REGISTER_API("ir_pass.StorageFlatten")
__anone4b6f6690502(TVMArgs args, TVMRetValue *ret) 88 .set_body([](TVMArgs args, TVMRetValue *ret) {
89     if (args.size() <= 3) {
90       *ret = StorageFlatten(args[0], args[1], args[2]);
91     } else {
92       *ret = StorageFlatten(args[0], args[1], args[2], args[3]);
93     }
94   });
95 
96 TVM_REGISTER_API("ir_pass.RewriteForTensorCore")
97 .set_body_typed<Stmt(const Stmt&, const Schedule&, const Map<Tensor, Buffer>&)>
__anone4b6f6690602(const Stmt& stmt, const Schedule& schedule, const Map<Tensor, Buffer>& extern_buffer) 98   ([](const Stmt& stmt, const Schedule& schedule, const Map<Tensor, Buffer>& extern_buffer) {
99       return RewriteForTensorCore(stmt, schedule, extern_buffer);
100   });
101 
102 TVM_REGISTER_API("ir_pass.AttrsEqual")
__anone4b6f6690702(const NodeRef& lhs, const NodeRef& rhs) 103 .set_body_typed<bool(const NodeRef&, const NodeRef&)>([](const NodeRef& lhs, const NodeRef& rhs) {
104     return AttrsEqual()(lhs, rhs);
105   });
106 
107 TVM_REGISTER_API("ir_pass.AttrsHash")
__anone4b6f6690802(const NodeRef &node) 108 .set_body_typed<int64_t(const NodeRef&)>([](const NodeRef &node) {
109     return AttrsHash()(node);
110   });
111 
112 
113 TVM_REGISTER_API("ir_pass.ExprUseVar")
__anone4b6f6690902(TVMArgs args, TVMRetValue *ret) 114 .set_body([](TVMArgs args, TVMRetValue *ret) {
115     *ret = ExprUseVar(args[0].operator Expr(), args[1].operator Var());
116   });
117 
118 TVM_REGISTER_API("ir_pass.PostOrderVisit")
__anone4b6f6690a02(TVMArgs args, TVMRetValue *ret) 119 .set_body([](TVMArgs args, TVMRetValue *ret) {
120     PackedFunc f = args[1];
121     ir::PostOrderVisit(args[0], [f](const NodeRef& n) {
122         f(n);
123       });
124   });
125 
126 TVM_REGISTER_API("ir_pass.LowerStorageAccess")
__anone4b6f6690c02(TVMArgs args, TVMRetValue *ret) 127 .set_body([](TVMArgs args, TVMRetValue *ret) {
128   LoweredFunc f = args[0];
129   auto n = make_node<LoweredFuncNode>(*f.operator->());
130   n->body = LowerStorageAccessInfo(f->body);
131   *ret = LoweredFunc(n);
132 });
133 
134 // make from two arguments
135 #define REGISTER_PASS(PassName)                                   \
136   TVM_REGISTER_API("ir_pass."#PassName)                           \
137   .set_body_typed(PassName);                                     \
138 
139 
140 REGISTER_PASS(ConvertSSA);
141 REGISTER_PASS(VerifySSA);
142 REGISTER_PASS(RewriteUnsafeSelect);
143 REGISTER_PASS(Inline);
144 REGISTER_PASS(IRTransform);
145 REGISTER_PASS(VectorizeLoop);
146 REGISTER_PASS(SkipVectorize);
147 REGISTER_PASS(UnrollLoop);
148 REGISTER_PASS(InjectCopyIntrin);
149 REGISTER_PASS(ThreadSync);
150 REGISTER_PASS(MakeAPI);
151 REGISTER_PASS(BindDeviceType);
152 REGISTER_PASS(SplitHostDevice);
153 REGISTER_PASS(StorageRewrite);
154 REGISTER_PASS(CoProcSync);
155 REGISTER_PASS(LowerStorageAccessInfo);
156 REGISTER_PASS(LowerDeviceStorageAccessInfo)
157 REGISTER_PASS(InjectVirtualThread);
158 REGISTER_PASS(InjectPrefetch);
159 REGISTER_PASS(InjectDoubleBuffer);
160 REGISTER_PASS(LoopPartition);
161 REGISTER_PASS(RemoveNoOp);
162 REGISTER_PASS(SplitPipeline);
163 REGISTER_PASS(LiftAttrScope);
164 REGISTER_PASS(NarrowChannelAccess);
165 REGISTER_PASS(LowerThreadAllreduce);
166 REGISTER_PASS(LowerWarpMemory);
167 REGISTER_PASS(RemapThreadAxis);
168 REGISTER_PASS(LowerIntrin);
169 REGISTER_PASS(LowerCustomDatatypes);
170 REGISTER_PASS(LowerTVMBuiltin);
171 REGISTER_PASS(CombineContextCall);
172 REGISTER_PASS(VerifyMemory);
173 REGISTER_PASS(VerifyGPUCode);
174 REGISTER_PASS(DecorateDeviceScope);
175 REGISTER_PASS(InstrumentBoundCheckers);
176 REGISTER_PASS(VerifyCompactBuffer);
177 REGISTER_PASS(HoistIfThenElse);
178 REGISTER_PASS(InferFragment)
179 }  // namespace ir
180 }  // namespace tvm
181