1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * Exposre of pass functions.
22 * \file api_pass.cc
23 */
24 #include <tvm/expr.h>
25 #include <tvm/ir.h>
26 #include <tvm/attrs.h>
27 #include <tvm/ir_pass.h>
28 #include <tvm/ir_visitor.h>
29 #include <tvm/ir_mutator.h>
30 #include <tvm/api_registry.h>
31
32 namespace tvm {
33 namespace ir {
34
35 TVM_REGISTER_API("ir_pass.Simplify")
__anone4b6f6690102(TVMArgs args, TVMRetValue *ret) 36 .set_body([](TVMArgs args, TVMRetValue *ret) {
37 if (args[0].IsObjectRef<Stmt>()) {
38 if (args.size() > 1) {
39 *ret = Simplify(args[0].operator Stmt(), args[1]);
40 } else {
41 *ret = Simplify(args[0].operator Stmt());
42 }
43 } else {
44 if (args.size() > 1) {
45 *ret = Simplify(args[0].operator Expr(), args[1]);
46 } else {
47 *ret = Simplify(args[0].operator Expr());
48 }
49 }
50 });
51
52 TVM_REGISTER_API("ir_pass.CanonicalSimplify")
__anone4b6f6690202(TVMArgs args, TVMRetValue *ret) 53 .set_body([](TVMArgs args, TVMRetValue *ret) {
54 if (args[0].IsObjectRef<Stmt>()) {
55 if (args.size() > 1) {
56 *ret = CanonicalSimplify(args[0].operator Stmt(), args[1]);
57 } else {
58 *ret = CanonicalSimplify(args[0].operator Stmt());
59 }
60 } else {
61 if (args.size() > 1) {
62 *ret = CanonicalSimplify(args[0].operator Expr(), args[1]);
63 } else {
64 *ret = CanonicalSimplify(args[0].operator Expr());
65 }
66 }
67 });
68
69 TVM_REGISTER_API("ir_pass.Substitute")
__anone4b6f6690302(TVMArgs args, TVMRetValue *ret) 70 .set_body([](TVMArgs args, TVMRetValue *ret) {
71 if (args[0].IsObjectRef<Stmt>()) {
72 *ret = Substitute(args[0].operator Stmt(), args[1].operator Map<Var, Expr>());
73 } else {
74 *ret = Substitute(args[0].operator Expr(), args[1].operator Map<Var, Expr>());
75 }
76 });
77
78 TVM_REGISTER_API("ir_pass.Equal")
__anone4b6f6690402(TVMArgs args, TVMRetValue *ret) 79 .set_body([](TVMArgs args, TVMRetValue *ret) {
80 if (args[0].IsObjectRef<Stmt>()) {
81 *ret = Equal(args[0].operator Stmt(), args[1].operator Stmt());
82 } else {
83 *ret = Equal(args[0].operator Expr(), args[1].operator Expr());
84 }
85 });
86
87 TVM_REGISTER_API("ir_pass.StorageFlatten")
__anone4b6f6690502(TVMArgs args, TVMRetValue *ret) 88 .set_body([](TVMArgs args, TVMRetValue *ret) {
89 if (args.size() <= 3) {
90 *ret = StorageFlatten(args[0], args[1], args[2]);
91 } else {
92 *ret = StorageFlatten(args[0], args[1], args[2], args[3]);
93 }
94 });
95
96 TVM_REGISTER_API("ir_pass.RewriteForTensorCore")
97 .set_body_typed<Stmt(const Stmt&, const Schedule&, const Map<Tensor, Buffer>&)>
__anone4b6f6690602(const Stmt& stmt, const Schedule& schedule, const Map<Tensor, Buffer>& extern_buffer) 98 ([](const Stmt& stmt, const Schedule& schedule, const Map<Tensor, Buffer>& extern_buffer) {
99 return RewriteForTensorCore(stmt, schedule, extern_buffer);
100 });
101
102 TVM_REGISTER_API("ir_pass.AttrsEqual")
__anone4b6f6690702(const NodeRef& lhs, const NodeRef& rhs) 103 .set_body_typed<bool(const NodeRef&, const NodeRef&)>([](const NodeRef& lhs, const NodeRef& rhs) {
104 return AttrsEqual()(lhs, rhs);
105 });
106
107 TVM_REGISTER_API("ir_pass.AttrsHash")
__anone4b6f6690802(const NodeRef &node) 108 .set_body_typed<int64_t(const NodeRef&)>([](const NodeRef &node) {
109 return AttrsHash()(node);
110 });
111
112
113 TVM_REGISTER_API("ir_pass.ExprUseVar")
__anone4b6f6690902(TVMArgs args, TVMRetValue *ret) 114 .set_body([](TVMArgs args, TVMRetValue *ret) {
115 *ret = ExprUseVar(args[0].operator Expr(), args[1].operator Var());
116 });
117
118 TVM_REGISTER_API("ir_pass.PostOrderVisit")
__anone4b6f6690a02(TVMArgs args, TVMRetValue *ret) 119 .set_body([](TVMArgs args, TVMRetValue *ret) {
120 PackedFunc f = args[1];
121 ir::PostOrderVisit(args[0], [f](const NodeRef& n) {
122 f(n);
123 });
124 });
125
126 TVM_REGISTER_API("ir_pass.LowerStorageAccess")
__anone4b6f6690c02(TVMArgs args, TVMRetValue *ret) 127 .set_body([](TVMArgs args, TVMRetValue *ret) {
128 LoweredFunc f = args[0];
129 auto n = make_node<LoweredFuncNode>(*f.operator->());
130 n->body = LowerStorageAccessInfo(f->body);
131 *ret = LoweredFunc(n);
132 });
133
134 // make from two arguments
135 #define REGISTER_PASS(PassName) \
136 TVM_REGISTER_API("ir_pass."#PassName) \
137 .set_body_typed(PassName); \
138
139
140 REGISTER_PASS(ConvertSSA);
141 REGISTER_PASS(VerifySSA);
142 REGISTER_PASS(RewriteUnsafeSelect);
143 REGISTER_PASS(Inline);
144 REGISTER_PASS(IRTransform);
145 REGISTER_PASS(VectorizeLoop);
146 REGISTER_PASS(SkipVectorize);
147 REGISTER_PASS(UnrollLoop);
148 REGISTER_PASS(InjectCopyIntrin);
149 REGISTER_PASS(ThreadSync);
150 REGISTER_PASS(MakeAPI);
151 REGISTER_PASS(BindDeviceType);
152 REGISTER_PASS(SplitHostDevice);
153 REGISTER_PASS(StorageRewrite);
154 REGISTER_PASS(CoProcSync);
155 REGISTER_PASS(LowerStorageAccessInfo);
156 REGISTER_PASS(LowerDeviceStorageAccessInfo)
157 REGISTER_PASS(InjectVirtualThread);
158 REGISTER_PASS(InjectPrefetch);
159 REGISTER_PASS(InjectDoubleBuffer);
160 REGISTER_PASS(LoopPartition);
161 REGISTER_PASS(RemoveNoOp);
162 REGISTER_PASS(SplitPipeline);
163 REGISTER_PASS(LiftAttrScope);
164 REGISTER_PASS(NarrowChannelAccess);
165 REGISTER_PASS(LowerThreadAllreduce);
166 REGISTER_PASS(LowerWarpMemory);
167 REGISTER_PASS(RemapThreadAxis);
168 REGISTER_PASS(LowerIntrin);
169 REGISTER_PASS(LowerCustomDatatypes);
170 REGISTER_PASS(LowerTVMBuiltin);
171 REGISTER_PASS(CombineContextCall);
172 REGISTER_PASS(VerifyMemory);
173 REGISTER_PASS(VerifyGPUCode);
174 REGISTER_PASS(DecorateDeviceScope);
175 REGISTER_PASS(InstrumentBoundCheckers);
176 REGISTER_PASS(VerifyCompactBuffer);
177 REGISTER_PASS(HoistIfThenElse);
178 REGISTER_PASS(InferFragment)
179 } // namespace ir
180 } // namespace tvm
181