1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*
21 * \file tvmjs_support.cc
22 * \brief Support functions to be linked with wasm_runtime to provide
23 * PackedFunc callbacks in tvmjs.
24 * We do not need to link this file in standalone wasm.
25 */
26
27 // configurations for the dmlc log.
28 #define DMLC_LOG_CUSTOMIZE 0
29 #define DMLC_LOG_STACK_TRACE 0
30 #define DMLC_LOG_DEBUG 0
31 #define DMLC_LOG_NODATE 1
32 #define DMLC_LOG_FATAL_THROW 0
33
34 #include <tvm/runtime/c_runtime_api.h>
35 #include <tvm/runtime/container.h>
36 #include <tvm/runtime/device_api.h>
37 #include <tvm/runtime/packed_func.h>
38 #include <tvm/runtime/registry.h>
39
40 #include "../../src/runtime/rpc/rpc_local_session.h"
41
42 extern "C" {
43 // --- Additional C API for the Wasm runtime ---
44 /*!
45 * \brief Allocate space aligned to 64 bit.
46 * \param size The size of the space.
47 * \return The allocated space.
48 */
49 TVM_DLL void* TVMWasmAllocSpace(int size);
50
51 /*!
52 * \brief Free the space allocated by TVMWasmAllocSpace.
53 * \param data The data pointer.
54 */
55 TVM_DLL void TVMWasmFreeSpace(void* data);
56
57 /*!
58 * \brief Create PackedFunc from a resource handle.
59 * \param resource_handle The handle to the resource.
60 * \param out The output PackedFunc.
61 * \sa TVMWasmPackedCFunc, TVMWasmPackedCFuncFinalizer
62 3A * \return 0 if success.
63 */
64 TVM_DLL int TVMWasmFuncCreateFromCFunc(void* resource_handle, TVMFunctionHandle* out);
65
66 // --- APIs to be implemented by the frontend. ---
67 /*!
68 * \brief Wasm frontend packed function caller.
69 *
70 * \param args The arguments
71 * \param type_codes The type codes of the arguments
72 * \param num_args Number of arguments.
73 * \param ret The return value handle.
74 * \param resource_handle The handle additional resouce handle from fron-end.
75 * \return 0 if success, -1 if failure happens, set error via TVMAPISetLastError.
76 */
77 extern int TVMWasmPackedCFunc(TVMValue* args, int* type_codes, int num_args, TVMRetValueHandle ret,
78 void* resource_handle);
79
80 /*!
81 * \brief Wasm frontend resource finalizer.
82 * \param resource_handle The pointer to the external resource.
83 */
84 extern void TVMWasmPackedCFuncFinalizer(void* resource_handle);
85 } // extern "C"
86
TVMWasmAllocSpace(int size)87 void* TVMWasmAllocSpace(int size) {
88 int num_count = (size + 7) / 8;
89 return new int64_t[num_count];
90 }
91
TVMWasmFreeSpace(void * arr)92 void TVMWasmFreeSpace(void* arr) { delete[] static_cast<int64_t*>(arr); }
93
TVMWasmFuncCreateFromCFunc(void * resource_handle,TVMFunctionHandle * out)94 int TVMWasmFuncCreateFromCFunc(void* resource_handle, TVMFunctionHandle* out) {
95 return TVMFuncCreateFromCFunc(TVMWasmPackedCFunc, resource_handle, TVMWasmPackedCFuncFinalizer,
96 out);
97 }
98
99 namespace tvm {
100 namespace runtime {
101
102 // A special local session that can interact with async
103 // functions in the JS runtime.
104 class AsyncLocalSession : public LocalSession {
105 public:
AsyncLocalSession()106 AsyncLocalSession() {}
107
GetFunction(const std::string & name)108 PackedFuncHandle GetFunction(const std::string& name) final {
109 if (name == "runtime.RPCTimeEvaluator") {
110 return get_time_eval_placeholder_.get();
111 } else if (auto* fp = tvm::runtime::Registry::Get(name)) {
112 // return raw handle because the remote need to explicitly manage it.
113 return new PackedFunc(*fp);
114 } else if (auto* fp = tvm::runtime::Registry::Get("__async." + name)) {
115 auto* rptr = new PackedFunc(*fp);
116 async_func_set_.insert(rptr);
117 return rptr;
118 } else {
119 return nullptr;
120 }
121 }
122
FreeHandle(void * handle,int type_code)123 void FreeHandle(void* handle, int type_code) final {
124 if (type_code == kTVMPackedFuncHandle) {
125 auto it = async_func_set_.find(handle);
126 if (it != async_func_set_.end()) {
127 async_func_set_.erase(it);
128 }
129 }
130 if (handle != get_time_eval_placeholder_.get()) {
131 LocalSession::FreeHandle(handle, type_code);
132 }
133 }
134
AsyncCallFunc(PackedFuncHandle func,const TVMValue * arg_values,const int * arg_type_codes,int num_args,FAsyncCallback callback)135 void AsyncCallFunc(PackedFuncHandle func, const TVMValue* arg_values, const int* arg_type_codes,
136 int num_args, FAsyncCallback callback) final {
137 auto it = async_func_set_.find(func);
138 if (it != async_func_set_.end()) {
139 PackedFunc packed_callback([callback, this](TVMArgs args, TVMRetValue*) {
140 int code = args[0];
141 TVMRetValue rv;
142 rv = args[1];
143 this->EncodeReturn(std::move(rv),
144 [&](TVMArgs encoded_args) { callback(RPCCode::kReturn, encoded_args); });
145 });
146
147 TVMRetValue temp;
148 std::vector<TVMValue> values(arg_values, arg_values + num_args);
149 std::vector<int> type_codes(arg_type_codes, arg_type_codes + num_args);
150 values.emplace_back(TVMValue());
151 type_codes.emplace_back(0);
152
153 TVMArgsSetter setter(&values[0], &type_codes[0]);
154 // pass the callback as the last argument.
155 setter(num_args, packed_callback);
156
157 auto* pf = static_cast<PackedFunc*>(func);
158 pf->CallPacked(TVMArgs(values.data(), type_codes.data(), num_args + 1), &temp);
159 } else if (func == get_time_eval_placeholder_.get()) {
160 // special handle time evaluator.
161 try {
162 TVMArgs args(arg_values, arg_type_codes, num_args);
163 PackedFunc retfunc =
164 this->GetTimeEvaluator(args[0], args[1], args[2], args[3], args[4], args[5], args[6]);
165 TVMRetValue rv;
166 rv = retfunc;
167 this->EncodeReturn(std::move(rv), [&](TVMArgs encoded_args) {
168 // mark as async.
169 async_func_set_.insert(encoded_args.values[1].v_handle);
170 callback(RPCCode::kReturn, encoded_args);
171 });
172 } catch (const std::runtime_error& e) {
173 this->SendException(callback, e.what());
174 }
175 } else {
176 LocalSession::AsyncCallFunc(func, arg_values, arg_type_codes, num_args, callback);
177 }
178 }
179
AsyncCopyToRemote(void * local_from,size_t local_from_offset,void * remote_to,size_t remote_to_offset,size_t nbytes,TVMContext remote_ctx_to,DLDataType type_hint,FAsyncCallback on_complete)180 void AsyncCopyToRemote(void* local_from, size_t local_from_offset, void* remote_to,
181 size_t remote_to_offset, size_t nbytes, TVMContext remote_ctx_to,
182 DLDataType type_hint, FAsyncCallback on_complete) final {
183 TVMContext cpu_ctx;
184 cpu_ctx.device_type = kDLCPU;
185 cpu_ctx.device_id = 0;
186 try {
187 this->GetDeviceAPI(remote_ctx_to)
188 ->CopyDataFromTo(local_from, local_from_offset, remote_to, remote_to_offset, nbytes,
189 cpu_ctx, remote_ctx_to, type_hint, nullptr);
190 this->AsyncStreamWait(remote_ctx_to, nullptr, on_complete);
191 } catch (const std::runtime_error& e) {
192 this->SendException(on_complete, e.what());
193 }
194 }
195
AsyncCopyFromRemote(void * remote_from,size_t remote_from_offset,void * local_to,size_t local_to_offset,size_t nbytes,TVMContext remote_ctx_from,DLDataType type_hint,FAsyncCallback on_complete)196 void AsyncCopyFromRemote(void* remote_from, size_t remote_from_offset, void* local_to,
197 size_t local_to_offset, size_t nbytes, TVMContext remote_ctx_from,
198 DLDataType type_hint, FAsyncCallback on_complete) final {
199 TVMContext cpu_ctx;
200 cpu_ctx.device_type = kDLCPU;
201 cpu_ctx.device_id = 0;
202 try {
203 this->GetDeviceAPI(remote_ctx_from)
204 ->CopyDataFromTo(remote_from, remote_from_offset, local_to, local_to_offset, nbytes,
205 remote_ctx_from, cpu_ctx, type_hint, nullptr);
206 this->AsyncStreamWait(remote_ctx_from, nullptr, on_complete);
207 } catch (const std::runtime_error& e) {
208 this->SendException(on_complete, e.what());
209 }
210 }
211
AsyncStreamWait(TVMContext ctx,TVMStreamHandle stream,FAsyncCallback on_complete)212 void AsyncStreamWait(TVMContext ctx, TVMStreamHandle stream, FAsyncCallback on_complete) final {
213 if (ctx.device_type == kDLCPU) {
214 TVMValue value;
215 int32_t tcode = kTVMNullptr;
216 value.v_handle = nullptr;
217 on_complete(RPCCode::kReturn, TVMArgs(&value, &tcode, 1));
218 } else {
219 CHECK(ctx.device_type == static_cast<DLDeviceType>(kDLWebGPU));
220 if (async_wait_ == nullptr) {
221 async_wait_ = tvm::runtime::Registry::Get("__async.wasm.WebGPUWaitForTasks");
222 }
223 CHECK(async_wait_ != nullptr);
224 PackedFunc packed_callback([on_complete](TVMArgs args, TVMRetValue*) {
225 int code = args[0];
226 on_complete(static_cast<RPCCode>(code),
227 TVMArgs(args.values + 1, args.type_codes + 1, args.size() - 1));
228 });
229 (*async_wait_)(packed_callback);
230 }
231 }
232
IsAsync() const233 bool IsAsync() const final { return true; }
234
235 private:
236 std::unordered_set<void*> async_func_set_;
237 std::unique_ptr<PackedFunc> get_time_eval_placeholder_ = std::make_unique<PackedFunc>();
238 const PackedFunc* async_wait_{nullptr};
239
240 // time evaluator
GetTimeEvaluator(Optional<Module> opt_mod,std::string name,int device_type,int device_id,int number,int repeat,int min_repeat_ms)241 PackedFunc GetTimeEvaluator(Optional<Module> opt_mod, std::string name, int device_type,
242 int device_id, int number, int repeat, int min_repeat_ms) {
243 TVMContext ctx;
244 ctx.device_type = static_cast<DLDeviceType>(device_type);
245 ctx.device_id = device_id;
246
247 if (opt_mod.defined()) {
248 Module m = opt_mod.value();
249 std::string tkey = m->type_key();
250 return WrapWasmTimeEvaluator(m.GetFunction(name, false), ctx, number, repeat, min_repeat_ms);
251 } else {
252 auto* pf = runtime::Registry::Get(name);
253 CHECK(pf != nullptr) << "Cannot find " << name << " in the global function";
254 return WrapWasmTimeEvaluator(*pf, ctx, number, repeat, min_repeat_ms);
255 }
256 }
257
258 // time evaluator
WrapWasmTimeEvaluator(PackedFunc pf,TVMContext ctx,int number,int repeat,int min_repeat_ms)259 PackedFunc WrapWasmTimeEvaluator(PackedFunc pf, TVMContext ctx, int number, int repeat,
260 int min_repeat_ms) {
261 auto ftimer = [pf, ctx, number, repeat, min_repeat_ms](TVMArgs args, TVMRetValue* rv) {
262 // the function is a async function.
263 PackedFunc on_complete = args[args.size() - 1];
264 // keep argument alive in finvoke so that they
265 // can be used throughout the async benchmark
266 std::vector<TVMValue> values(args.values, args.values + args.size() - 1);
267 std::vector<int> type_codes(args.type_codes, args.type_codes + args.size() - 1);
268
269 auto finvoke = [pf, values, type_codes](int n) {
270 TVMRetValue temp;
271 TVMArgs invoke_args(values.data(), type_codes.data(), values.size());
272 for (int i = 0; i < n; ++i) {
273 pf.CallPacked(invoke_args, &temp);
274 }
275 };
276 auto* time_exec = runtime::Registry::Get("__async.wasm.TimeExecution");
277 CHECK(time_exec != nullptr) << "Cannot find wasm.GetTimer in the global function";
278 (*time_exec)(TypedPackedFunc<void(int)>(finvoke), ctx, number, repeat, min_repeat_ms,
279 on_complete);
280 };
281 return PackedFunc(ftimer);
282 }
283 };
284
__anonbfc248580702() 285 TVM_REGISTER_GLOBAL("wasm.LocalSession").set_body_typed([]() {
286 return CreateRPCSessionModule(std::make_shared<AsyncLocalSession>());
287 });
288
289 } // namespace runtime
290 } // namespace tvm
291