1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*
21  * \file tvmjs_support.cc
22  * \brief Support functions to be linked with wasm_runtime to provide
23  *        PackedFunc callbacks in tvmjs.
24  *        We do not need to link this file in standalone wasm.
25  */
26 
27 // configurations for the dmlc log.
28 #define DMLC_LOG_CUSTOMIZE 0
29 #define DMLC_LOG_STACK_TRACE 0
30 #define DMLC_LOG_DEBUG 0
31 #define DMLC_LOG_NODATE 1
32 #define DMLC_LOG_FATAL_THROW 0
33 
34 #include <tvm/runtime/c_runtime_api.h>
35 #include <tvm/runtime/container.h>
36 #include <tvm/runtime/device_api.h>
37 #include <tvm/runtime/packed_func.h>
38 #include <tvm/runtime/registry.h>
39 
40 #include "../../src/runtime/rpc/rpc_local_session.h"
41 
42 extern "C" {
43 // --- Additional C API for the Wasm runtime ---
44 /*!
45  * \brief Allocate space aligned to 64 bit.
46  * \param size The size of the space.
47  * \return The allocated space.
48  */
49 TVM_DLL void* TVMWasmAllocSpace(int size);
50 
51 /*!
52  * \brief Free the space allocated by TVMWasmAllocSpace.
53  * \param data The data pointer.
54  */
55 TVM_DLL void TVMWasmFreeSpace(void* data);
56 
57 /*!
58  * \brief Create PackedFunc from a resource handle.
59  * \param resource_handle The handle to the resource.
60  * \param out The output PackedFunc.
61  * \sa TVMWasmPackedCFunc, TVMWasmPackedCFuncFinalizer
62 3A * \return 0 if success.
63  */
64 TVM_DLL int TVMWasmFuncCreateFromCFunc(void* resource_handle, TVMFunctionHandle* out);
65 
66 // --- APIs to be implemented by the frontend. ---
67 /*!
68  * \brief Wasm frontend packed function caller.
69  *
70  * \param args The arguments
71  * \param type_codes The type codes of the arguments
72  * \param num_args Number of arguments.
73  * \param ret The return value handle.
74  * \param resource_handle The handle additional resouce handle from fron-end.
75  * \return 0 if success, -1 if failure happens, set error via TVMAPISetLastError.
76  */
77 extern int TVMWasmPackedCFunc(TVMValue* args, int* type_codes, int num_args, TVMRetValueHandle ret,
78                               void* resource_handle);
79 
80 /*!
81  * \brief Wasm frontend resource finalizer.
82  * \param resource_handle The pointer to the external resource.
83  */
84 extern void TVMWasmPackedCFuncFinalizer(void* resource_handle);
85 }  // extern "C"
86 
TVMWasmAllocSpace(int size)87 void* TVMWasmAllocSpace(int size) {
88   int num_count = (size + 7) / 8;
89   return new int64_t[num_count];
90 }
91 
TVMWasmFreeSpace(void * arr)92 void TVMWasmFreeSpace(void* arr) { delete[] static_cast<int64_t*>(arr); }
93 
TVMWasmFuncCreateFromCFunc(void * resource_handle,TVMFunctionHandle * out)94 int TVMWasmFuncCreateFromCFunc(void* resource_handle, TVMFunctionHandle* out) {
95   return TVMFuncCreateFromCFunc(TVMWasmPackedCFunc, resource_handle, TVMWasmPackedCFuncFinalizer,
96                                 out);
97 }
98 
99 namespace tvm {
100 namespace runtime {
101 
102 // A special local session that can interact with async
103 // functions in the JS runtime.
104 class AsyncLocalSession : public LocalSession {
105  public:
AsyncLocalSession()106   AsyncLocalSession() {}
107 
GetFunction(const std::string & name)108   PackedFuncHandle GetFunction(const std::string& name) final {
109     if (name == "runtime.RPCTimeEvaluator") {
110       return get_time_eval_placeholder_.get();
111     } else if (auto* fp = tvm::runtime::Registry::Get(name)) {
112       // return raw handle because the remote need to explicitly manage it.
113       return new PackedFunc(*fp);
114     } else if (auto* fp = tvm::runtime::Registry::Get("__async." + name)) {
115       auto* rptr = new PackedFunc(*fp);
116       async_func_set_.insert(rptr);
117       return rptr;
118     } else {
119       return nullptr;
120     }
121   }
122 
FreeHandle(void * handle,int type_code)123   void FreeHandle(void* handle, int type_code) final {
124     if (type_code == kTVMPackedFuncHandle) {
125       auto it = async_func_set_.find(handle);
126       if (it != async_func_set_.end()) {
127         async_func_set_.erase(it);
128       }
129     }
130     if (handle != get_time_eval_placeholder_.get()) {
131       LocalSession::FreeHandle(handle, type_code);
132     }
133   }
134 
AsyncCallFunc(PackedFuncHandle func,const TVMValue * arg_values,const int * arg_type_codes,int num_args,FAsyncCallback callback)135   void AsyncCallFunc(PackedFuncHandle func, const TVMValue* arg_values, const int* arg_type_codes,
136                      int num_args, FAsyncCallback callback) final {
137     auto it = async_func_set_.find(func);
138     if (it != async_func_set_.end()) {
139       PackedFunc packed_callback([callback, this](TVMArgs args, TVMRetValue*) {
140         int code = args[0];
141         TVMRetValue rv;
142         rv = args[1];
143         this->EncodeReturn(std::move(rv),
144                            [&](TVMArgs encoded_args) { callback(RPCCode::kReturn, encoded_args); });
145       });
146 
147       TVMRetValue temp;
148       std::vector<TVMValue> values(arg_values, arg_values + num_args);
149       std::vector<int> type_codes(arg_type_codes, arg_type_codes + num_args);
150       values.emplace_back(TVMValue());
151       type_codes.emplace_back(0);
152 
153       TVMArgsSetter setter(&values[0], &type_codes[0]);
154       // pass the callback as the last argument.
155       setter(num_args, packed_callback);
156 
157       auto* pf = static_cast<PackedFunc*>(func);
158       pf->CallPacked(TVMArgs(values.data(), type_codes.data(), num_args + 1), &temp);
159     } else if (func == get_time_eval_placeholder_.get()) {
160       // special handle time evaluator.
161       try {
162         TVMArgs args(arg_values, arg_type_codes, num_args);
163         PackedFunc retfunc =
164             this->GetTimeEvaluator(args[0], args[1], args[2], args[3], args[4], args[5], args[6]);
165         TVMRetValue rv;
166         rv = retfunc;
167         this->EncodeReturn(std::move(rv), [&](TVMArgs encoded_args) {
168           // mark as async.
169           async_func_set_.insert(encoded_args.values[1].v_handle);
170           callback(RPCCode::kReturn, encoded_args);
171         });
172       } catch (const std::runtime_error& e) {
173         this->SendException(callback, e.what());
174       }
175     } else {
176       LocalSession::AsyncCallFunc(func, arg_values, arg_type_codes, num_args, callback);
177     }
178   }
179 
AsyncCopyToRemote(void * local_from,size_t local_from_offset,void * remote_to,size_t remote_to_offset,size_t nbytes,TVMContext remote_ctx_to,DLDataType type_hint,FAsyncCallback on_complete)180   void AsyncCopyToRemote(void* local_from, size_t local_from_offset, void* remote_to,
181                          size_t remote_to_offset, size_t nbytes, TVMContext remote_ctx_to,
182                          DLDataType type_hint, FAsyncCallback on_complete) final {
183     TVMContext cpu_ctx;
184     cpu_ctx.device_type = kDLCPU;
185     cpu_ctx.device_id = 0;
186     try {
187       this->GetDeviceAPI(remote_ctx_to)
188           ->CopyDataFromTo(local_from, local_from_offset, remote_to, remote_to_offset, nbytes,
189                            cpu_ctx, remote_ctx_to, type_hint, nullptr);
190       this->AsyncStreamWait(remote_ctx_to, nullptr, on_complete);
191     } catch (const std::runtime_error& e) {
192       this->SendException(on_complete, e.what());
193     }
194   }
195 
AsyncCopyFromRemote(void * remote_from,size_t remote_from_offset,void * local_to,size_t local_to_offset,size_t nbytes,TVMContext remote_ctx_from,DLDataType type_hint,FAsyncCallback on_complete)196   void AsyncCopyFromRemote(void* remote_from, size_t remote_from_offset, void* local_to,
197                            size_t local_to_offset, size_t nbytes, TVMContext remote_ctx_from,
198                            DLDataType type_hint, FAsyncCallback on_complete) final {
199     TVMContext cpu_ctx;
200     cpu_ctx.device_type = kDLCPU;
201     cpu_ctx.device_id = 0;
202     try {
203       this->GetDeviceAPI(remote_ctx_from)
204           ->CopyDataFromTo(remote_from, remote_from_offset, local_to, local_to_offset, nbytes,
205                            remote_ctx_from, cpu_ctx, type_hint, nullptr);
206       this->AsyncStreamWait(remote_ctx_from, nullptr, on_complete);
207     } catch (const std::runtime_error& e) {
208       this->SendException(on_complete, e.what());
209     }
210   }
211 
AsyncStreamWait(TVMContext ctx,TVMStreamHandle stream,FAsyncCallback on_complete)212   void AsyncStreamWait(TVMContext ctx, TVMStreamHandle stream, FAsyncCallback on_complete) final {
213     if (ctx.device_type == kDLCPU) {
214       TVMValue value;
215       int32_t tcode = kTVMNullptr;
216       value.v_handle = nullptr;
217       on_complete(RPCCode::kReturn, TVMArgs(&value, &tcode, 1));
218     } else {
219       CHECK(ctx.device_type == static_cast<DLDeviceType>(kDLWebGPU));
220       if (async_wait_ == nullptr) {
221         async_wait_ = tvm::runtime::Registry::Get("__async.wasm.WebGPUWaitForTasks");
222       }
223       CHECK(async_wait_ != nullptr);
224       PackedFunc packed_callback([on_complete](TVMArgs args, TVMRetValue*) {
225         int code = args[0];
226         on_complete(static_cast<RPCCode>(code),
227                     TVMArgs(args.values + 1, args.type_codes + 1, args.size() - 1));
228       });
229       (*async_wait_)(packed_callback);
230     }
231   }
232 
IsAsync() const233   bool IsAsync() const final { return true; }
234 
235  private:
236   std::unordered_set<void*> async_func_set_;
237   std::unique_ptr<PackedFunc> get_time_eval_placeholder_ = std::make_unique<PackedFunc>();
238   const PackedFunc* async_wait_{nullptr};
239 
240   // time evaluator
GetTimeEvaluator(Optional<Module> opt_mod,std::string name,int device_type,int device_id,int number,int repeat,int min_repeat_ms)241   PackedFunc GetTimeEvaluator(Optional<Module> opt_mod, std::string name, int device_type,
242                               int device_id, int number, int repeat, int min_repeat_ms) {
243     TVMContext ctx;
244     ctx.device_type = static_cast<DLDeviceType>(device_type);
245     ctx.device_id = device_id;
246 
247     if (opt_mod.defined()) {
248       Module m = opt_mod.value();
249       std::string tkey = m->type_key();
250       return WrapWasmTimeEvaluator(m.GetFunction(name, false), ctx, number, repeat, min_repeat_ms);
251     } else {
252       auto* pf = runtime::Registry::Get(name);
253       CHECK(pf != nullptr) << "Cannot find " << name << " in the global function";
254       return WrapWasmTimeEvaluator(*pf, ctx, number, repeat, min_repeat_ms);
255     }
256   }
257 
258   // time evaluator
WrapWasmTimeEvaluator(PackedFunc pf,TVMContext ctx,int number,int repeat,int min_repeat_ms)259   PackedFunc WrapWasmTimeEvaluator(PackedFunc pf, TVMContext ctx, int number, int repeat,
260                                    int min_repeat_ms) {
261     auto ftimer = [pf, ctx, number, repeat, min_repeat_ms](TVMArgs args, TVMRetValue* rv) {
262       // the function is a async function.
263       PackedFunc on_complete = args[args.size() - 1];
264       // keep argument alive in finvoke so that they
265       // can be used throughout the async benchmark
266       std::vector<TVMValue> values(args.values, args.values + args.size() - 1);
267       std::vector<int> type_codes(args.type_codes, args.type_codes + args.size() - 1);
268 
269       auto finvoke = [pf, values, type_codes](int n) {
270         TVMRetValue temp;
271         TVMArgs invoke_args(values.data(), type_codes.data(), values.size());
272         for (int i = 0; i < n; ++i) {
273           pf.CallPacked(invoke_args, &temp);
274         }
275       };
276       auto* time_exec = runtime::Registry::Get("__async.wasm.TimeExecution");
277       CHECK(time_exec != nullptr) << "Cannot find wasm.GetTimer in the global function";
278       (*time_exec)(TypedPackedFunc<void(int)>(finvoke), ctx, number, repeat, min_repeat_ms,
279                    on_complete);
280     };
281     return PackedFunc(ftimer);
282   }
283 };
284 
__anonbfc248580702() 285 TVM_REGISTER_GLOBAL("wasm.LocalSession").set_body_typed([]() {
286   return CreateRPCSessionModule(std::make_shared<AsyncLocalSession>());
287 });
288 
289 }  // namespace runtime
290 }  // namespace tvm
291