1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file c_runtime_api.cc
22  * \brief Device specific implementations
23  */
24 #include <dmlc/thread_local.h>
25 #include <tvm/runtime/c_backend_api.h>
26 #include <tvm/runtime/c_runtime_api.h>
27 #include <tvm/runtime/device_api.h>
28 #include <tvm/runtime/module.h>
29 #include <tvm/runtime/packed_func.h>
30 #include <tvm/runtime/registry.h>
31 
32 #include <algorithm>
33 #include <array>
34 #include <cctype>
35 #include <cstdlib>
36 #include <sstream>
37 #include <string>
38 
39 #include "object_internal.h"
40 #include "runtime_base.h"
41 
42 namespace tvm {
43 namespace runtime {
44 
GetCustomTypeName(uint8_t type_code)45 std::string GetCustomTypeName(uint8_t type_code) {
46   auto f = tvm::runtime::Registry::Get("runtime._datatype_get_type_name");
47   CHECK(f) << "Function runtime._datatype_get_type_name not found";
48   return (*f)(type_code).operator std::string();
49 }
50 
GetCustomTypeCode(const std::string & type_name)51 uint8_t GetCustomTypeCode(const std::string& type_name) {
52   auto f = tvm::runtime::Registry::Get("runtime._datatype_get_type_code");
53   CHECK(f) << "Function runtime._datatype_get_type_code not found";
54   return (*f)(type_name).operator int();
55 }
56 
GetCustomTypeRegistered(uint8_t type_code)57 bool GetCustomTypeRegistered(uint8_t type_code) {
58   auto f = tvm::runtime::Registry::Get("runtime._datatype_get_type_registered");
59   CHECK(f) << "Function runtime._datatype_get_type_registered not found";
60   return (*f)(type_code).operator bool();
61 }
62 
ParseCustomDatatype(const std::string & s,const char ** scan)63 uint8_t ParseCustomDatatype(const std::string& s, const char** scan) {
64   CHECK(s.substr(0, 6) == "custom") << "Not a valid custom datatype string";
65 
66   auto tmp = s.c_str();
67 
68   CHECK(s.c_str() == tmp);
69   *scan = s.c_str() + 6;
70   CHECK(s.c_str() == tmp);
71   if (**scan != '[') LOG(FATAL) << "expected opening brace after 'custom' type in" << s;
72   CHECK(s.c_str() == tmp);
73   *scan += 1;
74   CHECK(s.c_str() == tmp);
75   size_t custom_name_len = 0;
76   CHECK(s.c_str() == tmp);
77   while (*scan + custom_name_len <= s.c_str() + s.length() && *(*scan + custom_name_len) != ']')
78     ++custom_name_len;
79   CHECK(s.c_str() == tmp);
80   if (*(*scan + custom_name_len) != ']')
81     LOG(FATAL) << "expected closing brace after 'custom' type in" << s;
82   CHECK(s.c_str() == tmp);
83   *scan += custom_name_len + 1;
84   CHECK(s.c_str() == tmp);
85 
86   auto type_name = s.substr(7, custom_name_len);
87   CHECK(s.c_str() == tmp);
88   return GetCustomTypeCode(type_name);
89 }
90 
91 class DeviceAPIManager {
92  public:
93   static const int kMaxDeviceAPI = 32;
94   // Get API
Get(const TVMContext & ctx)95   static DeviceAPI* Get(const TVMContext& ctx) { return Get(ctx.device_type); }
Get(int dev_type,bool allow_missing=false)96   static DeviceAPI* Get(int dev_type, bool allow_missing = false) {
97     return Global()->GetAPI(dev_type, allow_missing);
98   }
99 
100  private:
101   std::array<DeviceAPI*, kMaxDeviceAPI> api_;
102   DeviceAPI* rpc_api_{nullptr};
103   std::mutex mutex_;
104   // constructor
DeviceAPIManager()105   DeviceAPIManager() { std::fill(api_.begin(), api_.end(), nullptr); }
106   // Global static variable.
Global()107   static DeviceAPIManager* Global() {
108     static DeviceAPIManager* inst = new DeviceAPIManager();
109     return inst;
110   }
111   // Get or initialize API.
GetAPI(int type,bool allow_missing)112   DeviceAPI* GetAPI(int type, bool allow_missing) {
113     if (type < kRPCSessMask) {
114       if (api_[type] != nullptr) return api_[type];
115       std::lock_guard<std::mutex> lock(mutex_);
116       if (api_[type] != nullptr) return api_[type];
117       api_[type] = GetAPI(DeviceName(type), allow_missing);
118       return api_[type];
119     } else {
120       if (rpc_api_ != nullptr) return rpc_api_;
121       std::lock_guard<std::mutex> lock(mutex_);
122       if (rpc_api_ != nullptr) return rpc_api_;
123       rpc_api_ = GetAPI("rpc", allow_missing);
124       return rpc_api_;
125     }
126   }
GetAPI(const std::string name,bool allow_missing)127   DeviceAPI* GetAPI(const std::string name, bool allow_missing) {
128     std::string factory = "device_api." + name;
129     auto* f = Registry::Get(factory);
130     if (f == nullptr) {
131       CHECK(allow_missing) << "Device API " << name << " is not enabled.";
132       return nullptr;
133     }
134     void* ptr = (*f)();
135     return static_cast<DeviceAPI*>(ptr);
136   }
137 };
138 
Get(TVMContext ctx,bool allow_missing)139 DeviceAPI* DeviceAPI::Get(TVMContext ctx, bool allow_missing) {
140   return DeviceAPIManager::Get(static_cast<int>(ctx.device_type), allow_missing);
141 }
142 
AllocWorkspace(TVMContext ctx,size_t size,DLDataType type_hint)143 void* DeviceAPI::AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) {
144   return AllocDataSpace(ctx, size, kTempAllocaAlignment, type_hint);
145 }
146 
FreeWorkspace(TVMContext ctx,void * ptr)147 void DeviceAPI::FreeWorkspace(TVMContext ctx, void* ptr) { FreeDataSpace(ctx, ptr); }
148 
CreateStream(TVMContext ctx)149 TVMStreamHandle DeviceAPI::CreateStream(TVMContext ctx) {
150   LOG(FATAL) << "Device does not support stream api.";
151   return nullptr;
152 }
153 
FreeStream(TVMContext ctx,TVMStreamHandle stream)154 void DeviceAPI::FreeStream(TVMContext ctx, TVMStreamHandle stream) {
155   LOG(FATAL) << "Device does not support stream api.";
156 }
157 
SyncStreamFromTo(TVMContext ctx,TVMStreamHandle event_src,TVMStreamHandle event_dst)158 void DeviceAPI::SyncStreamFromTo(TVMContext ctx, TVMStreamHandle event_src,
159                                  TVMStreamHandle event_dst) {
160   LOG(FATAL) << "Device does not support stream api.";
161 }
162 
163 //--------------------------------------------------------
164 // Error handling mechanism
165 // -------------------------------------------------------
166 // Standard error message format, {} means optional
167 //--------------------------------------------------------
168 // {error_type:} {message0}
169 // {message1}
170 // {message2}
171 // {Stack trace:}    // stack traces follow by this line
172 //   {trace 0}       // two spaces in the begining.
173 //   {trace 1}
174 //   {trace 2}
175 //--------------------------------------------------------
176 /*!
177  * \brief Normalize error message
178  *
179  *  Parse them header generated by by LOG(FATAL) and CHECK
180  *  and reformat the message into the standard format.
181  *
182  *  This function will also merge all the stack traces into
183  *  one trace and trim them.
184  *
185  * \param err_msg The error message.
186  * \return normalized message.
187  */
NormalizeError(std::string err_msg)188 std::string NormalizeError(std::string err_msg) {
189   // ------------------------------------------------------------------------
190   // log with header, {} indicates optional
191   //-------------------------------------------------------------------------
192   // [timestamp] file_name:line_number: {check_msg:} {error_type:} {message0}
193   // {message1}
194   // Stack trace:
195   //   {stack trace 0}
196   //   {stack trace 1}
197   //-------------------------------------------------------------------------
198   // Normalzied version
199   //-------------------------------------------------------------------------
200   // error_type: check_msg message0
201   // {message1}
202   // Stack trace:
203   //   File file_name, line lineno
204   //   {stack trace 0}
205   //   {stack trace 1}
206   //-------------------------------------------------------------------------
207   int line_number = 0;
208   std::istringstream is(err_msg);
209   std::string line, file_name, error_type, check_msg;
210 
211   // Parse log header and set the fields,
212   // Return true if it the log is in correct format,
213   // return false if something is wrong.
214   auto parse_log_header = [&]() {
215     // skip timestamp
216     if (is.peek() != '[') {
217       getline(is, line);
218       return true;
219     }
220     if (!(is >> line)) return false;
221     // get filename
222     while (is.peek() == ' ') is.get();
223 #ifdef _MSC_VER  // handle volume separator ":" in Windows path
224     std::string drive;
225     if (!getline(is, drive, ':')) return false;
226     if (!getline(is, file_name, ':')) return false;
227     file_name = drive + ":" + file_name;
228 #else
229     if (!getline(is, file_name, ':')) return false;
230 #endif
231     // get line number
232     if (!(is >> line_number)) return false;
233     // get rest of the message.
234     while (is.peek() == ' ' || is.peek() == ':') is.get();
235     if (!getline(is, line)) return false;
236     // detect check message, rewrite to remote extra :
237     if (line.compare(0, 13, "Check failed:") == 0) {
238       size_t end_pos = line.find(':', 13);
239       if (end_pos == std::string::npos) return false;
240       check_msg = line.substr(0, end_pos + 1) + ' ';
241       line = line.substr(end_pos + 1);
242     }
243     return true;
244   };
245   // if not in correct format, do not do any rewrite.
246   if (!parse_log_header()) return err_msg;
247   // Parse error type.
248   {
249     size_t start_pos = 0, end_pos;
250     for (; start_pos < line.length() && line[start_pos] == ' '; ++start_pos) {
251     }
252     for (end_pos = start_pos; end_pos < line.length(); ++end_pos) {
253       char ch = line[end_pos];
254       if (ch == ':') {
255         error_type = line.substr(start_pos, end_pos - start_pos);
256         break;
257       }
258       // [A-Z0-9a-z_.]
259       if (!std::isalpha(ch) && !std::isdigit(ch) && ch != '_' && ch != '.') break;
260     }
261     if (error_type.length() != 0) {
262       // if we successfully detected error_type: trim the following space.
263       for (start_pos = end_pos + 1; start_pos < line.length() && line[start_pos] == ' ';
264            ++start_pos) {
265       }
266       line = line.substr(start_pos);
267     } else {
268       // did not detect error_type, use default value.
269       line = line.substr(start_pos);
270       error_type = "TVMError";
271     }
272   }
273   // Seperate out stack trace.
274   std::ostringstream os;
275   os << error_type << ": " << check_msg << line << '\n';
276 
277   bool trace_mode = true;
278   std::vector<std::string> stack_trace;
279   while (getline(is, line)) {
280     if (trace_mode) {
281       if (line.compare(0, 2, "  ") == 0) {
282         stack_trace.push_back(line);
283       } else {
284         trace_mode = false;
285         // remove EOL trailing stacktrace.
286         if (line.length() == 0) continue;
287       }
288     }
289     if (!trace_mode) {
290       if (line.compare(0, 11, "Stack trace") == 0) {
291         trace_mode = true;
292       } else {
293         os << line << '\n';
294       }
295     }
296   }
297   if (stack_trace.size() != 0 || file_name.length() != 0) {
298     os << "Stack trace:\n";
299     if (file_name.length() != 0) {
300       os << "  File \"" << file_name << "\", line " << line_number << "\n";
301     }
302     // Print out stack traces, optionally trim the c++ traces
303     // about the frontends (as they will be provided by the frontends).
304     bool ffi_boundary = false;
305     for (const auto& line : stack_trace) {
306       // Heuristic to detect python ffi.
307       if (line.find("libffi.so") != std::string::npos ||
308           line.find("core.cpython") != std::string::npos) {
309         ffi_boundary = true;
310       }
311       // If the backtrace is not c++ backtrace with the prefix "  [bt]",
312       // then we can stop trimming.
313       if (ffi_boundary && line.compare(0, 6, "  [bt]") != 0) {
314         ffi_boundary = false;
315       }
316       if (!ffi_boundary) {
317         os << line << '\n';
318       }
319       // The line after TVMFuncCall cound be in FFI.
320       if (line.find("(TVMFuncCall") != std::string::npos) {
321         ffi_boundary = true;
322       }
323     }
324   }
325   return os.str();
326 }
327 
328 }  // namespace runtime
329 }  // namespace tvm
330 
331 using namespace tvm::runtime;
332 
333 struct TVMRuntimeEntry {
334   std::string ret_str;
335   std::string last_error;
336   TVMByteArray ret_bytes;
337 };
338 
339 typedef dmlc::ThreadLocalStore<TVMRuntimeEntry> TVMAPIRuntimeStore;
340 
TVMGetLastError()341 const char* TVMGetLastError() { return TVMAPIRuntimeStore::Get()->last_error.c_str(); }
342 
TVMAPIHandleException(const std::runtime_error & e)343 int TVMAPIHandleException(const std::runtime_error& e) {
344   TVMAPISetLastError(NormalizeError(e.what()).c_str());
345   return -1;
346 }
347 
TVMAPISetLastError(const char * msg)348 void TVMAPISetLastError(const char* msg) { TVMAPIRuntimeStore::Get()->last_error = msg; }
349 
TVMModLoadFromFile(const char * file_name,const char * format,TVMModuleHandle * out)350 int TVMModLoadFromFile(const char* file_name, const char* format, TVMModuleHandle* out) {
351   API_BEGIN();
352   TVMRetValue ret;
353   ret = Module::LoadFromFile(file_name, format);
354   TVMValue val;
355   int type_code;
356   ret.MoveToCHost(&val, &type_code);
357   *out = val.v_handle;
358   API_END();
359 }
360 
TVMModImport(TVMModuleHandle mod,TVMModuleHandle dep)361 int TVMModImport(TVMModuleHandle mod, TVMModuleHandle dep) {
362   API_BEGIN();
363   ObjectInternal::GetModuleNode(mod)->Import(GetRef<Module>(ObjectInternal::GetModuleNode(dep)));
364   API_END();
365 }
366 
TVMModGetFunction(TVMModuleHandle mod,const char * func_name,int query_imports,TVMFunctionHandle * func)367 int TVMModGetFunction(TVMModuleHandle mod, const char* func_name, int query_imports,
368                       TVMFunctionHandle* func) {
369   API_BEGIN();
370   PackedFunc pf = ObjectInternal::GetModuleNode(mod)->GetFunction(func_name, query_imports != 0);
371   if (pf != nullptr) {
372     *func = new PackedFunc(pf);
373   } else {
374     *func = nullptr;
375   }
376   API_END();
377 }
378 
TVMModFree(TVMModuleHandle mod)379 int TVMModFree(TVMModuleHandle mod) { return TVMObjectFree(mod); }
380 
TVMBackendGetFuncFromEnv(void * mod_node,const char * func_name,TVMFunctionHandle * func)381 int TVMBackendGetFuncFromEnv(void* mod_node, const char* func_name, TVMFunctionHandle* func) {
382   API_BEGIN();
383   *func = (TVMFunctionHandle)(static_cast<ModuleNode*>(mod_node)->GetFuncFromEnv(func_name));
384   API_END();
385 }
386 
TVMBackendAllocWorkspace(int device_type,int device_id,uint64_t size,int dtype_code_hint,int dtype_bits_hint)387 void* TVMBackendAllocWorkspace(int device_type, int device_id, uint64_t size, int dtype_code_hint,
388                                int dtype_bits_hint) {
389   TVMContext ctx;
390   ctx.device_type = static_cast<DLDeviceType>(device_type);
391   ctx.device_id = device_id;
392 
393   DLDataType type_hint;
394   type_hint.code = static_cast<decltype(type_hint.code)>(dtype_code_hint);
395   type_hint.bits = static_cast<decltype(type_hint.bits)>(dtype_bits_hint);
396   type_hint.lanes = 1;
397 
398   return DeviceAPIManager::Get(ctx)->AllocWorkspace(ctx, static_cast<size_t>(size), type_hint);
399 }
400 
TVMBackendFreeWorkspace(int device_type,int device_id,void * ptr)401 int TVMBackendFreeWorkspace(int device_type, int device_id, void* ptr) {
402   TVMContext ctx;
403   ctx.device_type = static_cast<DLDeviceType>(device_type);
404   ctx.device_id = device_id;
405   DeviceAPIManager::Get(ctx)->FreeWorkspace(ctx, ptr);
406   return 0;
407 }
408 
TVMBackendRunOnce(void ** handle,int (* f)(void *),void * cdata,int nbytes)409 int TVMBackendRunOnce(void** handle, int (*f)(void*), void* cdata, int nbytes) {
410   if (*handle == nullptr) {
411     *handle = reinterpret_cast<void*>(1);
412     return (*f)(cdata);
413   }
414   return 0;
415 }
416 
TVMFuncFree(TVMFunctionHandle func)417 int TVMFuncFree(TVMFunctionHandle func) {
418   API_BEGIN();
419   delete static_cast<PackedFunc*>(func);
420   API_END();
421 }
422 
TVMFuncCall(TVMFunctionHandle func,TVMValue * args,int * arg_type_codes,int num_args,TVMValue * ret_val,int * ret_type_code)423 int TVMFuncCall(TVMFunctionHandle func, TVMValue* args, int* arg_type_codes, int num_args,
424                 TVMValue* ret_val, int* ret_type_code) {
425   API_BEGIN();
426 
427   TVMRetValue rv;
428   (*static_cast<const PackedFunc*>(func)).CallPacked(TVMArgs(args, arg_type_codes, num_args), &rv);
429   // handle return string.
430   if (rv.type_code() == kTVMStr || rv.type_code() == kTVMDataType || rv.type_code() == kTVMBytes) {
431     TVMRuntimeEntry* e = TVMAPIRuntimeStore::Get();
432     if (rv.type_code() != kTVMDataType) {
433       e->ret_str = *rv.ptr<std::string>();
434     } else {
435       e->ret_str = rv.operator std::string();
436     }
437     if (rv.type_code() == kTVMBytes) {
438       e->ret_bytes.data = e->ret_str.c_str();
439       e->ret_bytes.size = e->ret_str.length();
440       *ret_type_code = kTVMBytes;
441       ret_val->v_handle = &(e->ret_bytes);
442     } else {
443       *ret_type_code = kTVMStr;
444       ret_val->v_str = e->ret_str.c_str();
445     }
446   } else {
447     rv.MoveToCHost(ret_val, ret_type_code);
448   }
449   API_END();
450 }
451 
TVMCFuncSetReturn(TVMRetValueHandle ret,TVMValue * value,int * type_code,int num_ret)452 int TVMCFuncSetReturn(TVMRetValueHandle ret, TVMValue* value, int* type_code, int num_ret) {
453   API_BEGIN();
454   CHECK_EQ(num_ret, 1);
455   TVMRetValue* rv = static_cast<TVMRetValue*>(ret);
456   *rv = TVMArgValue(value[0], type_code[0]);
457   API_END();
458 }
459 
TVMFuncCreateFromCFunc(TVMPackedCFunc func,void * resource_handle,TVMPackedCFuncFinalizer fin,TVMFunctionHandle * out)460 int TVMFuncCreateFromCFunc(TVMPackedCFunc func, void* resource_handle, TVMPackedCFuncFinalizer fin,
461                            TVMFunctionHandle* out) {
462   API_BEGIN();
463   if (fin == nullptr) {
464     *out = new PackedFunc([func, resource_handle](TVMArgs args, TVMRetValue* rv) {
465       int ret = func(const_cast<TVMValue*>(args.values), const_cast<int*>(args.type_codes),
466                      args.num_args, rv, resource_handle);
467       if (ret != 0) {
468         throw dmlc::Error(TVMGetLastError() + ::dmlc::StackTrace());
469       }
470     });
471   } else {
472     // wrap it in a shared_ptr, with fin as deleter.
473     // so fin will be called when the lambda went out of scope.
474     std::shared_ptr<void> rpack(resource_handle, fin);
475     *out = new PackedFunc([func, rpack](TVMArgs args, TVMRetValue* rv) {
476       int ret = func(const_cast<TVMValue*>(args.values), const_cast<int*>(args.type_codes),
477                      args.num_args, rv, rpack.get());
478       if (ret != 0) {
479         throw dmlc::Error(TVMGetLastError() + ::dmlc::StackTrace());
480       }
481     });
482   }
483   API_END();
484 }
485 
TVMStreamCreate(int device_type,int device_id,TVMStreamHandle * out)486 int TVMStreamCreate(int device_type, int device_id, TVMStreamHandle* out) {
487   API_BEGIN();
488   TVMContext ctx;
489   ctx.device_type = static_cast<DLDeviceType>(device_type);
490   ctx.device_id = device_id;
491   *out = DeviceAPIManager::Get(ctx)->CreateStream(ctx);
492   API_END();
493 }
494 
TVMStreamFree(int device_type,int device_id,TVMStreamHandle stream)495 int TVMStreamFree(int device_type, int device_id, TVMStreamHandle stream) {
496   API_BEGIN();
497   TVMContext ctx;
498   ctx.device_type = static_cast<DLDeviceType>(device_type);
499   ctx.device_id = device_id;
500   DeviceAPIManager::Get(ctx)->FreeStream(ctx, stream);
501   API_END();
502 }
503 
TVMSetStream(int device_type,int device_id,TVMStreamHandle stream)504 int TVMSetStream(int device_type, int device_id, TVMStreamHandle stream) {
505   API_BEGIN();
506   TVMContext ctx;
507   ctx.device_type = static_cast<DLDeviceType>(device_type);
508   ctx.device_id = device_id;
509   DeviceAPIManager::Get(ctx)->SetStream(ctx, stream);
510   API_END();
511 }
512 
TVMSynchronize(int device_type,int device_id,TVMStreamHandle stream)513 int TVMSynchronize(int device_type, int device_id, TVMStreamHandle stream) {
514   API_BEGIN();
515   TVMContext ctx;
516   ctx.device_type = static_cast<DLDeviceType>(device_type);
517   ctx.device_id = device_id;
518   DeviceAPIManager::Get(ctx)->StreamSync(ctx, stream);
519   API_END();
520 }
521 
TVMStreamStreamSynchronize(int device_type,int device_id,TVMStreamHandle src,TVMStreamHandle dst)522 int TVMStreamStreamSynchronize(int device_type, int device_id, TVMStreamHandle src,
523                                TVMStreamHandle dst) {
524   API_BEGIN();
525   TVMContext ctx;
526   ctx.device_type = static_cast<DLDeviceType>(device_type);
527   ctx.device_id = device_id;
528   DeviceAPIManager::Get(ctx)->SyncStreamFromTo(ctx, src, dst);
529   API_END();
530 }
531 
TVMCbArgToReturn(TVMValue * value,int * code)532 int TVMCbArgToReturn(TVMValue* value, int* code) {
533   API_BEGIN();
534   tvm::runtime::TVMRetValue rv;
535   rv = tvm::runtime::TVMMovableArgValue_(*value, *code);
536   rv.MoveToCHost(value, code);
537   API_END();
538 }
539 
TVMDeviceAllocDataSpace(DLContext ctx,size_t nbytes,size_t alignment,DLDataType type_hint,void ** out_data)540 int TVMDeviceAllocDataSpace(DLContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint,
541                             void** out_data) {
542   API_BEGIN();
543   out_data[0] = DeviceAPIManager::Get(ctx)->AllocDataSpace(ctx, nbytes, alignment, type_hint);
544   API_END();
545 }
546 
TVMDeviceFreeDataSpace(DLContext ctx,void * ptr)547 int TVMDeviceFreeDataSpace(DLContext ctx, void* ptr) {
548   API_BEGIN();
549   DeviceAPIManager::Get(ctx)->FreeDataSpace(ctx, ptr);
550   API_END();
551 }
552 
TVMDeviceCopyDataFromTo(const void * from,size_t from_offset,void * to,size_t to_offset,size_t num_bytes,TVMContext ctx_from,TVMContext ctx_to,DLDataType type_hint,TVMStreamHandle stream)553 int TVMDeviceCopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset,
554                             size_t num_bytes, TVMContext ctx_from, TVMContext ctx_to,
555                             DLDataType type_hint, TVMStreamHandle stream) {
556   API_BEGIN();
557   TVMContext ctx = ctx_from.device_type != kDLCPU ? ctx_from : ctx_to;
558   DeviceAPIManager::Get(ctx)->CopyDataFromTo(from, from_offset, to, to_offset, num_bytes, ctx_from,
559                                              ctx_to, type_hint, stream);
560   API_END();
561 }
562 
563 // set device api
564 TVM_REGISTER_GLOBAL(tvm::runtime::symbol::tvm_set_device)
__anonf686cb9d0402(TVMArgs args, TVMRetValue* ret) 565     .set_body([](TVMArgs args, TVMRetValue* ret) {
566       TVMContext ctx;
567       ctx.device_type = static_cast<DLDeviceType>(args[0].operator int());
568       ctx.device_id = args[1];
569       DeviceAPIManager::Get(ctx)->SetDevice(ctx);
570     });
571 
572 // set device api
__anonf686cb9d0502(TVMArgs args, TVMRetValue* ret) 573 TVM_REGISTER_GLOBAL("runtime.GetDeviceAttr").set_body([](TVMArgs args, TVMRetValue* ret) {
574   TVMContext ctx;
575   ctx.device_type = static_cast<DLDeviceType>(args[0].operator int());
576   ctx.device_id = args[1];
577 
578   DeviceAttrKind kind = static_cast<DeviceAttrKind>(args[2].operator int());
579   if (kind == kExist) {
580     DeviceAPI* api = DeviceAPIManager::Get(ctx.device_type, true);
581     if (api != nullptr) {
582       api->GetAttr(ctx, kind, ret);
583     } else {
584       *ret = 0;
585     }
586   } else {
587     DeviceAPIManager::Get(ctx)->GetAttr(ctx, kind, ret);
588   }
589 });
590 
591 TVM_REGISTER_GLOBAL("runtime.TVMSetStream").set_body_typed(TVMSetStream);
592