1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file pack_args.h
22  * \brief Utility to pack TVMArgs to other type-erased fution calling convention.
23  *
24  *  Two type erased function signatures are supported.
25  *   - cuda_style(void** args, int num_args);
26  *      - Pack everything by address
27  *   - metal_style(void** buffers, int num_buffers,
28  *                 union_32bit args[N], int num_args);
29  *      - Pack buffer by address, pack rest parameter into 32bit union buffer.
30  */
31 #ifndef TVM_RUNTIME_PACK_ARGS_H_
32 #define TVM_RUNTIME_PACK_ARGS_H_
33 
34 #include <tvm/runtime/c_runtime_api.h>
35 #include <vector>
36 #include <cstring>
37 
38 namespace tvm {
39 namespace runtime {
40 /*!
41  * \brief argument union type of 32bit.
42  * Choose 32 bit because most GPU API do not work well with 64 bit.
43  */
44 union ArgUnion {
45   int32_t v_int32;
46   uint32_t v_uint32;
47   float v_float32;
48 };
49 /*!
50  * \brief Create a packed function from void addr types.
51  *
52  * \param f with signiture (TVMArgs args, TVMRetValue* rv, void* void_args)
53  * \param arg_types The arguments type information.
54  * \tparam F the function type
55  *
56  * \return The wrapped packed function.
57  */
58 template<typename F>
59 inline PackedFunc PackFuncVoidAddr(F f, const std::vector<TVMType>& arg_types);
60 /*!
61  * \brief Create a packed function that from function only packs buffer arguments.
62  *
63  * \param f with signiture (TVMArgs args, TVMRetValue* rv, ArgUnion* pack_args)
64  * \param arg_types The arguments type information.
65  * \tparam F the function type
66  *
67  * \return The wrapped packed function.
68  */
69 template<typename F>
70 inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<TVMType>& arg_types);
71 /*!
72  * \brief Create a packed function that from function that takes a packed arguments.
73  *
74  * \param f with signature (TVMArgs args, TVMRetValue* rv, void* pack_args, size_t nbytes)
75  * \param arg_types The arguments that wish to get from
76  * \tparam F the function type
77  *
78  * \return The wrapped packed function.
79  */
80 template<typename F>
81 inline PackedFunc PackFuncPackedArg(F f, const std::vector<TVMType>& arg_types);
82 /*!
83  * \brief Extract number of buffer argument from the argument types.
84  * \param arg_types The argument types.
85  * \return number of buffer arguments
86  */
87 inline size_t NumBufferArgs(const std::vector<TVMType>& arg_types);
88 
89 // implementations details
90 namespace detail {
91 template<typename T, int kSize>
92 class TempArray {
93  public:
TempArray(int size)94   explicit TempArray(int size) {}
data()95   T* data() {
96     return data_;
97   }
98  private:
99   T data_[kSize];
100 };
101 template<typename T>
102 class TempArray<T, 0> {
103  public:
TempArray(int size)104   explicit TempArray(int size) : data_(size) {}
data()105   T* data() {
106     return data_.data();
107   }
108  private:
109   std::vector<T> data_;
110 };
111 
112 /*! \brief conversion code used in void arg. */
113 enum ArgConvertCode {
114   INT64_TO_INT64,
115   INT64_TO_INT32,
116   INT64_TO_UINT32,
117   FLOAT64_TO_FLOAT32,
118   FLOAT64_TO_FLOAT64,
119   HANDLE_TO_HANDLE
120 };
121 
GetArgConvertCode(TVMType t)122 inline ArgConvertCode GetArgConvertCode(TVMType t) {
123   CHECK_EQ(t.lanes, 1U)
124       << "Cannot pass vector type argument to devic function for now";
125   if (t.code == kDLInt) {
126     if (t.bits == 64U) return INT64_TO_INT64;
127     if (t.bits == 32U) return INT64_TO_INT32;
128   } else if (t.code == kDLUInt) {
129     if (t.bits == 32U) return INT64_TO_UINT32;
130   } else if (t.code == kDLFloat) {
131     if (t.bits == 64U) return FLOAT64_TO_FLOAT64;
132     if (t.bits == 32U) return FLOAT64_TO_FLOAT32;
133   } else if (t.code == kHandle) {
134     return HANDLE_TO_HANDLE;
135   }
136   LOG(FATAL) << "Cannot handle " << t << " as device function argument";
137   return HANDLE_TO_HANDLE;
138 }
139 
140 template<int N, typename F>
PackFuncVoidAddr_(F f,const std::vector<ArgConvertCode> & codes)141 inline PackedFunc PackFuncVoidAddr_(F f, const std::vector<ArgConvertCode>& codes) {
142   int num_args = static_cast<int>(codes.size());
143   auto ret = [f, codes, num_args](TVMArgs args, TVMRetValue* ret) {
144     TempArray<void*, N> addr_(num_args);
145     TempArray<ArgUnion, N> holder_(num_args);
146     void** addr = addr_.data();
147     ArgUnion* holder = holder_.data();
148     for (int i = 0; i < num_args; ++i) {
149       switch (codes[i]) {
150         case INT64_TO_INT64:
151         case FLOAT64_TO_FLOAT64:
152         case HANDLE_TO_HANDLE: {
153           addr[i] = (void*)&(args.values[i]);  // NOLINT(*)
154           break;
155         }
156         case INT64_TO_INT32: {
157           holder[i].v_int32 = static_cast<int32_t>(args.values[i].v_int64);
158           addr[i] = &(holder[i]);
159           break;
160         }
161         case INT64_TO_UINT32 : {
162           holder[i].v_uint32 = static_cast<uint32_t>(args.values[i].v_int64);
163           addr[i] = &(holder[i]);
164           break;
165         }
166         case FLOAT64_TO_FLOAT32: {
167           holder[i].v_float32 = static_cast<float>(args.values[i].v_float64);
168           addr[i] = &(holder[i]);
169           break;
170         }
171       }
172     }
173     f(args, ret, addr);
174   };
175   return PackedFunc(ret);
176 }
177 
178 template<int N, typename F>
PackFuncNonBufferArg_(F f,int base,const std::vector<ArgConvertCode> & codes)179 inline PackedFunc PackFuncNonBufferArg_(
180     F f, int base, const std::vector<ArgConvertCode>& codes) {
181   int num_args = static_cast<int>(codes.size());
182   auto ret = [f, codes, base, num_args](TVMArgs args, TVMRetValue* ret) {
183     TempArray<ArgUnion, N> holder_(num_args);
184     ArgUnion* holder = holder_.data();
185     for (int i = 0; i < num_args; ++i) {
186       switch (codes[i]) {
187         case INT64_TO_INT64:
188         case FLOAT64_TO_FLOAT64: {
189           LOG(FATAL) << "Do not support 64bit argument to device function"; break;
190         }
191         case INT64_TO_INT32: {
192           holder[i].v_int32 = static_cast<int32_t>(args.values[base + i].v_int64);
193           break;
194         }
195         case INT64_TO_UINT32 : {
196           holder[i].v_uint32 = static_cast<uint32_t>(args.values[base + i].v_int64);
197           break;
198         }
199         case FLOAT64_TO_FLOAT32: {
200           holder[i].v_float32 = static_cast<float>(args.values[base + i].v_float64);
201           break;
202         }
203         case HANDLE_TO_HANDLE: {
204           LOG(FATAL) << "not reached"; break;
205         }
206       }
207     }
208     f(args, ret, holder);
209   };
210   return PackedFunc(ret);
211 }
212 
213 template<int N, typename F>
PackFuncPackedArg_(F f,const std::vector<ArgConvertCode> & codes)214 inline PackedFunc PackFuncPackedArg_(
215     F f, const std::vector<ArgConvertCode>& codes) {
216   int num_args = static_cast<int>(codes.size());
217   auto ret = [f, codes, num_args](TVMArgs args, TVMRetValue* ret) {
218     TempArray<uint64_t, N> pack_(num_args);
219     int32_t* pack = reinterpret_cast<int32_t*>(pack_.data());
220     int32_t* ptr = pack;
221     static_assert(sizeof(TVMValue) == 8, "invariant");
222     static_assert(sizeof(void*) % sizeof(int32_t) == 0, "invariant");
223     for (int i = 0; i < num_args; ++i) {
224       switch (codes[i]) {
225         case HANDLE_TO_HANDLE: {
226           std::memcpy(ptr, &(args.values[i].v_handle), sizeof(void*));
227           ptr += sizeof(void*) / sizeof(int32_t);
228           break;
229         }
230         case INT64_TO_INT64:
231         case FLOAT64_TO_FLOAT64: {
232           std::memcpy(ptr, &args.values[i], sizeof(TVMValue));
233           ptr += 2;
234           break;
235         }
236         case INT64_TO_INT32: {
237           *ptr = static_cast<int32_t>(args.values[i].v_int64);
238           ++ptr;
239           break;
240         }
241         case INT64_TO_UINT32 : {
242           *reinterpret_cast<uint32_t*>(ptr) =
243               static_cast<uint32_t>(args.values[i].v_int64);
244           ++ptr;
245           break;
246         }
247         case FLOAT64_TO_FLOAT32: {
248           *reinterpret_cast<float*>(ptr) =
249               static_cast<float>(args.values[i].v_float64);
250           ++ptr;
251           break;
252         }
253         default: {
254           LOG(FATAL) << "not reached"; break;
255         }
256       }
257     }
258     f(args, ret, pack, (ptr - pack) * sizeof(int32_t));
259   };
260   return PackedFunc(ret);
261 }
262 }  // namespace detail
263 
264 template<typename F>
PackFuncVoidAddr(F f,const std::vector<TVMType> & arg_types)265 inline PackedFunc PackFuncVoidAddr(F f, const std::vector<TVMType>& arg_types) {
266   std::vector<detail::ArgConvertCode> codes(arg_types.size());
267   for (size_t i = 0; i < arg_types.size(); ++i) {
268     codes[i] = detail::GetArgConvertCode(arg_types[i]);
269   }
270   size_t num_void_args = arg_types.size();
271   // specialization
272   if (num_void_args <= 4) {
273     return detail::PackFuncVoidAddr_<4>(f, codes);
274   } else if (num_void_args <= 8) {
275     return detail::PackFuncVoidAddr_<8>(f, codes);
276   } else {
277     return detail::PackFuncVoidAddr_<0>(f, codes);
278   }
279 }
280 
NumBufferArgs(const std::vector<TVMType> & arg_types)281 inline size_t NumBufferArgs(const std::vector<TVMType>& arg_types) {
282   size_t base = arg_types.size();
283   for (size_t i = 0; i < arg_types.size(); ++i) {
284     if (arg_types[i].code != kHandle) {
285       base = i; break;
286     }
287   }
288   for (size_t i = base; i < arg_types.size(); ++i) {
289     CHECK(arg_types[i].code != kHandle)
290         << "Device function need to be organized";
291   }
292   return base;
293 }
294 
295 template<typename F>
PackFuncNonBufferArg(F f,const std::vector<TVMType> & arg_types)296 inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<TVMType>& arg_types) {
297   size_t num_buffer = NumBufferArgs(arg_types);
298   std::vector<detail::ArgConvertCode> codes;
299   for (size_t i = num_buffer; i < arg_types.size(); ++i) {
300     codes.push_back(detail::GetArgConvertCode(arg_types[i]));
301   }
302   int base = static_cast<int>(num_buffer);
303   size_t nargs = codes.size();
304   // specialization
305   if (nargs <= 4) {
306     return detail::PackFuncNonBufferArg_<4>(f, base, codes);
307   } else {
308     return detail::PackFuncNonBufferArg_<0>(f, base, codes);
309   }
310 }
311 
312 template<typename F>
PackFuncPackedArg(F f,const std::vector<TVMType> & arg_types)313 inline PackedFunc PackFuncPackedArg(F f, const std::vector<TVMType>& arg_types) {
314   std::vector<detail::ArgConvertCode> codes;
315   for (size_t i = 0; i < arg_types.size(); ++i) {
316     codes.push_back(detail::GetArgConvertCode(arg_types[i]));
317   }
318   size_t nargs = codes.size();
319   // specialization
320   if (nargs <= 4) {
321     return detail::PackFuncPackedArg_<4>(f, codes);
322   } else {
323     return detail::PackFuncPackedArg_<0>(f, codes);
324   }
325 }
326 }  // namespace runtime
327 }  // namespace tvm
328 #endif  // TVM_RUNTIME_PACK_ARGS_H_
329