1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \brief Registration of TVM schedules
22  * \file schedule.cc
23  */
24 
25 #include <tvm/ir/expr.h>
26 #include <tvm/runtime/module.h>
27 #include <tvm/runtime/packed_func.h>
28 #include <tvm/runtime/registry.h>
29 #include <tvm/target/generic_func.h>
30 #include <tvm/topi/cuda/dense.h>
31 #include <tvm/topi/cuda/injective.h>
32 #include <tvm/topi/cuda/normalization.h>
33 #include <tvm/topi/cuda/pooling.h>
34 #include <tvm/topi/cuda/reduction.h>
35 #include <tvm/topi/cuda/softmax.h>
36 #include <tvm/topi/detail/tensor_utils.h>
37 #include <tvm/topi/generic/default.h>
38 #include <tvm/topi/generic/extern.h>
39 #include <tvm/topi/generic/injective.h>
40 #include <tvm/topi/rocm/dense.h>
41 #include <tvm/topi/rocm/injective.h>
42 #include <tvm/topi/rocm/normalization.h>
43 #include <tvm/topi/rocm/pooling.h>
44 #include <tvm/topi/rocm/reduction.h>
45 #include <tvm/topi/rocm/softmax.h>
46 #include <tvm/topi/x86/bnn.h>
47 #include <tvm/topi/x86/default.h>
48 #include <tvm/topi/x86/injective.h>
49 
50 namespace tvm {
51 namespace topi {
52 
53 using namespace tvm;
54 using namespace tvm::runtime;
55 
__anonbd8a05830102(TVMArgs args, TVMRetValue* rv) 56 TVM_REGISTER_GLOBAL("topi.TEST_create_target").set_body([](TVMArgs args, TVMRetValue* rv) {
57   *rv = tvm::Target(args[0].operator String());
58 });
59 
60 /* Generic schedules */
__anonbd8a05830202(TVMArgs args, TVMRetValue* rv) 61 TVM_REGISTER_GLOBAL("topi.generic.default_schedule").set_body([](TVMArgs args, TVMRetValue* rv) {
62   if (args[2]) {
63     *rv = topi::generic::default_schedule_auto_inline(args[0], args[1]);
64   } else {
65     *rv = topi::generic::default_schedule(args[0], args[1]);
66   }
67 });
68 
__anonbd8a05830302(TVMArgs args, TVMRetValue* rv) 69 TVM_REGISTER_GLOBAL("topi.generic.schedule_extern").set_body([](TVMArgs args, TVMRetValue* rv) {
70   *rv = topi::generic::schedule_extern(args[0], args[1]);
71 });
72 
__anonbd8a05830402(TVMArgs args, TVMRetValue* rv) 73 TVM_REGISTER_GLOBAL("topi.generic.schedule_injective").set_body([](TVMArgs args, TVMRetValue* rv) {
74   *rv = topi::generic::schedule_injective(args[0], args[1]);
75 });
76 
77 TVM_REGISTER_GLOBAL("topi.generic.schedule_injective_from_existing")
__anonbd8a05830502(TVMArgs args, TVMRetValue* rv) 78     .set_body([](TVMArgs args, TVMRetValue* rv) {
79       *rv = topi::generic::schedule_injective_from_existing(args[0], args[1]);
80     });
81 
82 /* x86 schedules */
__anonbd8a05830602(TVMArgs args, TVMRetValue* rv) 83 TVM_REGISTER_GLOBAL("topi.x86.schedule_binarize_pack").set_body([](TVMArgs args, TVMRetValue* rv) {
84   *rv = topi::x86::schedule_binarize_pack(args[0], args[1]);
85 });
86 
__anonbd8a05830702(TVMArgs args, TVMRetValue* rv) 87 TVM_REGISTER_GLOBAL("topi.x86.schedule_binary_dense").set_body([](TVMArgs args, TVMRetValue* rv) {
88   *rv = topi::x86::schedule_binary_dense(args[0], args[1]);
89 });
90 
__anonbd8a05830802(TVMArgs args, TVMRetValue* rv) 91 TVM_REGISTER_GLOBAL("topi.x86.default_schedule").set_body([](TVMArgs args, TVMRetValue* rv) {
92   if (args[2]) {
93     *rv = topi::x86::default_schedule_auto_inline(args[0], args[1]);
94   } else {
95     *rv = topi::x86::default_schedule(args[0], args[1]);
96   }
97 });
98 
__anonbd8a05830902(TVMArgs args, TVMRetValue* rv) 99 TVM_REGISTER_GLOBAL("topi.x86.schedule_injective").set_body([](TVMArgs args, TVMRetValue* rv) {
100   *rv = topi::x86::schedule_injective(args[0], args[1]);
101 });
102 
103 TVM_REGISTER_GLOBAL("topi.x86.schedule_injective_from_existing")
__anonbd8a05830a02(TVMArgs args, TVMRetValue* rv) 104     .set_body([](TVMArgs args, TVMRetValue* rv) {
105       *rv = topi::x86::schedule_injective_from_existing(args[0], args[1]);
106     });
107 
108 /* ROCm schedules */
__anonbd8a05830b02(TVMArgs args, TVMRetValue* rv) 109 TVM_REGISTER_GLOBAL("topi.rocm.dense_cuda").set_body([](TVMArgs args, TVMRetValue* rv) {
110   *rv = rocm::dense_rocm(args[0], args[1], args[2], args[3], args[4]);
111 });
112 
__anonbd8a05830c02(TVMArgs args, TVMRetValue* rv) 113 TVM_REGISTER_GLOBAL("topi.rocm.schedule_dense").set_body([](TVMArgs args, TVMRetValue* rv) {
114   *rv = topi::rocm::schedule_dense(args[0], args[1]);
115 });
116 
__anonbd8a05830d02(TVMArgs args, TVMRetValue* rv) 117 TVM_REGISTER_GLOBAL("topi.rocm.schedule_injective").set_body([](TVMArgs args, TVMRetValue* rv) {
118   *rv = topi::rocm::schedule_injective(args[0], args[1]);
119 });
120 
121 TVM_REGISTER_GLOBAL("topi.rocm.schedule_injective_from_existing")
__anonbd8a05830e02(TVMArgs args, TVMRetValue* rv) 122     .set_body([](TVMArgs args, TVMRetValue* rv) {
123       *rv = topi::rocm::schedule_injective_from_existing(args[0], args[1]);
124     });
125 
__anonbd8a05830f02(TVMArgs args, TVMRetValue* rv) 126 TVM_REGISTER_GLOBAL("topi.rocm.schedule_pool").set_body([](TVMArgs args, TVMRetValue* rv) {
127   *rv = topi::rocm::schedule_pool(args[0], args[1]);
128 });
129 
__anonbd8a05831002(TVMArgs args, TVMRetValue* rv) 130 TVM_REGISTER_GLOBAL("topi.rocm.schedule_global_pool").set_body([](TVMArgs args, TVMRetValue* rv) {
131   *rv = topi::rocm::schedule_global_pool(args[0], args[1]);
132 });
133 
__anonbd8a05831102(TVMArgs args, TVMRetValue* rv) 134 TVM_REGISTER_GLOBAL("topi.rocm.schedule_reduce").set_body([](TVMArgs args, TVMRetValue* rv) {
135   *rv = topi::rocm::schedule_reduce(args[0], args[1]);
136 });
137 
__anonbd8a05831202(TVMArgs args, TVMRetValue* rv) 138 TVM_REGISTER_GLOBAL("topi.rocm.schedule_softmax").set_body([](TVMArgs args, TVMRetValue* rv) {
139   *rv = topi::rocm::schedule_softmax(args[0], args[1]);
140 });
141 
__anonbd8a05831302(TVMArgs args, TVMRetValue* rv) 142 TVM_REGISTER_GLOBAL("topi.rocm.schedule_lrn").set_body([](TVMArgs args, TVMRetValue* rv) {
143   *rv = topi::rocm::schedule_lrn(args[0]);
144 });
145 
146 /* CUDA schedules */
__anonbd8a05831402(TVMArgs args, TVMRetValue* rv) 147 TVM_REGISTER_GLOBAL("topi.cuda.dense_cuda").set_body([](TVMArgs args, TVMRetValue* rv) {
148   *rv = cuda::dense_cuda(args[0], args[1], args[2], args[3], args[4]);
149 });
150 
__anonbd8a05831502(TVMArgs args, TVMRetValue* rv) 151 TVM_REGISTER_GLOBAL("topi.cuda.schedule_dense").set_body([](TVMArgs args, TVMRetValue* rv) {
152   *rv = topi::cuda::schedule_dense(args[0], args[1]);
153 });
154 
__anonbd8a05831602(TVMArgs args, TVMRetValue* rv) 155 TVM_REGISTER_GLOBAL("topi.cuda.schedule_injective").set_body([](TVMArgs args, TVMRetValue* rv) {
156   *rv = topi::cuda::schedule_injective(args[0], args[1]);
157 });
158 
159 TVM_REGISTER_GLOBAL("topi.cuda.schedule_injective_from_existing")
__anonbd8a05831702(TVMArgs args, TVMRetValue* rv) 160     .set_body([](TVMArgs args, TVMRetValue* rv) {
161       *rv = topi::cuda::schedule_injective_from_existing(args[0], args[1]);
162     });
163 
__anonbd8a05831802(TVMArgs args, TVMRetValue* rv) 164 TVM_REGISTER_GLOBAL("topi.cuda.schedule_pool").set_body([](TVMArgs args, TVMRetValue* rv) {
165   *rv = topi::cuda::schedule_pool(args[0], args[1]);
166 });
167 
__anonbd8a05831902(TVMArgs args, TVMRetValue* rv) 168 TVM_REGISTER_GLOBAL("topi.cuda.schedule_global_pool").set_body([](TVMArgs args, TVMRetValue* rv) {
169   *rv = topi::cuda::schedule_global_pool(args[0], args[1]);
170 });
171 
__anonbd8a05831a02(TVMArgs args, TVMRetValue* rv) 172 TVM_REGISTER_GLOBAL("topi.cuda.schedule_reduce").set_body([](TVMArgs args, TVMRetValue* rv) {
173   *rv = topi::cuda::schedule_reduce(args[0], args[1]);
174 });
175 
__anonbd8a05831b02(TVMArgs args, TVMRetValue* rv) 176 TVM_REGISTER_GLOBAL("topi.cuda.schedule_softmax").set_body([](TVMArgs args, TVMRetValue* rv) {
177   *rv = topi::cuda::schedule_softmax(args[0], args[1]);
178 });
179 
__anonbd8a05831c02(TVMArgs args, TVMRetValue* rv) 180 TVM_REGISTER_GLOBAL("topi.cuda.schedule_lrn").set_body([](TVMArgs args, TVMRetValue* rv) {
181   *rv = topi::cuda::schedule_lrn(args[0]);
182 });
183 
184 /* Utility functions */
__anonbd8a05831d02(TVMArgs args, TVMRetValue* rv) 185 TVM_REGISTER_GLOBAL("topi.util.is_empty_shape").set_body([](TVMArgs args, TVMRetValue* rv) {
186   *rv = topi::detail::is_empty_shape(args[0]);
187 });
188 
__anonbd8a05831e02(TVMArgs args, TVMRetValue* rv) 189 TVM_REGISTER_GLOBAL("topi.util.bilinear_sample_nchw").set_body([](TVMArgs args, TVMRetValue* rv) {
190   *rv = detail::bilinear_sample_nchw(args[0], args[1], args[2], args[3]);
191 });
192 
193 /*! \brief Builder function for instantiating schedules. */
194 using FTVMScheduleBuilder = std::function<tvm::te::Schedule(
195     const tvm::Target& target, const tvm::Array<tvm::te::Tensor>& outs)>;
196 
197 /*!
198  * \brief Helper function for registering generic functions matching the
199  * FTVMScheduleBuilder signature. The schedule builder function is wrapped
200  * with a PackedFunc suitable for passing to a tvm::GenericFunc.
201  *
202  * \param builder The schedule builder to wrap.
203  *
204  * \return The wrapped schedule builder
205  */
WrapSchedule(FTVMScheduleBuilder builder)206 inline PackedFunc WrapSchedule(FTVMScheduleBuilder builder) {
207   return PackedFunc([builder](TVMArgs args, TVMRetValue* ret) {
208     auto target = Target::Current(false);
209     Array<Tensor> outs;
210     ObjectRef argNodeRef = args[0];
211     if (argNodeRef->type_index() == outs->type_index()) {
212       outs = args[0];
213     } else {
214       outs = Array<Tensor>{args[0]};
215     }
216 
217     *ret = builder(target, outs);
218   });
219 }
220 
221 TVM_REGISTER_GENERIC_FUNC(schedule_injective)
222     .set_default(WrapSchedule(topi::generic::schedule_injective))
223     .register_func({"cpu"}, WrapSchedule(topi::x86::schedule_injective))
224     .register_func({"cuda", "gpu"}, WrapSchedule(topi::cuda::schedule_injective));
225 
226 TVM_REGISTER_GENERIC_FUNC(schedule_softmax)
227     .set_default(WrapSchedule(topi::generic::default_schedule))
228     .register_func({"cpu"}, WrapSchedule(topi::x86::default_schedule))
229     .register_func({"cuda", "gpu"}, WrapSchedule(topi::cuda::schedule_softmax));
230 
231 TVM_REGISTER_GENERIC_FUNC(schedule_dense)
232     .set_default(WrapSchedule(topi::generic::default_schedule))
233     .register_func({"cuda", "gpu"}, WrapSchedule(topi::cuda::schedule_dense))
234     .register_func({"rocm"}, WrapSchedule(topi::rocm::schedule_dense));
235 
236 TVM_REGISTER_GENERIC_FUNC(schedule_batch_matmul)
237     .set_default(WrapSchedule(topi::generic::default_schedule));
238 
239 TVM_REGISTER_GENERIC_FUNC(schedule_pool)
240     .set_default(WrapSchedule(topi::generic::default_schedule))
241     .register_func({"cpu"}, WrapSchedule(topi::x86::default_schedule))
242     .register_func({"cuda", "gpu"}, WrapSchedule(topi::cuda::schedule_pool));
243 
244 TVM_REGISTER_GENERIC_FUNC(schedule_global_pool)
245     .set_default(WrapSchedule(topi::generic::default_schedule))
246     .register_func({"cpu"}, WrapSchedule(topi::x86::default_schedule))
247     .register_func({"cuda", "gpu"}, WrapSchedule(topi::cuda::schedule_global_pool));
248 
249 TVM_REGISTER_GENERIC_FUNC(schedule_reduce)
250     .set_default(WrapSchedule(topi::generic::default_schedule_auto_inline))
251     .register_func({"cpu"}, WrapSchedule(topi::x86::default_schedule_auto_inline))
252     .register_func({"cuda", "gpu"}, WrapSchedule(topi::cuda::schedule_reduce));
253 
254 TVM_REGISTER_GENERIC_FUNC(schedule_binarize_pack)
255     .set_default(WrapSchedule(topi::generic::default_schedule))
256     .register_func({"cpu"}, WrapSchedule(topi::x86::schedule_binarize_pack));
257 
258 TVM_REGISTER_GENERIC_FUNC(schedule_binary_dense)
259     .set_default(WrapSchedule(topi::generic::default_schedule))
260     .register_func({"cpu"}, WrapSchedule(topi::x86::schedule_binary_dense));
261 
262 /*! \brief Builder function for instantiating schedules from existing schedules. */
263 using FTVMScheduleFromExistingBuilder =
264     std::function<tvm::te::Schedule(tvm::te::Schedule sch, const tvm::te::Tensor& out)>;
265 
266 /*!
267  * \brief Helper function for registering generic functions matching the
268  * FTVMScheduleFromExistingBuilder signature. The schedule builder function is wrapped
269  * with a PackedFunc suitable for passing to a tvm::GenericFunc.
270  *
271  * \param builder The schedule builder to wrap.
272  *
273  * \return The wrapped schedule builder
274  */
WrapScheduleFromExisting(FTVMScheduleFromExistingBuilder builder)275 inline PackedFunc WrapScheduleFromExisting(FTVMScheduleFromExistingBuilder builder) {
276   return PackedFunc(
277       [builder](TVMArgs args, TVMRetValue* ret) { *ret = builder(args[0], args[1]); });
278 }
279 
280 TVM_REGISTER_GENERIC_FUNC(schedule_injective_from_existing)
281     .set_default(WrapScheduleFromExisting(topi::generic::schedule_injective_from_existing))
282     .register_func({"cpu"}, WrapScheduleFromExisting(topi::x86::schedule_injective_from_existing))
283     .register_func({"cuda", "gpu"},
284                    WrapScheduleFromExisting(topi::cuda::schedule_injective_from_existing));
285 
286 /*! \brief Builder function for instantiating dense ops. */
287 using FTVMDenseOpBuilder = std::function<tvm::te::Tensor(
288     const Target& target, const tvm::te::Tensor& data, const tvm::te::Tensor& weight,
289     const tvm::te::Tensor& bias, const DataType& out_dtype)>;
290 
291 /*!
292  * \brief Helper function for registering dense ops matching the
293  * FTVMDenseOpBuilder signature. The op builder function is wrapped
294  * with a PackedFunc suitable for passing to a tvm::GenericFunc.
295  *
296  * \param builder The op builder to wrap.
297  *
298  * \return The wrapped op builder
299  */
WrapDenseOp(FTVMDenseOpBuilder builder)300 inline PackedFunc WrapDenseOp(FTVMDenseOpBuilder builder) {
301   return PackedFunc([builder](TVMArgs args, TVMRetValue* ret) {
302     auto target = Target::Current(false);
303     Tensor data = args[0];
304     Tensor weight = args[1];
305     Tensor bias = args[2];
306     DataType out_dtype = args[3];
307 
308     *ret = builder(target, data, weight, bias, out_dtype);
309   });
310 }
311 
312 TVM_REGISTER_GENERIC_FUNC(dense)
313     .set_default(WrapDenseOp([](const Target& target, const tvm::te::Tensor& data,
314                                 const tvm::te::Tensor& weight, const tvm::te::Tensor& bias,
__anonbd8a05832202(const Target& target, const tvm::te::Tensor& data, const tvm::te::Tensor& weight, const tvm::te::Tensor& bias, const DataType& out_dtype) 315                                 const DataType& out_dtype) {
316       return topi::nn::dense(data, weight, bias, out_dtype);
317     }))
318     .register_func({"cuda", "gpu"}, WrapDenseOp(topi::cuda::dense_cuda))
319     .register_func({"rocm"}, WrapDenseOp(topi::rocm::dense_rocm));
320 
321 }  // namespace topi
322 }  // namespace tvm
323