1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file pack_args.h
22 * \brief Utility to pack TVMArgs to other type-erased fution calling convention.
23 *
24 * Two type erased function signatures are supported.
25 * - cuda_style(void** args, int num_args);
26 * - Pack everything by address
27 * - metal_style(void** buffers, int num_buffers,
28 * union_32bit args[N], int num_args);
29 * - Pack buffer by address, pack rest parameter into 32bit union buffer.
30 */
31 #ifndef TVM_RUNTIME_PACK_ARGS_H_
32 #define TVM_RUNTIME_PACK_ARGS_H_
33
34 #include <tvm/runtime/c_runtime_api.h>
35 #include <vector>
36 #include <cstring>
37
38 namespace tvm {
39 namespace runtime {
40 /*!
41 * \brief argument union type of 32bit.
42 * Choose 32 bit because most GPU API do not work well with 64 bit.
43 */
44 union ArgUnion {
45 int32_t v_int32;
46 uint32_t v_uint32;
47 float v_float32;
48 };
49 /*!
50 * \brief Create a packed function from void addr types.
51 *
52 * \param f with signiture (TVMArgs args, TVMRetValue* rv, void* void_args)
53 * \param arg_types The arguments type information.
54 * \tparam F the function type
55 *
56 * \return The wrapped packed function.
57 */
58 template<typename F>
59 inline PackedFunc PackFuncVoidAddr(F f, const std::vector<TVMType>& arg_types);
60 /*!
61 * \brief Create a packed function that from function only packs buffer arguments.
62 *
63 * \param f with signiture (TVMArgs args, TVMRetValue* rv, ArgUnion* pack_args)
64 * \param arg_types The arguments type information.
65 * \tparam F the function type
66 *
67 * \return The wrapped packed function.
68 */
69 template<typename F>
70 inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<TVMType>& arg_types);
71 /*!
72 * \brief Create a packed function that from function that takes a packed arguments.
73 *
74 * \param f with signature (TVMArgs args, TVMRetValue* rv, void* pack_args, size_t nbytes)
75 * \param arg_types The arguments that wish to get from
76 * \tparam F the function type
77 *
78 * \return The wrapped packed function.
79 */
80 template<typename F>
81 inline PackedFunc PackFuncPackedArg(F f, const std::vector<TVMType>& arg_types);
82 /*!
83 * \brief Extract number of buffer argument from the argument types.
84 * \param arg_types The argument types.
85 * \return number of buffer arguments
86 */
87 inline size_t NumBufferArgs(const std::vector<TVMType>& arg_types);
88
89 // implementations details
90 namespace detail {
91 template<typename T, int kSize>
92 class TempArray {
93 public:
TempArray(int size)94 explicit TempArray(int size) {}
data()95 T* data() {
96 return data_;
97 }
98 private:
99 T data_[kSize];
100 };
101 template<typename T>
102 class TempArray<T, 0> {
103 public:
TempArray(int size)104 explicit TempArray(int size) : data_(size) {}
data()105 T* data() {
106 return data_.data();
107 }
108 private:
109 std::vector<T> data_;
110 };
111
112 /*! \brief conversion code used in void arg. */
113 enum ArgConvertCode {
114 INT64_TO_INT64,
115 INT64_TO_INT32,
116 INT64_TO_UINT32,
117 FLOAT64_TO_FLOAT32,
118 FLOAT64_TO_FLOAT64,
119 HANDLE_TO_HANDLE
120 };
121
GetArgConvertCode(TVMType t)122 inline ArgConvertCode GetArgConvertCode(TVMType t) {
123 CHECK_EQ(t.lanes, 1U)
124 << "Cannot pass vector type argument to devic function for now";
125 if (t.code == kDLInt) {
126 if (t.bits == 64U) return INT64_TO_INT64;
127 if (t.bits == 32U) return INT64_TO_INT32;
128 } else if (t.code == kDLUInt) {
129 if (t.bits == 32U) return INT64_TO_UINT32;
130 } else if (t.code == kDLFloat) {
131 if (t.bits == 64U) return FLOAT64_TO_FLOAT64;
132 if (t.bits == 32U) return FLOAT64_TO_FLOAT32;
133 } else if (t.code == kHandle) {
134 return HANDLE_TO_HANDLE;
135 }
136 LOG(FATAL) << "Cannot handle " << t << " as device function argument";
137 return HANDLE_TO_HANDLE;
138 }
139
140 template<int N, typename F>
PackFuncVoidAddr_(F f,const std::vector<ArgConvertCode> & codes)141 inline PackedFunc PackFuncVoidAddr_(F f, const std::vector<ArgConvertCode>& codes) {
142 int num_args = static_cast<int>(codes.size());
143 auto ret = [f, codes, num_args](TVMArgs args, TVMRetValue* ret) {
144 TempArray<void*, N> addr_(num_args);
145 TempArray<ArgUnion, N> holder_(num_args);
146 void** addr = addr_.data();
147 ArgUnion* holder = holder_.data();
148 for (int i = 0; i < num_args; ++i) {
149 switch (codes[i]) {
150 case INT64_TO_INT64:
151 case FLOAT64_TO_FLOAT64:
152 case HANDLE_TO_HANDLE: {
153 addr[i] = (void*)&(args.values[i]); // NOLINT(*)
154 break;
155 }
156 case INT64_TO_INT32: {
157 holder[i].v_int32 = static_cast<int32_t>(args.values[i].v_int64);
158 addr[i] = &(holder[i]);
159 break;
160 }
161 case INT64_TO_UINT32 : {
162 holder[i].v_uint32 = static_cast<uint32_t>(args.values[i].v_int64);
163 addr[i] = &(holder[i]);
164 break;
165 }
166 case FLOAT64_TO_FLOAT32: {
167 holder[i].v_float32 = static_cast<float>(args.values[i].v_float64);
168 addr[i] = &(holder[i]);
169 break;
170 }
171 }
172 }
173 f(args, ret, addr);
174 };
175 return PackedFunc(ret);
176 }
177
178 template<int N, typename F>
PackFuncNonBufferArg_(F f,int base,const std::vector<ArgConvertCode> & codes)179 inline PackedFunc PackFuncNonBufferArg_(
180 F f, int base, const std::vector<ArgConvertCode>& codes) {
181 int num_args = static_cast<int>(codes.size());
182 auto ret = [f, codes, base, num_args](TVMArgs args, TVMRetValue* ret) {
183 TempArray<ArgUnion, N> holder_(num_args);
184 ArgUnion* holder = holder_.data();
185 for (int i = 0; i < num_args; ++i) {
186 switch (codes[i]) {
187 case INT64_TO_INT64:
188 case FLOAT64_TO_FLOAT64: {
189 LOG(FATAL) << "Do not support 64bit argument to device function"; break;
190 }
191 case INT64_TO_INT32: {
192 holder[i].v_int32 = static_cast<int32_t>(args.values[base + i].v_int64);
193 break;
194 }
195 case INT64_TO_UINT32 : {
196 holder[i].v_uint32 = static_cast<uint32_t>(args.values[base + i].v_int64);
197 break;
198 }
199 case FLOAT64_TO_FLOAT32: {
200 holder[i].v_float32 = static_cast<float>(args.values[base + i].v_float64);
201 break;
202 }
203 case HANDLE_TO_HANDLE: {
204 LOG(FATAL) << "not reached"; break;
205 }
206 }
207 }
208 f(args, ret, holder);
209 };
210 return PackedFunc(ret);
211 }
212
213 template<int N, typename F>
PackFuncPackedArg_(F f,const std::vector<ArgConvertCode> & codes)214 inline PackedFunc PackFuncPackedArg_(
215 F f, const std::vector<ArgConvertCode>& codes) {
216 int num_args = static_cast<int>(codes.size());
217 auto ret = [f, codes, num_args](TVMArgs args, TVMRetValue* ret) {
218 TempArray<uint64_t, N> pack_(num_args);
219 int32_t* pack = reinterpret_cast<int32_t*>(pack_.data());
220 int32_t* ptr = pack;
221 static_assert(sizeof(TVMValue) == 8, "invariant");
222 static_assert(sizeof(void*) % sizeof(int32_t) == 0, "invariant");
223 for (int i = 0; i < num_args; ++i) {
224 switch (codes[i]) {
225 case HANDLE_TO_HANDLE: {
226 std::memcpy(ptr, &(args.values[i].v_handle), sizeof(void*));
227 ptr += sizeof(void*) / sizeof(int32_t);
228 break;
229 }
230 case INT64_TO_INT64:
231 case FLOAT64_TO_FLOAT64: {
232 std::memcpy(ptr, &args.values[i], sizeof(TVMValue));
233 ptr += 2;
234 break;
235 }
236 case INT64_TO_INT32: {
237 *ptr = static_cast<int32_t>(args.values[i].v_int64);
238 ++ptr;
239 break;
240 }
241 case INT64_TO_UINT32 : {
242 *reinterpret_cast<uint32_t*>(ptr) =
243 static_cast<uint32_t>(args.values[i].v_int64);
244 ++ptr;
245 break;
246 }
247 case FLOAT64_TO_FLOAT32: {
248 *reinterpret_cast<float*>(ptr) =
249 static_cast<float>(args.values[i].v_float64);
250 ++ptr;
251 break;
252 }
253 default: {
254 LOG(FATAL) << "not reached"; break;
255 }
256 }
257 }
258 f(args, ret, pack, (ptr - pack) * sizeof(int32_t));
259 };
260 return PackedFunc(ret);
261 }
262 } // namespace detail
263
264 template<typename F>
PackFuncVoidAddr(F f,const std::vector<TVMType> & arg_types)265 inline PackedFunc PackFuncVoidAddr(F f, const std::vector<TVMType>& arg_types) {
266 std::vector<detail::ArgConvertCode> codes(arg_types.size());
267 for (size_t i = 0; i < arg_types.size(); ++i) {
268 codes[i] = detail::GetArgConvertCode(arg_types[i]);
269 }
270 size_t num_void_args = arg_types.size();
271 // specialization
272 if (num_void_args <= 4) {
273 return detail::PackFuncVoidAddr_<4>(f, codes);
274 } else if (num_void_args <= 8) {
275 return detail::PackFuncVoidAddr_<8>(f, codes);
276 } else {
277 return detail::PackFuncVoidAddr_<0>(f, codes);
278 }
279 }
280
NumBufferArgs(const std::vector<TVMType> & arg_types)281 inline size_t NumBufferArgs(const std::vector<TVMType>& arg_types) {
282 size_t base = arg_types.size();
283 for (size_t i = 0; i < arg_types.size(); ++i) {
284 if (arg_types[i].code != kHandle) {
285 base = i; break;
286 }
287 }
288 for (size_t i = base; i < arg_types.size(); ++i) {
289 CHECK(arg_types[i].code != kHandle)
290 << "Device function need to be organized";
291 }
292 return base;
293 }
294
295 template<typename F>
PackFuncNonBufferArg(F f,const std::vector<TVMType> & arg_types)296 inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<TVMType>& arg_types) {
297 size_t num_buffer = NumBufferArgs(arg_types);
298 std::vector<detail::ArgConvertCode> codes;
299 for (size_t i = num_buffer; i < arg_types.size(); ++i) {
300 codes.push_back(detail::GetArgConvertCode(arg_types[i]));
301 }
302 int base = static_cast<int>(num_buffer);
303 size_t nargs = codes.size();
304 // specialization
305 if (nargs <= 4) {
306 return detail::PackFuncNonBufferArg_<4>(f, base, codes);
307 } else {
308 return detail::PackFuncNonBufferArg_<0>(f, base, codes);
309 }
310 }
311
312 template<typename F>
PackFuncPackedArg(F f,const std::vector<TVMType> & arg_types)313 inline PackedFunc PackFuncPackedArg(F f, const std::vector<TVMType>& arg_types) {
314 std::vector<detail::ArgConvertCode> codes;
315 for (size_t i = 0; i < arg_types.size(); ++i) {
316 codes.push_back(detail::GetArgConvertCode(arg_types[i]));
317 }
318 size_t nargs = codes.size();
319 // specialization
320 if (nargs <= 4) {
321 return detail::PackFuncPackedArg_<4>(f, codes);
322 } else {
323 return detail::PackFuncPackedArg_<0>(f, codes);
324 }
325 }
326 } // namespace runtime
327 } // namespace tvm
328 #endif // TVM_RUNTIME_PACK_ARGS_H_
329