1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \brief Registration of TVM schedules
22 * \file schedule.cc
23 */
24
25 #include <tvm/ir/expr.h>
26 #include <tvm/runtime/module.h>
27 #include <tvm/runtime/packed_func.h>
28 #include <tvm/runtime/registry.h>
29 #include <tvm/target/generic_func.h>
30 #include <tvm/topi/cuda/dense.h>
31 #include <tvm/topi/cuda/injective.h>
32 #include <tvm/topi/cuda/normalization.h>
33 #include <tvm/topi/cuda/pooling.h>
34 #include <tvm/topi/cuda/reduction.h>
35 #include <tvm/topi/cuda/softmax.h>
36 #include <tvm/topi/detail/tensor_utils.h>
37 #include <tvm/topi/generic/default.h>
38 #include <tvm/topi/generic/extern.h>
39 #include <tvm/topi/generic/injective.h>
40 #include <tvm/topi/rocm/dense.h>
41 #include <tvm/topi/rocm/injective.h>
42 #include <tvm/topi/rocm/normalization.h>
43 #include <tvm/topi/rocm/pooling.h>
44 #include <tvm/topi/rocm/reduction.h>
45 #include <tvm/topi/rocm/softmax.h>
46 #include <tvm/topi/x86/bnn.h>
47 #include <tvm/topi/x86/default.h>
48 #include <tvm/topi/x86/injective.h>
49
50 namespace tvm {
51 namespace topi {
52
53 using namespace tvm;
54 using namespace tvm::runtime;
55
__anonbd8a05830102(TVMArgs args, TVMRetValue* rv) 56 TVM_REGISTER_GLOBAL("topi.TEST_create_target").set_body([](TVMArgs args, TVMRetValue* rv) {
57 *rv = tvm::Target(args[0].operator String());
58 });
59
60 /* Generic schedules */
__anonbd8a05830202(TVMArgs args, TVMRetValue* rv) 61 TVM_REGISTER_GLOBAL("topi.generic.default_schedule").set_body([](TVMArgs args, TVMRetValue* rv) {
62 if (args[2]) {
63 *rv = topi::generic::default_schedule_auto_inline(args[0], args[1]);
64 } else {
65 *rv = topi::generic::default_schedule(args[0], args[1]);
66 }
67 });
68
__anonbd8a05830302(TVMArgs args, TVMRetValue* rv) 69 TVM_REGISTER_GLOBAL("topi.generic.schedule_extern").set_body([](TVMArgs args, TVMRetValue* rv) {
70 *rv = topi::generic::schedule_extern(args[0], args[1]);
71 });
72
__anonbd8a05830402(TVMArgs args, TVMRetValue* rv) 73 TVM_REGISTER_GLOBAL("topi.generic.schedule_injective").set_body([](TVMArgs args, TVMRetValue* rv) {
74 *rv = topi::generic::schedule_injective(args[0], args[1]);
75 });
76
77 TVM_REGISTER_GLOBAL("topi.generic.schedule_injective_from_existing")
__anonbd8a05830502(TVMArgs args, TVMRetValue* rv) 78 .set_body([](TVMArgs args, TVMRetValue* rv) {
79 *rv = topi::generic::schedule_injective_from_existing(args[0], args[1]);
80 });
81
82 /* x86 schedules */
__anonbd8a05830602(TVMArgs args, TVMRetValue* rv) 83 TVM_REGISTER_GLOBAL("topi.x86.schedule_binarize_pack").set_body([](TVMArgs args, TVMRetValue* rv) {
84 *rv = topi::x86::schedule_binarize_pack(args[0], args[1]);
85 });
86
__anonbd8a05830702(TVMArgs args, TVMRetValue* rv) 87 TVM_REGISTER_GLOBAL("topi.x86.schedule_binary_dense").set_body([](TVMArgs args, TVMRetValue* rv) {
88 *rv = topi::x86::schedule_binary_dense(args[0], args[1]);
89 });
90
__anonbd8a05830802(TVMArgs args, TVMRetValue* rv) 91 TVM_REGISTER_GLOBAL("topi.x86.default_schedule").set_body([](TVMArgs args, TVMRetValue* rv) {
92 if (args[2]) {
93 *rv = topi::x86::default_schedule_auto_inline(args[0], args[1]);
94 } else {
95 *rv = topi::x86::default_schedule(args[0], args[1]);
96 }
97 });
98
__anonbd8a05830902(TVMArgs args, TVMRetValue* rv) 99 TVM_REGISTER_GLOBAL("topi.x86.schedule_injective").set_body([](TVMArgs args, TVMRetValue* rv) {
100 *rv = topi::x86::schedule_injective(args[0], args[1]);
101 });
102
103 TVM_REGISTER_GLOBAL("topi.x86.schedule_injective_from_existing")
__anonbd8a05830a02(TVMArgs args, TVMRetValue* rv) 104 .set_body([](TVMArgs args, TVMRetValue* rv) {
105 *rv = topi::x86::schedule_injective_from_existing(args[0], args[1]);
106 });
107
108 /* ROCm schedules */
__anonbd8a05830b02(TVMArgs args, TVMRetValue* rv) 109 TVM_REGISTER_GLOBAL("topi.rocm.dense_cuda").set_body([](TVMArgs args, TVMRetValue* rv) {
110 *rv = rocm::dense_rocm(args[0], args[1], args[2], args[3], args[4]);
111 });
112
__anonbd8a05830c02(TVMArgs args, TVMRetValue* rv) 113 TVM_REGISTER_GLOBAL("topi.rocm.schedule_dense").set_body([](TVMArgs args, TVMRetValue* rv) {
114 *rv = topi::rocm::schedule_dense(args[0], args[1]);
115 });
116
__anonbd8a05830d02(TVMArgs args, TVMRetValue* rv) 117 TVM_REGISTER_GLOBAL("topi.rocm.schedule_injective").set_body([](TVMArgs args, TVMRetValue* rv) {
118 *rv = topi::rocm::schedule_injective(args[0], args[1]);
119 });
120
121 TVM_REGISTER_GLOBAL("topi.rocm.schedule_injective_from_existing")
__anonbd8a05830e02(TVMArgs args, TVMRetValue* rv) 122 .set_body([](TVMArgs args, TVMRetValue* rv) {
123 *rv = topi::rocm::schedule_injective_from_existing(args[0], args[1]);
124 });
125
__anonbd8a05830f02(TVMArgs args, TVMRetValue* rv) 126 TVM_REGISTER_GLOBAL("topi.rocm.schedule_pool").set_body([](TVMArgs args, TVMRetValue* rv) {
127 *rv = topi::rocm::schedule_pool(args[0], args[1]);
128 });
129
__anonbd8a05831002(TVMArgs args, TVMRetValue* rv) 130 TVM_REGISTER_GLOBAL("topi.rocm.schedule_global_pool").set_body([](TVMArgs args, TVMRetValue* rv) {
131 *rv = topi::rocm::schedule_global_pool(args[0], args[1]);
132 });
133
__anonbd8a05831102(TVMArgs args, TVMRetValue* rv) 134 TVM_REGISTER_GLOBAL("topi.rocm.schedule_reduce").set_body([](TVMArgs args, TVMRetValue* rv) {
135 *rv = topi::rocm::schedule_reduce(args[0], args[1]);
136 });
137
__anonbd8a05831202(TVMArgs args, TVMRetValue* rv) 138 TVM_REGISTER_GLOBAL("topi.rocm.schedule_softmax").set_body([](TVMArgs args, TVMRetValue* rv) {
139 *rv = topi::rocm::schedule_softmax(args[0], args[1]);
140 });
141
__anonbd8a05831302(TVMArgs args, TVMRetValue* rv) 142 TVM_REGISTER_GLOBAL("topi.rocm.schedule_lrn").set_body([](TVMArgs args, TVMRetValue* rv) {
143 *rv = topi::rocm::schedule_lrn(args[0]);
144 });
145
146 /* CUDA schedules */
__anonbd8a05831402(TVMArgs args, TVMRetValue* rv) 147 TVM_REGISTER_GLOBAL("topi.cuda.dense_cuda").set_body([](TVMArgs args, TVMRetValue* rv) {
148 *rv = cuda::dense_cuda(args[0], args[1], args[2], args[3], args[4]);
149 });
150
__anonbd8a05831502(TVMArgs args, TVMRetValue* rv) 151 TVM_REGISTER_GLOBAL("topi.cuda.schedule_dense").set_body([](TVMArgs args, TVMRetValue* rv) {
152 *rv = topi::cuda::schedule_dense(args[0], args[1]);
153 });
154
__anonbd8a05831602(TVMArgs args, TVMRetValue* rv) 155 TVM_REGISTER_GLOBAL("topi.cuda.schedule_injective").set_body([](TVMArgs args, TVMRetValue* rv) {
156 *rv = topi::cuda::schedule_injective(args[0], args[1]);
157 });
158
159 TVM_REGISTER_GLOBAL("topi.cuda.schedule_injective_from_existing")
__anonbd8a05831702(TVMArgs args, TVMRetValue* rv) 160 .set_body([](TVMArgs args, TVMRetValue* rv) {
161 *rv = topi::cuda::schedule_injective_from_existing(args[0], args[1]);
162 });
163
__anonbd8a05831802(TVMArgs args, TVMRetValue* rv) 164 TVM_REGISTER_GLOBAL("topi.cuda.schedule_pool").set_body([](TVMArgs args, TVMRetValue* rv) {
165 *rv = topi::cuda::schedule_pool(args[0], args[1]);
166 });
167
__anonbd8a05831902(TVMArgs args, TVMRetValue* rv) 168 TVM_REGISTER_GLOBAL("topi.cuda.schedule_global_pool").set_body([](TVMArgs args, TVMRetValue* rv) {
169 *rv = topi::cuda::schedule_global_pool(args[0], args[1]);
170 });
171
__anonbd8a05831a02(TVMArgs args, TVMRetValue* rv) 172 TVM_REGISTER_GLOBAL("topi.cuda.schedule_reduce").set_body([](TVMArgs args, TVMRetValue* rv) {
173 *rv = topi::cuda::schedule_reduce(args[0], args[1]);
174 });
175
__anonbd8a05831b02(TVMArgs args, TVMRetValue* rv) 176 TVM_REGISTER_GLOBAL("topi.cuda.schedule_softmax").set_body([](TVMArgs args, TVMRetValue* rv) {
177 *rv = topi::cuda::schedule_softmax(args[0], args[1]);
178 });
179
__anonbd8a05831c02(TVMArgs args, TVMRetValue* rv) 180 TVM_REGISTER_GLOBAL("topi.cuda.schedule_lrn").set_body([](TVMArgs args, TVMRetValue* rv) {
181 *rv = topi::cuda::schedule_lrn(args[0]);
182 });
183
184 /* Utility functions */
__anonbd8a05831d02(TVMArgs args, TVMRetValue* rv) 185 TVM_REGISTER_GLOBAL("topi.util.is_empty_shape").set_body([](TVMArgs args, TVMRetValue* rv) {
186 *rv = topi::detail::is_empty_shape(args[0]);
187 });
188
__anonbd8a05831e02(TVMArgs args, TVMRetValue* rv) 189 TVM_REGISTER_GLOBAL("topi.util.bilinear_sample_nchw").set_body([](TVMArgs args, TVMRetValue* rv) {
190 *rv = detail::bilinear_sample_nchw(args[0], args[1], args[2], args[3]);
191 });
192
193 /*! \brief Builder function for instantiating schedules. */
194 using FTVMScheduleBuilder = std::function<tvm::te::Schedule(
195 const tvm::Target& target, const tvm::Array<tvm::te::Tensor>& outs)>;
196
197 /*!
198 * \brief Helper function for registering generic functions matching the
199 * FTVMScheduleBuilder signature. The schedule builder function is wrapped
200 * with a PackedFunc suitable for passing to a tvm::GenericFunc.
201 *
202 * \param builder The schedule builder to wrap.
203 *
204 * \return The wrapped schedule builder
205 */
WrapSchedule(FTVMScheduleBuilder builder)206 inline PackedFunc WrapSchedule(FTVMScheduleBuilder builder) {
207 return PackedFunc([builder](TVMArgs args, TVMRetValue* ret) {
208 auto target = Target::Current(false);
209 Array<Tensor> outs;
210 ObjectRef argNodeRef = args[0];
211 if (argNodeRef->type_index() == outs->type_index()) {
212 outs = args[0];
213 } else {
214 outs = Array<Tensor>{args[0]};
215 }
216
217 *ret = builder(target, outs);
218 });
219 }
220
221 TVM_REGISTER_GENERIC_FUNC(schedule_injective)
222 .set_default(WrapSchedule(topi::generic::schedule_injective))
223 .register_func({"cpu"}, WrapSchedule(topi::x86::schedule_injective))
224 .register_func({"cuda", "gpu"}, WrapSchedule(topi::cuda::schedule_injective));
225
226 TVM_REGISTER_GENERIC_FUNC(schedule_softmax)
227 .set_default(WrapSchedule(topi::generic::default_schedule))
228 .register_func({"cpu"}, WrapSchedule(topi::x86::default_schedule))
229 .register_func({"cuda", "gpu"}, WrapSchedule(topi::cuda::schedule_softmax));
230
231 TVM_REGISTER_GENERIC_FUNC(schedule_dense)
232 .set_default(WrapSchedule(topi::generic::default_schedule))
233 .register_func({"cuda", "gpu"}, WrapSchedule(topi::cuda::schedule_dense))
234 .register_func({"rocm"}, WrapSchedule(topi::rocm::schedule_dense));
235
236 TVM_REGISTER_GENERIC_FUNC(schedule_batch_matmul)
237 .set_default(WrapSchedule(topi::generic::default_schedule));
238
239 TVM_REGISTER_GENERIC_FUNC(schedule_pool)
240 .set_default(WrapSchedule(topi::generic::default_schedule))
241 .register_func({"cpu"}, WrapSchedule(topi::x86::default_schedule))
242 .register_func({"cuda", "gpu"}, WrapSchedule(topi::cuda::schedule_pool));
243
244 TVM_REGISTER_GENERIC_FUNC(schedule_global_pool)
245 .set_default(WrapSchedule(topi::generic::default_schedule))
246 .register_func({"cpu"}, WrapSchedule(topi::x86::default_schedule))
247 .register_func({"cuda", "gpu"}, WrapSchedule(topi::cuda::schedule_global_pool));
248
249 TVM_REGISTER_GENERIC_FUNC(schedule_reduce)
250 .set_default(WrapSchedule(topi::generic::default_schedule_auto_inline))
251 .register_func({"cpu"}, WrapSchedule(topi::x86::default_schedule_auto_inline))
252 .register_func({"cuda", "gpu"}, WrapSchedule(topi::cuda::schedule_reduce));
253
254 TVM_REGISTER_GENERIC_FUNC(schedule_binarize_pack)
255 .set_default(WrapSchedule(topi::generic::default_schedule))
256 .register_func({"cpu"}, WrapSchedule(topi::x86::schedule_binarize_pack));
257
258 TVM_REGISTER_GENERIC_FUNC(schedule_binary_dense)
259 .set_default(WrapSchedule(topi::generic::default_schedule))
260 .register_func({"cpu"}, WrapSchedule(topi::x86::schedule_binary_dense));
261
262 /*! \brief Builder function for instantiating schedules from existing schedules. */
263 using FTVMScheduleFromExistingBuilder =
264 std::function<tvm::te::Schedule(tvm::te::Schedule sch, const tvm::te::Tensor& out)>;
265
266 /*!
267 * \brief Helper function for registering generic functions matching the
268 * FTVMScheduleFromExistingBuilder signature. The schedule builder function is wrapped
269 * with a PackedFunc suitable for passing to a tvm::GenericFunc.
270 *
271 * \param builder The schedule builder to wrap.
272 *
273 * \return The wrapped schedule builder
274 */
WrapScheduleFromExisting(FTVMScheduleFromExistingBuilder builder)275 inline PackedFunc WrapScheduleFromExisting(FTVMScheduleFromExistingBuilder builder) {
276 return PackedFunc(
277 [builder](TVMArgs args, TVMRetValue* ret) { *ret = builder(args[0], args[1]); });
278 }
279
280 TVM_REGISTER_GENERIC_FUNC(schedule_injective_from_existing)
281 .set_default(WrapScheduleFromExisting(topi::generic::schedule_injective_from_existing))
282 .register_func({"cpu"}, WrapScheduleFromExisting(topi::x86::schedule_injective_from_existing))
283 .register_func({"cuda", "gpu"},
284 WrapScheduleFromExisting(topi::cuda::schedule_injective_from_existing));
285
286 /*! \brief Builder function for instantiating dense ops. */
287 using FTVMDenseOpBuilder = std::function<tvm::te::Tensor(
288 const Target& target, const tvm::te::Tensor& data, const tvm::te::Tensor& weight,
289 const tvm::te::Tensor& bias, const DataType& out_dtype)>;
290
291 /*!
292 * \brief Helper function for registering dense ops matching the
293 * FTVMDenseOpBuilder signature. The op builder function is wrapped
294 * with a PackedFunc suitable for passing to a tvm::GenericFunc.
295 *
296 * \param builder The op builder to wrap.
297 *
298 * \return The wrapped op builder
299 */
WrapDenseOp(FTVMDenseOpBuilder builder)300 inline PackedFunc WrapDenseOp(FTVMDenseOpBuilder builder) {
301 return PackedFunc([builder](TVMArgs args, TVMRetValue* ret) {
302 auto target = Target::Current(false);
303 Tensor data = args[0];
304 Tensor weight = args[1];
305 Tensor bias = args[2];
306 DataType out_dtype = args[3];
307
308 *ret = builder(target, data, weight, bias, out_dtype);
309 });
310 }
311
312 TVM_REGISTER_GENERIC_FUNC(dense)
313 .set_default(WrapDenseOp([](const Target& target, const tvm::te::Tensor& data,
314 const tvm::te::Tensor& weight, const tvm::te::Tensor& bias,
__anonbd8a05832202(const Target& target, const tvm::te::Tensor& data, const tvm::te::Tensor& weight, const tvm::te::Tensor& bias, const DataType& out_dtype) 315 const DataType& out_dtype) {
316 return topi::nn::dense(data, weight, bias, out_dtype);
317 }))
318 .register_func({"cuda", "gpu"}, WrapDenseOp(topi::cuda::dense_cuda))
319 .register_func({"rocm"}, WrapDenseOp(topi::rocm::dense_rocm));
320
321 } // namespace topi
322 } // namespace tvm
323