1 #include "WasmExecutor.h"
2
3 #include "CodeGen_WebAssembly.h"
4 #include "Error.h"
5 #include "Float16.h"
6 #include "Func.h"
7 #include "ImageParam.h"
8 #include "JITModule.h"
9 #if WITH_WABT
10 #include "LLVM_Headers.h"
11 #endif
12 #include "LLVM_Output.h"
13 #include "LLVM_Runtime_Linker.h"
14 #include "Target.h"
15
16 #include <cmath>
17 #include <csignal>
18 #include <cstdlib>
19 #include <mutex>
20 #include <sstream>
21 #include <unordered_map>
22 #include <vector>
23
24 #if WITH_WABT
25 #include "wabt-src/src/binary-reader.h"
26 #include "wabt-src/src/error-formatter.h"
27 #include "wabt-src/src/feature.h"
28 #include "wabt-src/src/interp/binary-reader-interp.h"
29 #include "wabt-src/src/interp/interp-util.h"
30 #include "wabt-src/src/interp/interp-wasi.h"
31 #include "wabt-src/src/interp/interp.h"
32 #include "wabt-src/src/option-parser.h"
33 #include "wabt-src/src/stream.h"
34 #endif
35
36 namespace Halide {
37 namespace Internal {
38
39 // Trampolines do not use "_argv" as the suffix because
40 // that name may already exist and if so, will return an int
41 // instead of taking a pointer at the end of the args list to
42 // receive the result value.
43 static const char kTrampolineSuffix[] = "_trampoline";
44
45 #if WITH_WABT
46
47 namespace {
48
49 // ---------------------
50 // General debug helpers
51 // ---------------------
52
53 // Debugging the WebAssembly JIT support is usually disconnected from the rest of HL_DEBUG_CODEGEN
54 #define WASM_DEBUG_LEVEL 0
55
56 struct debug_sink {
57 debug_sink() = default;
58
59 template<typename T>
operator <<Halide::Internal::__anona20e47b00111::debug_sink60 inline debug_sink &operator<<(T &&x) {
61 return *this;
62 }
63 };
64
65 #if WASM_DEBUG_LEVEL
66 #define wdebug(x) Halide::Internal::debug(((x) <= WASM_DEBUG_LEVEL) ? 0 : 255)
67 #define wassert(x) internal_assert(x)
68 #else
69 #define wdebug(x) debug_sink()
70 #define wassert(x) debug_sink()
71 #endif
72
73 // ---------------------
74 // BDMalloc
75 // ---------------------
76
77 template<typename T>
align_up(T p,int alignment=32)78 inline T align_up(T p, int alignment = 32) {
79 return (p + alignment - 1) & ~(alignment - 1);
80 }
81
82 // Debugging our Malloc is extremely noisy and usually undesired
83
84 #define BDMALLOC_DEBUG_LEVEL 0
85 #if BDMALLOC_DEBUG_LEVEL
86 #define bddebug(x) Halide::Internal::debug(((x) <= BDMALLOC_DEBUG_LEVEL) ? 0 : 255)
87 #else
88 #define bddebug(x) debug_sink()
89 #endif
90
91 // BDMalloc aka BrainDeadMalloc. This is an *extremely* simple-minded implementation
92 // of malloc/free on top of a WasmMemoryObject, and is intended to be just barely adequate
93 // to allow Halide's JIT-based tests to pass. It is neither memory-efficient nor performant,
94 // nor has it been particularly well-vetted for potential buffer overruns and such.
95 class BDMalloc {
96 struct Region {
97 uint32_t size : 31;
98 uint32_t used : 1;
99 };
100
101 uint32_t total_size = 0;
102 std::map<uint32_t, Region> regions;
103
104 public:
105 BDMalloc() = default;
106
get_total_size() const107 uint32_t get_total_size() const {
108 return total_size;
109 }
110
inited() const111 bool inited() const {
112 return total_size > 0;
113 }
114
init(uint32_t total_size,uint32_t heap_start=1)115 void init(uint32_t total_size, uint32_t heap_start = 1) {
116 this->total_size = total_size;
117 regions.clear();
118
119 internal_assert(heap_start < total_size);
120 // Area before heap_start is permanently off-limits
121 regions[0] = {heap_start, true};
122 // Everything else is free
123 regions[heap_start] = {total_size - heap_start, false};
124 }
125
reset()126 void reset() {
127 this->total_size = 0;
128 regions.clear();
129 }
130
alloc_region(uint32_t requested_size)131 uint32_t alloc_region(uint32_t requested_size) {
132 internal_assert(requested_size > 0);
133
134 bddebug(1) << "begin alloc_region " << requested_size << "\n";
135 validate();
136
137 // TODO: this would be faster with a basic free list,
138 // but for most Halide test code, there aren't enough allocations
139 // for this to be worthwhile, or at least that's my observation from
140 // a first run-thru. Consider adding it if any of the JIT-based tests
141 // seem unreasonably slow; as it is, a linear search for the first free
142 // block of adequate size is (apparently) performant enough.
143
144 // alignment and min-block-size are the same for our purposes here.
145 constexpr uint32_t kAlignment = 32;
146 const uint32_t size = std::max(align_up((uint32_t)requested_size, kAlignment), kAlignment);
147
148 constexpr uint32_t kMaxAllocSize = 0x7fffffff;
149 internal_assert(size <= kMaxAllocSize);
150 bddebug(2) << "size -> " << size << "\n";
151
152 for (auto it = regions.begin(); it != regions.end(); it++) {
153 const uint32_t start = it->first;
154 Region &r = it->second;
155 if (!r.used && r.size >= size) {
156 bddebug(2) << "alloc @ " << start << "," << (uint32_t)r.size << "\n";
157 if (r.size > size + kAlignment) {
158 // Split the block
159 const uint32_t r2_start = start + size;
160 const uint32_t r2_size = r.size - size;
161 regions[r2_start] = {r2_size, false};
162 r.size = size;
163 bddebug(2) << "split: r-> " << start << "," << (uint32_t)r.size << "," << (start + r.size) << "\n";
164 bddebug(2) << "split: r2-> " << r2_start << "," << r2_size << "," << (r2_start + r2_size) << "\n";
165 }
166 // Just return the block
167 r.used = true;
168 bddebug(1) << "end alloc_region " << requested_size << "\n";
169 validate();
170 return start;
171 }
172 }
173 bddebug(1) << "fail alloc_region " << requested_size << "\n";
174 validate();
175 return 0;
176 }
177
free_region(uint32_t start)178 void free_region(uint32_t start) {
179 bddebug(1) << "begin free_region " << start << "\n";
180 validate();
181
182 // Can't free region at zero
183 if (!start) {
184 return;
185 }
186
187 internal_assert(start > 0);
188 auto it = regions.find(start);
189 internal_assert(it != regions.end());
190 internal_assert(it->second.used);
191 it->second.used = false;
192 // If prev region is free, combine with it
193 if (it != regions.begin()) {
194 auto prev = std::prev(it);
195 if (!prev->second.used) {
196 bddebug(2) << "combine prev: " << prev->first << " w/ " << it->first << "\n";
197 prev->second.size += it->second.size;
198 regions.erase(it);
199 it = prev;
200 }
201 }
202 // If next region is free, combine with it
203 auto next = std::next(it);
204 if (next != regions.end() && !next->second.used) {
205 bddebug(2) << "combine next: " << next->first << " w/ " << it->first << " "
206 << "\n";
207 it->second.size += next->second.size;
208 regions.erase(next);
209 }
210
211 bddebug(1) << "end free_region " << start << "\n";
212 validate();
213 }
214
grow_total_size(uint32_t new_total_size)215 void grow_total_size(uint32_t new_total_size) {
216 bddebug(1) << "begin grow_total_size " << new_total_size << "\n";
217 validate();
218
219 internal_assert(new_total_size > total_size);
220 auto it = regions.rbegin();
221 const uint32_t start = it->first;
222 Region &r = it->second;
223 uint32_t r_end = start + r.size;
224 internal_assert(r_end == total_size);
225 uint32_t delta = new_total_size - r_end;
226 if (r.used) {
227 // Add a free region after the last one
228 regions[r_end] = {delta, false};
229 } else {
230 // Just extend the last (free) region
231 r.size += delta;
232 }
233
234 // bookkeeping
235 total_size = new_total_size;
236
237 bddebug(1) << "end grow_total_size " << new_total_size << "\n";
238 validate();
239 }
240
validate() const241 void validate() const {
242 internal_assert(total_size > 0);
243 #if (BDMALLOC_DEBUG_LEVEL >= 1) || (WASM_DEBUG_LEVEL >= 1)
244 uint32_t prev_end = 0;
245 bool prev_used = false;
246 for (auto it : regions) {
247 const uint32_t start = it.first;
248 const Region &r = it.second;
249 bddebug(2) << "R: " << start << ".." << (start + r.size - 1) << "," << r.used << "\n";
250 wassert(start == prev_end) << "start " << start << " prev_end " << prev_end << "\n";
251 // it's OK to have two used regions in a row, but not two free ones
252 wassert(!(!prev_used && !r.used));
253 prev_end = start + r.size;
254 prev_used = r.used;
255 }
256 wassert(prev_end == total_size) << "prev_end " << prev_end << " total_size " << total_size << "\n";
257 bddebug(2) << "\n";
258 #endif
259 }
260 };
261
262 // ---------------------
263 // General Wasm helpers
264 // ---------------------
265
266 using wasm32_ptr_t = int32_t;
267
268 const wasm32_ptr_t kMagicJitUserContextValue = -1;
269
270 // TODO: vector codegen can underead allocated buffers; we need to deliberately
271 // allocate extra and return a pointer partway in to avoid out-of-bounds access
272 // failures. https://github.com/halide/Halide/issues/3738
273 constexpr size_t kExtraMallocSlop = 32;
274
compile_to_wasm(const Module & module,const std::string & fn_name)275 std::vector<char> compile_to_wasm(const Module &module, const std::string &fn_name) {
276 static std::mutex link_lock;
277 std::lock_guard<std::mutex> lock(link_lock);
278
279 llvm::LLVMContext context;
280 std::unique_ptr<llvm::Module> fn_module;
281
282 // Default wasm stack size is ~64k, but schedules with lots of
283 // alloca usage (heavily inlined, or tracing enabled) can blow thru
284 // this, which crashes in amusing ways, so ask for extra stack space
285 // for the alloca usage.
286 size_t stack_size = 65536;
287 {
288 std::unique_ptr<CodeGen_WebAssembly> cg(new CodeGen_WebAssembly(module.target()));
289 cg->set_context(context);
290 fn_module = cg->compile(module);
291 stack_size += cg->get_requested_alloca_total();
292 }
293
294 stack_size = align_up(stack_size);
295 wdebug(1) << "Requesting stack size of " << stack_size << "\n";
296
297 std::unique_ptr<llvm::Module> llvm_module =
298 link_with_wasm_jit_runtime(&context, module.target(), std::move(fn_module));
299
300 llvm::SmallVector<char, 4096> object;
301 llvm::raw_svector_ostream object_stream(object);
302 compile_llvm_module_to_object(*llvm_module, object_stream);
303
304 // TODO: surely there's a better way that doesn't require spooling things
305 // out to temp files
306 TemporaryFile obj_file("", ".o");
307 write_entire_file(obj_file.pathname(), object.data(), object.size());
308 #if WASM_DEBUG_LEVEL
309 obj_file.detach();
310 wdebug(1) << "Dumping obj_file to " << obj_file.pathname() << "\n";
311 #endif
312
313 TemporaryFile wasm_output("", ".wasm");
314
315 std::string lld_arg_strs[] = {
316 "HalideJITLinker",
317 // For debugging purposes:
318 // "--verbose",
319 // "-error-limit=0",
320 // "--print-gc-sections",
321 "--export=__heap_base",
322 "--allow-undefined",
323 "-zstack-size=" + std::to_string(stack_size),
324 obj_file.pathname(),
325 "--entry=" + fn_name,
326 "-o",
327 wasm_output.pathname()};
328
329 constexpr int c = sizeof(lld_arg_strs) / sizeof(lld_arg_strs[0]);
330 const char *lld_args[c];
331 for (int i = 0; i < c; ++i) {
332 lld_args[i] = lld_arg_strs[i].c_str();
333 }
334
335 // lld will temporarily hijack the signal handlers to ensure that temp files get cleaned up,
336 // but rather than preserving custom handlers in place, it restores the default handlers.
337 // This conflicts with some of our testing infrastructure, which relies on a SIGABRT handler
338 // set at global-ctor time to stay set. Therefore we'll save and restore this ourselves.
339 // Note that we must restore it before using internal_error (and also on the non-error path).
340 auto old_abort_handler = std::signal(SIGABRT, SIG_DFL);
341
342 #if LLVM_VERSION >= 110
343 if (!lld::wasm::link(lld_args, /*CanExitEarly*/ false, llvm::outs(), llvm::errs())) {
344 std::signal(SIGABRT, old_abort_handler);
345 internal_error << "lld::wasm::link failed\n";
346 }
347 #elif LLVM_VERSION >= 100
348 std::string lld_errs_string;
349 llvm::raw_string_ostream lld_errs(lld_errs_string);
350
351 if (!lld::wasm::link(lld_args, /*CanExitEarly*/ false, llvm::outs(), llvm::errs())) {
352 std::signal(SIGABRT, old_abort_handler);
353 internal_error << "lld::wasm::link failed: (" << lld_errs.str() << ")\n";
354 }
355 #else
356 std::string lld_errs_string;
357 llvm::raw_string_ostream lld_errs(lld_errs_string);
358
359 if (!lld::wasm::link(lld_args, /*CanExitEarly*/ false, lld_errs)) {
360 std::signal(SIGABRT, old_abort_handler);
361 internal_error << "lld::wasm::link failed: (" << lld_errs.str() << ")\n";
362 }
363 #endif
364
365 std::signal(SIGABRT, old_abort_handler);
366
367 #if WASM_DEBUG_LEVEL
368 wasm_output.detach();
369 wdebug(1) << "Dumping linked wasm to " << wasm_output.pathname() << "\n";
370 #endif
371
372 return read_entire_file(wasm_output.pathname());
373 }
374
halide_type_code(halide_type_code_t code,int bits)375 inline constexpr int halide_type_code(halide_type_code_t code, int bits) {
376 return ((int)code) | (bits << 8);
377 }
378
379 // dynamic_type_dispatch is a utility for functors that want to be able
380 // to dynamically dispatch a halide_type_t to type-specialized code.
381 // To use it, a functor must be a *templated* class, e.g.
382 //
383 // template<typename T> class MyFunctor { int operator()(arg1, arg2...); };
384 //
385 // dynamic_type_dispatch() is called with a halide_type_t as the first argument,
386 // followed by the arguments to the Functor's operator():
387 //
388 // auto result = dynamic_type_dispatch<MyFunctor>(some_halide_type, arg1, arg2);
389 //
390 // Note that this means that the functor must be able to instantiate its
391 // operator() for all the Halide scalar types; it also means that all those
392 // variants *will* be instantiated (increasing code size), so this approach
393 // should only be used when strictly necessary.
394
395 // clang-format off
396 template<template<typename> class Functor, typename... Args>
dynamic_type_dispatch(const halide_type_t & type,Args &&...args)397 auto dynamic_type_dispatch(const halide_type_t &type, Args &&... args) -> decltype(std::declval<Functor<uint8_t>>()(std::forward<Args>(args)...)) {
398
399 #define HANDLE_CASE(CODE, BITS, TYPE) \
400 case halide_type_code(CODE, BITS): \
401 return Functor<TYPE>()(std::forward<Args>(args)...);
402
403 switch (halide_type_code((halide_type_code_t)type.code, type.bits)) {
404 HANDLE_CASE(halide_type_bfloat, 16, bfloat16_t)
405 HANDLE_CASE(halide_type_float, 16, float16_t)
406 HANDLE_CASE(halide_type_float, 32, float)
407 HANDLE_CASE(halide_type_float, 64, double)
408 HANDLE_CASE(halide_type_int, 8, int8_t)
409 HANDLE_CASE(halide_type_int, 16, int16_t)
410 HANDLE_CASE(halide_type_int, 32, int32_t)
411 HANDLE_CASE(halide_type_int, 64, int64_t)
412 HANDLE_CASE(halide_type_uint, 1, bool)
413 HANDLE_CASE(halide_type_uint, 8, uint8_t)
414 HANDLE_CASE(halide_type_uint, 16, uint16_t)
415 HANDLE_CASE(halide_type_uint, 32, uint32_t)
416 HANDLE_CASE(halide_type_uint, 64, uint64_t)
417 HANDLE_CASE(halide_type_handle, 64, void *)
418 default:
419 internal_error;
420 using ReturnType = decltype(std::declval<Functor<uint8_t>>()(std::forward<Args>(args)...));
421 return ReturnType();
422 }
423
424 #undef HANDLE_CASE
425 }
426 // clang-format on
427
to_string(const wabt::MemoryStream & m)428 std::string to_string(const wabt::MemoryStream &m) {
429 wabt::OutputBuffer &o = const_cast<wabt::MemoryStream *>(&m)->output_buffer();
430 return std::string((const char *)o.data.data(), o.data.size());
431 }
432
433 struct WabtContext {
434 JITUserContext *const jit_user_context;
435 wabt::interp::Memory &memory;
436 BDMalloc &bdmalloc;
437
WabtContextHalide::Internal::__anona20e47b00111::WabtContext438 explicit WabtContext(JITUserContext *jit_user_context, wabt::interp::Memory &memory, BDMalloc &bdmalloc)
439 : jit_user_context(jit_user_context), memory(memory), bdmalloc(bdmalloc) {
440 }
441
442 WabtContext(const WabtContext &) = delete;
443 WabtContext(WabtContext &&) = delete;
444 void operator=(const WabtContext &) = delete;
445 void operator=(WabtContext &&) = delete;
446 };
447
get_wabt_context(wabt::interp::Thread & thread)448 WabtContext &get_wabt_context(wabt::interp::Thread &thread) {
449 void *host_info = thread.host_info();
450 wassert(host_info);
451 return *(WabtContext *)host_info;
452 }
453
get_wasm_memory_base(WabtContext & wabt_context)454 uint8_t *get_wasm_memory_base(WabtContext &wabt_context) {
455 return wabt_context.memory.UnsafeData();
456 }
457
wabt_malloc(WabtContext & wabt_context,size_t size)458 wasm32_ptr_t wabt_malloc(WabtContext &wabt_context, size_t size) {
459 wasm32_ptr_t p = wabt_context.bdmalloc.alloc_region(size);
460 if (!p) {
461 constexpr int kWasmPageSize = 65536;
462 const int32_t pages_needed = (size + kWasmPageSize - 1) / 65536;
463 wdebug(1) << "attempting to grow by pages: " << pages_needed << "\n";
464
465 wabt::Result r = wabt_context.memory.Grow(pages_needed);
466 internal_assert(Succeeded(r)) << "Memory::Grow() failed";
467
468 wabt_context.bdmalloc.grow_total_size(wabt_context.memory.ByteSize());
469 p = wabt_context.bdmalloc.alloc_region(size);
470 }
471
472 wdebug(2) << "allocation of " << size << " at: " << p << "\n";
473 return p;
474 }
475
wabt_free(WabtContext & wabt_context,wasm32_ptr_t ptr)476 void wabt_free(WabtContext &wabt_context, wasm32_ptr_t ptr) {
477 wdebug(2) << "freeing ptr at: " << ptr << "\n";
478 wabt_context.bdmalloc.free_region(ptr);
479 }
480
481 // Some internal code can call halide_error(null, ...), so this needs to be resilient to that.
482 // Callers must expect null and not crash.
get_jit_user_context(WabtContext & wabt_context,const wabt::interp::Value & arg)483 JITUserContext *get_jit_user_context(WabtContext &wabt_context, const wabt::interp::Value &arg) {
484 int32_t ucon_magic = arg.Get<int32_t>();
485 if (ucon_magic == 0) {
486 return nullptr;
487 }
488 wassert(ucon_magic == kMagicJitUserContextValue);
489 JITUserContext *jit_user_context = wabt_context.jit_user_context;
490 wassert(jit_user_context);
491 return jit_user_context;
492 }
493
494 // -----------------------
495 // halide_buffer_t <-> wasm_halide_buffer_t helpers
496 // -----------------------
497
498 struct wasm_halide_buffer_t {
499 uint64_t device;
500 wasm32_ptr_t device_interface; // halide_device_interface_t*
501 wasm32_ptr_t host; // uint8_t*
502 uint64_t flags;
503 halide_type_t type;
504 int32_t dimensions;
505 wasm32_ptr_t dim; // halide_dimension_t*
506 wasm32_ptr_t padding; // always zero
507 };
508
dump_hostbuf(WabtContext & wabt_context,const halide_buffer_t * buf,const std::string & label)509 void dump_hostbuf(WabtContext &wabt_context, const halide_buffer_t *buf, const std::string &label) {
510 #if WASM_DEBUG_LEVEL >= 2
511 const halide_dimension_t *dim = buf->dim;
512 const uint8_t *host = buf->host;
513
514 wdebug(1) << label << " = " << (void *)buf << " = {\n";
515 wdebug(1) << " device = " << buf->device << "\n";
516 wdebug(1) << " device_interface = " << buf->device_interface << "\n";
517 wdebug(1) << " host = " << (void *)host << " = {\n";
518 if (host) {
519 wdebug(1) << " " << (int)host[0] << ", " << (int)host[1] << ", " << (int)host[2] << ", " << (int)host[3] << "...\n";
520 }
521 wdebug(1) << " }\n";
522 wdebug(1) << " flags = " << buf->flags << "\n";
523 wdebug(1) << " type = " << (int)buf->type.code << "," << (int)buf->type.bits << "," << buf->type.lanes << "\n";
524 wdebug(1) << " dimensions = " << buf->dimensions << "\n";
525 wdebug(1) << " dim = " << (void *)buf->dim << " = {\n";
526 for (int i = 0; i < buf->dimensions; i++) {
527 const auto &d = dim[i];
528 wdebug(1) << " {" << d.min << "," << d.extent << "," << d.stride << "," << d.flags << "},\n";
529 }
530 wdebug(1) << " }\n";
531 wdebug(1) << " padding = " << buf->padding << "\n";
532 wdebug(1) << "}\n";
533 #endif
534 }
535
dump_wasmbuf(WabtContext & wabt_context,wasm32_ptr_t buf_ptr,const std::string & label)536 void dump_wasmbuf(WabtContext &wabt_context, wasm32_ptr_t buf_ptr, const std::string &label) {
537 #if WASM_DEBUG_LEVEL >= 2
538 wassert(buf_ptr);
539
540 uint8_t *base = get_wasm_memory_base(wabt_context);
541 wasm_halide_buffer_t *buf = (wasm_halide_buffer_t *)(base + buf_ptr);
542 halide_dimension_t *dim = buf->dim ? (halide_dimension_t *)(base + buf->dim) : nullptr;
543 uint8_t *host = buf->host ? (base + buf->host) : nullptr;
544
545 wdebug(1) << label << " = " << buf_ptr << " -> " << (void *)buf << " = {\n";
546 wdebug(1) << " device = " << buf->device << "\n";
547 wdebug(1) << " device_interface = " << buf->device_interface << "\n";
548 wdebug(1) << " host = " << buf->host << " -> " << (void *)host << " = {\n";
549 if (host) {
550 wdebug(1) << " " << (int)host[0] << ", " << (int)host[1] << ", " << (int)host[2] << ", " << (int)host[3] << "...\n";
551 }
552 wdebug(1) << " }\n";
553 wdebug(1) << " flags = " << buf->flags << "\n";
554 wdebug(1) << " type = " << (int)buf->type.code << "," << (int)buf->type.bits << "," << buf->type.lanes << "\n";
555 wdebug(1) << " dimensions = " << buf->dimensions << "\n";
556 wdebug(1) << " dim = " << buf->dim << " -> " << (void *)dim << " = {\n";
557 for (int i = 0; i < buf->dimensions; i++) {
558 const auto &d = dim[i];
559 wdebug(1) << " {" << d.min << "," << d.extent << "," << d.stride << "," << d.flags << "},\n";
560 }
561 wdebug(1) << " }\n";
562 wdebug(1) << " padding = " << buf->padding << "\n";
563 wdebug(1) << "}\n";
564 #endif
565 }
566
567 static_assert(sizeof(halide_type_t) == 4, "halide_type_t");
568 static_assert(sizeof(halide_dimension_t) == 16, "halide_dimension_t");
569 static_assert(sizeof(wasm_halide_buffer_t) == 40, "wasm_halide_buffer_t");
570
571 // Given a halide_buffer_t on the host, allocate a wasm_halide_buffer_t in wasm
572 // memory space and copy all relevant data. The resulting buf is laid out in
573 // contiguous memory, and can be free with a single free().
hostbuf_to_wasmbuf(WabtContext & wabt_context,const halide_buffer_t * src)574 wasm32_ptr_t hostbuf_to_wasmbuf(WabtContext &wabt_context, const halide_buffer_t *src) {
575 static_assert(sizeof(halide_type_t) == 4, "halide_type_t");
576 static_assert(sizeof(halide_dimension_t) == 16, "halide_dimension_t");
577 static_assert(sizeof(wasm_halide_buffer_t) == 40, "wasm_halide_buffer_t");
578
579 wdebug(2) << "\nhostbuf_to_wasmbuf:\n";
580 if (!src) {
581 return 0;
582 }
583
584 dump_hostbuf(wabt_context, src, "src");
585
586 wassert(src->device == 0);
587 wassert(src->device_interface == nullptr);
588
589 // Assume our malloc() has everything 32-byte aligned,
590 // and insert enough padding for host to also be 32-byte aligned.
591 const size_t dims_size_in_bytes = sizeof(halide_dimension_t) * src->dimensions;
592 const size_t dims_offset = sizeof(wasm_halide_buffer_t);
593 const size_t mem_needed_base = sizeof(wasm_halide_buffer_t) + dims_size_in_bytes;
594 const size_t host_offset = align_up(mem_needed_base);
595 const size_t host_size_in_bytes = src->size_in_bytes();
596 const size_t mem_needed = host_offset + host_size_in_bytes;
597
598 const wasm32_ptr_t dst_ptr = wabt_malloc(wabt_context, mem_needed);
599 wassert(dst_ptr);
600
601 uint8_t *base = get_wasm_memory_base(wabt_context);
602
603 wasm_halide_buffer_t *dst = (wasm_halide_buffer_t *)(base + dst_ptr);
604 dst->device = 0;
605 dst->device_interface = 0;
606 dst->host = src->host ? (dst_ptr + host_offset) : 0;
607 dst->flags = src->flags;
608 dst->type = src->type;
609 dst->dimensions = src->dimensions;
610 dst->dim = src->dimensions ? (dst_ptr + dims_offset) : 0;
611 dst->padding = 0;
612
613 if (src->dim) {
614 memcpy(base + dst->dim, src->dim, dims_size_in_bytes);
615 }
616 if (src->host) {
617 memcpy(base + dst->host, src->host, host_size_in_bytes);
618 }
619
620 dump_wasmbuf(wabt_context, dst_ptr, "dst");
621
622 return dst_ptr;
623 }
624
625 // Given a pointer to a wasm_halide_buffer_t in wasm memory space,
626 // allocate a Buffer<> on the host and copy all relevant data.
wasmbuf_to_hostbuf(WabtContext & wabt_context,wasm32_ptr_t src_ptr,Halide::Runtime::Buffer<> & dst)627 void wasmbuf_to_hostbuf(WabtContext &wabt_context, wasm32_ptr_t src_ptr, Halide::Runtime::Buffer<> &dst) {
628 wdebug(2) << "\nwasmbuf_to_hostbuf:\n";
629
630 dump_wasmbuf(wabt_context, src_ptr, "src");
631
632 wassert(src_ptr);
633
634 uint8_t *base = get_wasm_memory_base(wabt_context);
635
636 wasm_halide_buffer_t *src = (wasm_halide_buffer_t *)(base + src_ptr);
637
638 wassert(src->device == 0);
639 wassert(src->device_interface == 0);
640
641 halide_buffer_t dst_tmp;
642 dst_tmp.device = 0;
643 dst_tmp.device_interface = 0;
644 dst_tmp.host = nullptr; // src->host ? (base + src->host) : nullptr;
645 dst_tmp.flags = src->flags;
646 dst_tmp.type = src->type;
647 dst_tmp.dimensions = src->dimensions;
648 dst_tmp.dim = src->dim ? (halide_dimension_t *)(base + src->dim) : nullptr;
649 dst_tmp.padding = 0;
650
651 dump_hostbuf(wabt_context, &dst_tmp, "dst_tmp");
652
653 dst = Halide::Runtime::Buffer<>(dst_tmp);
654 if (src->host) {
655 // Don't use dst.copy(); it can tweak strides in ways that matter.
656 dst.allocate();
657 const size_t host_size_in_bytes = dst.raw_buffer()->size_in_bytes();
658 memcpy(dst.raw_buffer()->host, base + src->host, host_size_in_bytes);
659 }
660 dump_hostbuf(wabt_context, dst.raw_buffer(), "dst");
661 }
662
663 // Given a wasm_halide_buffer_t, copy possibly-changed data into a halide_buffer_t.
664 // Both buffers are asserted to match in type and dimensions.
copy_wasmbuf_to_existing_hostbuf(WabtContext & wabt_context,wasm32_ptr_t src_ptr,halide_buffer_t * dst)665 void copy_wasmbuf_to_existing_hostbuf(WabtContext &wabt_context, wasm32_ptr_t src_ptr, halide_buffer_t *dst) {
666 wassert(src_ptr && dst);
667
668 wdebug(2) << "\ncopy_wasmbuf_to_existing_hostbuf:\n";
669 dump_wasmbuf(wabt_context, src_ptr, "src");
670
671 uint8_t *base = get_wasm_memory_base(wabt_context);
672
673 wasm_halide_buffer_t *src = (wasm_halide_buffer_t *)(base + src_ptr);
674 wassert(src->device == 0);
675 wassert(src->device_interface == 0);
676 wassert(src->dimensions == dst->dimensions);
677 wassert(src->type == dst->type);
678
679 dump_hostbuf(wabt_context, dst, "dst_pre");
680
681 if (src->dimensions) {
682 memcpy(dst->dim, base + src->dim, sizeof(halide_dimension_t) * src->dimensions);
683 }
684 if (src->host) {
685 size_t host_size_in_bytes = dst->size_in_bytes();
686 memcpy(dst->host, base + src->host, host_size_in_bytes);
687 }
688
689 dst->device = 0;
690 dst->device_interface = 0;
691 dst->flags = src->flags;
692
693 dump_hostbuf(wabt_context, dst, "dst_post");
694 }
695
696 // Given a halide_buffer_t, copy possibly-changed data into a wasm_halide_buffer_t.
697 // Both buffers are asserted to match in type and dimensions.
copy_hostbuf_to_existing_wasmbuf(WabtContext & wabt_context,const halide_buffer_t * src,wasm32_ptr_t dst_ptr)698 void copy_hostbuf_to_existing_wasmbuf(WabtContext &wabt_context, const halide_buffer_t *src, wasm32_ptr_t dst_ptr) {
699 wassert(src && dst_ptr);
700
701 wdebug(1) << "\ncopy_hostbuf_to_existing_wasmbuf:\n";
702 dump_hostbuf(wabt_context, src, "src");
703
704 uint8_t *base = get_wasm_memory_base(wabt_context);
705
706 wasm_halide_buffer_t *dst = (wasm_halide_buffer_t *)(base + dst_ptr);
707 wassert(src->device == 0);
708 wassert(src->device_interface == 0);
709 wassert(src->dimensions == dst->dimensions);
710 wassert(src->type == dst->type);
711
712 dump_wasmbuf(wabt_context, dst_ptr, "dst_pre");
713
714 if (src->dimensions) {
715 memcpy(base + dst->dim, src->dim, sizeof(halide_dimension_t) * src->dimensions);
716 }
717 if (src->host) {
718 size_t host_size_in_bytes = src->size_in_bytes();
719 memcpy(base + dst->host, src->host, host_size_in_bytes);
720 }
721
722 dst->device = 0;
723 dst->device_interface = 0;
724 dst->flags = src->flags;
725
726 dump_wasmbuf(wabt_context, dst_ptr, "dst_post");
727 }
728
729 // --------------------------------------------------
730 // Helpers for converting to/from wabt::interp::Value
731 // --------------------------------------------------
732
733 template<typename T>
734 struct LoadValue {
operator ()Halide::Internal::__anona20e47b00111::LoadValue735 inline wabt::interp::Value operator()(const void *src) {
736 const T val = *(const T *)(src);
737 return wabt::interp::Value::Make(val);
738 }
739 };
740
741 template<>
operator ()(const void * src)742 inline wabt::interp::Value LoadValue<bool>::operator()(const void *src) {
743 // WABT doesn't do bools. Stash as u8 for now.
744 const uint8_t val = *(const uint8_t *)src;
745 return wabt::interp::Value::Make(val);
746 }
747
748 template<>
operator ()(const void * src)749 inline wabt::interp::Value LoadValue<void *>::operator()(const void *src) {
750 // Halide 'handle' types are always uint64, even on 32-bit systems
751 const uint64_t val = *(const uint64_t *)src;
752 return wabt::interp::Value::Make(val);
753 }
754
755 template<>
operator ()(const void * src)756 inline wabt::interp::Value LoadValue<float16_t>::operator()(const void *src) {
757 const uint16_t val = *(const uint16_t *)src;
758 return wabt::interp::Value::Make(val);
759 }
760
761 template<>
operator ()(const void * src)762 inline wabt::interp::Value LoadValue<bfloat16_t>::operator()(const void *src) {
763 const uint16_t val = *(const uint16_t *)src;
764 return wabt::interp::Value::Make(val);
765 }
766
load_value(const Type & t,const void * src)767 inline wabt::interp::Value load_value(const Type &t, const void *src) {
768 return dynamic_type_dispatch<LoadValue>(t, src);
769 }
770
771 template<typename T>
load_value(const T & val)772 inline wabt::interp::Value load_value(const T &val) {
773 return LoadValue<T>()(&val);
774 }
775
776 // -----
777
778 template<typename T>
779 struct StoreValue {
operator ()Halide::Internal::__anona20e47b00111::StoreValue780 inline void operator()(const wabt::interp::Value &src, void *dst) {
781 *(T *)dst = src.Get<T>();
782 }
783 };
784
785 template<>
operator ()(const wabt::interp::Value & src,void * dst)786 inline void StoreValue<bool>::operator()(const wabt::interp::Value &src, void *dst) {
787 // WABT doesn't do bools. Stash as u8 for now.
788 *(uint8_t *)dst = src.Get<uint8_t>();
789 }
790
791 template<>
operator ()(const wabt::interp::Value & src,void * dst)792 inline void StoreValue<void *>::operator()(const wabt::interp::Value &src, void *dst) {
793 // Halide 'handle' types are always uint64, even on 32-bit systems
794 *(uint64_t *)dst = src.Get<uint64_t>();
795 }
796
797 template<>
operator ()(const wabt::interp::Value & src,void * dst)798 inline void StoreValue<float16_t>::operator()(const wabt::interp::Value &src, void *dst) {
799 *(uint16_t *)dst = src.Get<uint16_t>();
800 }
801
802 template<>
operator ()(const wabt::interp::Value & src,void * dst)803 inline void StoreValue<bfloat16_t>::operator()(const wabt::interp::Value &src, void *dst) {
804 *(uint16_t *)dst = src.Get<uint16_t>();
805 }
806
store_value(const Type & t,const wabt::interp::Value & src,void * dst)807 inline void store_value(const Type &t, const wabt::interp::Value &src, void *dst) {
808 dynamic_type_dispatch<StoreValue>(t, src, dst);
809 }
810
811 template<typename T>
store_value(const wabt::interp::Value & src,T * dst)812 inline void store_value(const wabt::interp::Value &src, T *dst) {
813 StoreValue<T>()(src, dst);
814 }
815
816 // --------------------------------------------------
817 // Host Callback Functions
818 // --------------------------------------------------
819
820 template<typename T, T some_func(T)>
wabt_posix_math_1(wabt::interp::Thread & thread,const wabt::interp::Values & args,wabt::interp::Values & results,wabt::interp::Trap::Ptr * trap)821 wabt::Result wabt_posix_math_1(wabt::interp::Thread &thread,
822 const wabt::interp::Values &args,
823 wabt::interp::Values &results,
824 wabt::interp::Trap::Ptr *trap) {
825 wassert(args.size() == 1);
826 const T in = args[0].Get<T>();
827 const T out = some_func(in);
828 results[0] = wabt::interp::Value::Make(out);
829 return wabt::Result::Ok;
830 }
831
832 template<typename T, T some_func(T, T)>
wabt_posix_math_2(wabt::interp::Thread & thread,const wabt::interp::Values & args,wabt::interp::Values & results,wabt::interp::Trap::Ptr * trap)833 wabt::Result wabt_posix_math_2(wabt::interp::Thread &thread,
834 const wabt::interp::Values &args,
835 wabt::interp::Values &results,
836 wabt::interp::Trap::Ptr *trap) {
837 wassert(args.size() == 2);
838 const T in1 = args[0].Get<T>();
839 const T in2 = args[1].Get<T>();
840 const T out = some_func(in1, in2);
841 results[0] = wabt::interp::Value::Make(out);
842 return wabt::Result::Ok;
843 }
844
845 #define WABT_HOST_CALLBACK(x) \
846 wabt::Result wabt_jit_##x##_callback(wabt::interp::Thread &thread, \
847 const wabt::interp::Values &args, \
848 wabt::interp::Values &results, \
849 wabt::interp::Trap::Ptr *trap)
850
851 #define WABT_HOST_CALLBACK_UNIMPLEMENTED(x) \
852 WABT_HOST_CALLBACK(x) { \
853 internal_error << "WebAssembly JIT does not yet support the " #x "() call."; \
854 return wabt::Result::Ok; \
855 }
856
WABT_HOST_CALLBACK(__cxa_atexit)857 WABT_HOST_CALLBACK(__cxa_atexit) {
858 // nothing
859 return wabt::Result::Ok;
860 }
861
WABT_HOST_CALLBACK(__extendhfsf2)862 WABT_HOST_CALLBACK(__extendhfsf2) {
863 const uint16_t in = args[0].Get<uint16_t>();
864 const float out = (float)float16_t::make_from_bits(in);
865 results[0] = wabt::interp::Value::Make(out);
866 return wabt::Result::Ok;
867 }
868
WABT_HOST_CALLBACK(__truncsfhf2)869 WABT_HOST_CALLBACK(__truncsfhf2) {
870 const float in = args[0].Get<float>();
871 const uint16_t out = float16_t(in).to_bits();
872 results[0] = wabt::interp::Value::Make(out);
873 return wabt::Result::Ok;
874 }
875
WABT_HOST_CALLBACK(abort)876 WABT_HOST_CALLBACK(abort) {
877 abort();
878 return wabt::Result::Ok;
879 }
880
881 WABT_HOST_CALLBACK_UNIMPLEMENTED(fclose)
882
WABT_HOST_CALLBACK_UNIMPLEMENTED(fileno)883 WABT_HOST_CALLBACK_UNIMPLEMENTED(fileno)
884
885 WABT_HOST_CALLBACK_UNIMPLEMENTED(fopen)
886
887 WABT_HOST_CALLBACK(free) {
888 WabtContext &wabt_context = get_wabt_context(thread);
889
890 wasm32_ptr_t p = args[0].Get<int32_t>();
891 if (p) p -= kExtraMallocSlop;
892 wabt_free(wabt_context, p);
893 return wabt::Result::Ok;
894 }
895
896 WABT_HOST_CALLBACK_UNIMPLEMENTED(fwrite)
897
WABT_HOST_CALLBACK(getenv)898 WABT_HOST_CALLBACK(getenv) {
899 WabtContext &wabt_context = get_wabt_context(thread);
900
901 const int32_t s = args[0].Get<int32_t>();
902
903 uint8_t *base = get_wasm_memory_base(wabt_context);
904 char *e = getenv((char *)base + s);
905
906 // TODO: this string is leaked
907 if (e) {
908 wasm32_ptr_t r = wabt_malloc(wabt_context, strlen(e) + 1);
909 strcpy((char *)base + r, e);
910 results[0] = wabt::interp::Value::Make(r);
911 } else {
912 results[0] = wabt::interp::Value::Make(0);
913 }
914 return wabt::Result::Ok;
915 }
916
WABT_HOST_CALLBACK(halide_print)917 WABT_HOST_CALLBACK(halide_print) {
918 WabtContext &wabt_context = get_wabt_context(thread);
919
920 wassert(args.size() == 2);
921
922 JITUserContext *jit_user_context = get_jit_user_context(wabt_context, args[0]);
923 const int32_t str_address = args[1].Get<int32_t>();
924
925 uint8_t *p = get_wasm_memory_base(wabt_context);
926 const char *str = (const char *)p + str_address;
927
928 if (jit_user_context && jit_user_context->handlers.custom_print != NULL) {
929 (*jit_user_context->handlers.custom_print)(jit_user_context, str);
930 } else {
931 std::cout << str;
932 }
933 return wabt::Result::Ok;
934 }
935
WABT_HOST_CALLBACK(halide_trace_helper)936 WABT_HOST_CALLBACK(halide_trace_helper) {
937 WabtContext &wabt_context = get_wabt_context(thread);
938
939 wassert(args.size() == 12);
940
941 uint8_t *base = get_wasm_memory_base(wabt_context);
942
943 JITUserContext *jit_user_context = get_jit_user_context(wabt_context, args[0]);
944
945 const wasm32_ptr_t func_name_ptr = args[1].Get<int32_t>();
946 const wasm32_ptr_t value_ptr = args[2].Get<int32_t>();
947 const wasm32_ptr_t coordinates_ptr = args[3].Get<int32_t>();
948 const int type_code = args[4].Get<int32_t>();
949 const int type_bits = args[5].Get<int32_t>();
950 const int type_lanes = args[6].Get<int32_t>();
951 const int trace_code = args[7].Get<int32_t>();
952 const int parent_id = args[8].Get<int32_t>();
953 const int value_index = args[9].Get<int32_t>();
954 const int dimensions = args[10].Get<int32_t>();
955 const wasm32_ptr_t trace_tag_ptr = args[11].Get<int32_t>();
956
957 wassert(dimensions >= 0 && dimensions < 1024); // not a hard limit, just a sanity check
958
959 halide_trace_event_t event;
960 event.func = (const char *)(base + func_name_ptr);
961 event.value = value_ptr ? ((void *)(base + value_ptr)) : nullptr;
962 event.coordinates = coordinates_ptr ? ((int32_t *)(base + coordinates_ptr)) : nullptr;
963 event.trace_tag = (const char *)(base + trace_tag_ptr);
964 event.type.code = (halide_type_code_t)type_code;
965 event.type.bits = (uint8_t)type_bits;
966 event.type.lanes = (uint16_t)type_lanes;
967 event.event = (halide_trace_event_code_t)trace_code;
968 event.parent_id = parent_id;
969 event.value_index = value_index;
970 event.dimensions = dimensions;
971
972 int32_t result = 0;
973 if (jit_user_context && jit_user_context->handlers.custom_trace != NULL) {
974 result = (*jit_user_context->handlers.custom_trace)(jit_user_context, &event);
975 } else {
976 debug(0) << "Dropping trace event due to lack of trace handler.\n";
977 }
978
979 results[0] = wabt::interp::Value::Make(result);
980 return wabt::Result::Ok;
981 }
982
WABT_HOST_CALLBACK(halide_error)983 WABT_HOST_CALLBACK(halide_error) {
984 WabtContext &wabt_context = get_wabt_context(thread);
985
986 wassert(args.size() == 2);
987
988 JITUserContext *jit_user_context = get_jit_user_context(wabt_context, args[0]);
989 const int32_t str_address = args[1].Get<int32_t>();
990
991 uint8_t *p = get_wasm_memory_base(wabt_context);
992 const char *str = (const char *)p + str_address;
993
994 if (jit_user_context && jit_user_context->handlers.custom_error != NULL) {
995 (*jit_user_context->handlers.custom_error)(jit_user_context, str);
996 } else {
997 halide_runtime_error << str;
998 }
999 return wabt::Result::Ok;
1000 }
1001
WABT_HOST_CALLBACK(malloc)1002 WABT_HOST_CALLBACK(malloc) {
1003 WabtContext &wabt_context = get_wabt_context(thread);
1004
1005 size_t size = args[0].Get<int32_t>() + kExtraMallocSlop;
1006 wasm32_ptr_t p = wabt_malloc(wabt_context, size);
1007 if (p) p += kExtraMallocSlop;
1008 results[0] = wabt::interp::Value::Make(p);
1009 return wabt::Result::Ok;
1010 }
1011
WABT_HOST_CALLBACK(memcpy)1012 WABT_HOST_CALLBACK(memcpy) {
1013 WabtContext &wabt_context = get_wabt_context(thread);
1014
1015 const int32_t dst = args[0].Get<int32_t>();
1016 const int32_t src = args[1].Get<int32_t>();
1017 const int32_t n = args[2].Get<int32_t>();
1018
1019 uint8_t *base = get_wasm_memory_base(wabt_context);
1020
1021 memcpy(base + dst, base + src, n);
1022
1023 results[0] = wabt::interp::Value::Make(dst);
1024 return wabt::Result::Ok;
1025 }
1026
WABT_HOST_CALLBACK(memset)1027 WABT_HOST_CALLBACK(memset) {
1028 WabtContext &wabt_context = get_wabt_context(thread);
1029
1030 const int32_t s = args[0].Get<int32_t>();
1031 const int32_t c = args[1].Get<int32_t>();
1032 const int32_t n = args[2].Get<int32_t>();
1033
1034 uint8_t *base = get_wasm_memory_base(wabt_context);
1035 memset(base + s, c, n);
1036
1037 results[0] = wabt::interp::Value::Make(s);
1038 return wabt::Result::Ok;
1039 }
1040
WABT_HOST_CALLBACK(memcmp)1041 WABT_HOST_CALLBACK(memcmp) {
1042 WabtContext &wabt_context = get_wabt_context(thread);
1043
1044 const int32_t s1 = args[0].Get<int32_t>();
1045 const int32_t s2 = args[1].Get<int32_t>();
1046 const int32_t n = args[2].Get<int32_t>();
1047
1048 uint8_t *base = get_wasm_memory_base(wabt_context);
1049
1050 const int32_t r = memcmp(base + s1, base + s2, n);
1051
1052 results[0] = wabt::interp::Value::Make(r);
1053 return wabt::Result::Ok;
1054 }
1055
WABT_HOST_CALLBACK(strlen)1056 WABT_HOST_CALLBACK(strlen) {
1057 WabtContext &wabt_context = get_wabt_context(thread);
1058 const int32_t s = args[0].Get<int32_t>();
1059
1060 uint8_t *base = get_wasm_memory_base(wabt_context);
1061 int32_t r = strlen((char *)base + s);
1062
1063 results[0] = wabt::interp::Value::Make(r);
1064 return wabt::Result::Ok;
1065 }
1066
1067 WABT_HOST_CALLBACK_UNIMPLEMENTED(write)
1068
1069 using HostCallbackMap = std::unordered_map<std::string, wabt::interp::HostFunc::Callback>;
1070
1071 // clang-format off
get_host_callback_map()1072 const HostCallbackMap &get_host_callback_map() {
1073
1074 static HostCallbackMap m = {
1075 // General runtime functions.
1076
1077 #define DEFINE_CALLBACK(f) { #f, wabt_jit_##f##_callback },
1078
1079 DEFINE_CALLBACK(__cxa_atexit)
1080 DEFINE_CALLBACK(__extendhfsf2)
1081 DEFINE_CALLBACK(__truncsfhf2)
1082 DEFINE_CALLBACK(abort)
1083 DEFINE_CALLBACK(fclose)
1084 DEFINE_CALLBACK(fileno)
1085 DEFINE_CALLBACK(fopen)
1086 DEFINE_CALLBACK(free)
1087 DEFINE_CALLBACK(fwrite)
1088 DEFINE_CALLBACK(getenv)
1089 DEFINE_CALLBACK(halide_error)
1090 DEFINE_CALLBACK(halide_print)
1091 DEFINE_CALLBACK(halide_trace_helper)
1092 DEFINE_CALLBACK(malloc)
1093 DEFINE_CALLBACK(memcmp)
1094 DEFINE_CALLBACK(memcpy)
1095 DEFINE_CALLBACK(memset)
1096 DEFINE_CALLBACK(strlen)
1097 DEFINE_CALLBACK(write)
1098
1099 #undef DEFINE_CALLBACK
1100
1101 // Posix math.
1102 #define DEFINE_POSIX_MATH_CALLBACK(t, f) { #f, wabt_posix_math_1<t, ::f> },
1103
1104 DEFINE_POSIX_MATH_CALLBACK(double, acos)
1105 DEFINE_POSIX_MATH_CALLBACK(double, acosh)
1106 DEFINE_POSIX_MATH_CALLBACK(double, asin)
1107 DEFINE_POSIX_MATH_CALLBACK(double, asinh)
1108 DEFINE_POSIX_MATH_CALLBACK(double, atan)
1109 DEFINE_POSIX_MATH_CALLBACK(double, atanh)
1110 DEFINE_POSIX_MATH_CALLBACK(double, cos)
1111 DEFINE_POSIX_MATH_CALLBACK(double, cosh)
1112 DEFINE_POSIX_MATH_CALLBACK(double, exp)
1113 DEFINE_POSIX_MATH_CALLBACK(double, log)
1114 DEFINE_POSIX_MATH_CALLBACK(double, round)
1115 DEFINE_POSIX_MATH_CALLBACK(double, sin)
1116 DEFINE_POSIX_MATH_CALLBACK(double, sinh)
1117 DEFINE_POSIX_MATH_CALLBACK(double, tan)
1118 DEFINE_POSIX_MATH_CALLBACK(double, tanh)
1119
1120 DEFINE_POSIX_MATH_CALLBACK(float, acosf)
1121 DEFINE_POSIX_MATH_CALLBACK(float, acoshf)
1122 DEFINE_POSIX_MATH_CALLBACK(float, asinf)
1123 DEFINE_POSIX_MATH_CALLBACK(float, asinhf)
1124 DEFINE_POSIX_MATH_CALLBACK(float, atanf)
1125 DEFINE_POSIX_MATH_CALLBACK(float, atanhf)
1126 DEFINE_POSIX_MATH_CALLBACK(float, cosf)
1127 DEFINE_POSIX_MATH_CALLBACK(float, coshf)
1128 DEFINE_POSIX_MATH_CALLBACK(float, expf)
1129 DEFINE_POSIX_MATH_CALLBACK(float, logf)
1130 DEFINE_POSIX_MATH_CALLBACK(float, roundf)
1131 DEFINE_POSIX_MATH_CALLBACK(float, sinf)
1132 DEFINE_POSIX_MATH_CALLBACK(float, sinhf)
1133 DEFINE_POSIX_MATH_CALLBACK(float, tanf)
1134 DEFINE_POSIX_MATH_CALLBACK(float, tanhf)
1135
1136 #undef DEFINE_POSIX_MATH_CALLBACK
1137
1138 #define DEFINE_POSIX_MATH_CALLBACK2(t, f) { #f, wabt_posix_math_2<t, ::f> },
1139
1140 DEFINE_POSIX_MATH_CALLBACK2(float, atan2f)
1141 DEFINE_POSIX_MATH_CALLBACK2(double, atan2)
1142 DEFINE_POSIX_MATH_CALLBACK2(float, fminf)
1143 DEFINE_POSIX_MATH_CALLBACK2(double, fmin)
1144 DEFINE_POSIX_MATH_CALLBACK2(float, fmaxf)
1145 DEFINE_POSIX_MATH_CALLBACK2(double, fmax)
1146 DEFINE_POSIX_MATH_CALLBACK2(float, powf)
1147 DEFINE_POSIX_MATH_CALLBACK2(double, pow)
1148
1149 #undef DEFINE_POSIX_MATH_CALLBACK2
1150 };
1151
1152 return m;
1153 }
1154 // clang-format on
1155
1156 // --------------------------------------------------
1157 // Host Callback Functions
1158 // --------------------------------------------------
1159
1160 struct ExternArgType {
1161 halide_type_t type;
1162 bool is_void;
1163 bool is_buffer;
1164 };
1165
1166 using TrampolineFn = void (*)(void **);
1167
extern_callback_wrapper(const std::vector<ExternArgType> & arg_types,TrampolineFn trampoline_fn,wabt::interp::Thread & thread,const wabt::interp::Values & args,wabt::interp::Values & results,wabt::interp::Trap::Ptr * trap)1168 wabt::Result extern_callback_wrapper(const std::vector<ExternArgType> &arg_types,
1169 TrampolineFn trampoline_fn,
1170 wabt::interp::Thread &thread,
1171 const wabt::interp::Values &args,
1172 wabt::interp::Values &results,
1173 wabt::interp::Trap::Ptr *trap) {
1174 WabtContext &wabt_context = get_wabt_context(thread);
1175
1176 wassert(arg_types.size() >= 1);
1177 const size_t arg_types_len = arg_types.size() - 1;
1178 const ExternArgType &ret_type = arg_types[0];
1179
1180 // There's wasted space here, but that's ok.
1181 std::vector<Halide::Runtime::Buffer<>> buffers(arg_types_len);
1182 std::vector<uint64_t> scalars(arg_types_len, 0);
1183 std::vector<void *> trampoline_args(arg_types_len, nullptr);
1184
1185 for (size_t i = 0; i < arg_types_len; ++i) {
1186 const auto &a = arg_types[i + 1];
1187 if (a.is_buffer) {
1188 const wasm32_ptr_t buf_ptr = args[i].Get<int32_t>();
1189 wasmbuf_to_hostbuf(wabt_context, buf_ptr, buffers[i]);
1190 trampoline_args[i] = buffers[i].raw_buffer();
1191 } else {
1192 store_value(a.type, args[i], &scalars[i]);
1193 trampoline_args[i] = &scalars[i];
1194 }
1195 }
1196
1197 // The return value (if any) is always scalar.
1198 uint64_t ret_val = 0;
1199 const bool has_retval = !ret_type.is_void;
1200 internal_assert(!ret_type.is_buffer);
1201 if (has_retval) {
1202 trampoline_args.push_back(&ret_val);
1203 }
1204 (*trampoline_fn)(trampoline_args.data());
1205
1206 if (has_retval) {
1207 results[0] = dynamic_type_dispatch<LoadValue>(ret_type.type, (void *)&ret_val);
1208 }
1209
1210 // Progagate buffer data backwards. Note that for arbitrary extern functions,
1211 // we have no idea which buffers might be "input only", so we copy all data for all of them.
1212 for (size_t i = 0; i < arg_types_len; ++i) {
1213 const auto &a = arg_types[i + 1];
1214 if (a.is_buffer) {
1215 const wasm32_ptr_t buf_ptr = args[i].Get<int32_t>();
1216 copy_hostbuf_to_existing_wasmbuf(wabt_context, buffers[i], buf_ptr);
1217 }
1218 }
1219
1220 return wabt::Result::Ok;
1221 }
1222
should_skip_extern_symbol(const std::string & name)1223 bool should_skip_extern_symbol(const std::string &name) {
1224 static std::set<std::string> symbols = {
1225 "halide_print",
1226 "halide_error"};
1227 return symbols.count(name) > 0;
1228 }
1229
make_extern_callback(wabt::interp::Store & store,const std::map<std::string,Halide::JITExtern> & jit_externs,const JITModule & trampolines,const wabt::interp::ImportDesc & import)1230 wabt::interp::HostFunc::Ptr make_extern_callback(wabt::interp::Store &store,
1231 const std::map<std::string, Halide::JITExtern> &jit_externs,
1232 const JITModule &trampolines,
1233 const wabt::interp::ImportDesc &import) {
1234 const std::string &fn_name = import.type.name;
1235 if (should_skip_extern_symbol(fn_name)) {
1236 wdebug(1) << "Skipping extern symbol: " << fn_name << "\n";
1237 return wabt::interp::HostFunc::Ptr();
1238 }
1239
1240 const auto it = jit_externs.find(fn_name);
1241 if (it == jit_externs.end()) {
1242 wdebug(1) << "Extern symbol not found in JIT Externs: " << fn_name << "\n";
1243 return wabt::interp::HostFunc::Ptr();
1244 }
1245 const ExternSignature &sig = it->second.extern_c_function().signature();
1246
1247 const auto &tramp_it = trampolines.exports().find(fn_name + kTrampolineSuffix);
1248 if (tramp_it == trampolines.exports().end()) {
1249 wdebug(1) << "Extern symbol not found in trampolines: " << fn_name << "\n";
1250 return wabt::interp::HostFunc::Ptr();
1251 }
1252 TrampolineFn trampoline_fn = (TrampolineFn)tramp_it->second.address;
1253
1254 const size_t arg_count = sig.arg_types().size();
1255
1256 std::vector<ExternArgType> arg_types;
1257
1258 if (sig.is_void_return()) {
1259 const bool is_void = true;
1260 const bool is_buffer = false;
1261 // Specifying a type here with bits == 0 should trigger a proper 'void' return type
1262 arg_types.push_back(ExternArgType{{halide_type_int, 0, 0}, is_void, is_buffer});
1263 } else {
1264 const Type &t = sig.ret_type();
1265 const bool is_void = false;
1266 const bool is_buffer = (t == type_of<halide_buffer_t *>());
1267 user_assert(t.lanes() == 1) << "Halide Extern functions cannot return vector values.";
1268 user_assert(!is_buffer) << "Halide Extern functions cannot return halide_buffer_t.";
1269 arg_types.push_back(ExternArgType{t, is_void, is_buffer});
1270 }
1271 for (size_t i = 0; i < arg_count; ++i) {
1272 const Type &t = sig.arg_types()[i];
1273 const bool is_void = false;
1274 const bool is_buffer = (t == type_of<halide_buffer_t *>());
1275 user_assert(t.lanes() == 1) << "Halide Extern functions cannot accept vector values as arguments.";
1276 arg_types.push_back(ExternArgType{t, is_void, is_buffer});
1277 }
1278
1279 const auto callback_wrapper =
1280 [arg_types, trampoline_fn](wabt::interp::Thread &thread,
1281 const wabt::interp::Values &args,
1282 wabt::interp::Values &results,
1283 wabt::interp::Trap::Ptr *trap) -> wabt::Result {
1284 return extern_callback_wrapper(arg_types, trampoline_fn, thread, args, results, trap);
1285 };
1286
1287 auto func_type = *wabt::cast<wabt::interp::FuncType>(import.type.type.get());
1288 auto host_func = wabt::interp::HostFunc::New(store, func_type, callback_wrapper);
1289
1290 return host_func;
1291 }
1292
calc_features(const Target & target)1293 wabt::Features calc_features(const Target &target) {
1294 wabt::Features f;
1295 if (target.has_feature(Target::WasmSignExt)) {
1296 f.enable_sign_extension();
1297 }
1298 if (target.has_feature(Target::WasmSimd128)) {
1299 f.enable_simd();
1300 }
1301 if (target.has_feature(Target::WasmSatFloatToInt)) {
1302 f.enable_sat_float_to_int();
1303 }
1304 return f;
1305 }
1306
1307 } // namespace
1308
1309 #endif // WITH_WABT
1310
1311 struct WasmModuleContents {
1312 mutable RefCount ref_count;
1313
1314 const Target target;
1315 const std::vector<Argument> arguments;
1316 std::map<std::string, Halide::JITExtern> jit_externs;
1317 std::vector<JITModule> extern_deps;
1318 JITModule trampolines;
1319
1320 #if WITH_WABT
1321 BDMalloc bdmalloc;
1322 wabt::interp::Store store;
1323 wabt::interp::Module::Ptr module;
1324 wabt::interp::Instance::Ptr instance;
1325 wabt::interp::Thread::Options thread_options;
1326 wabt::interp::Memory::Ptr memory;
1327 #endif
1328
1329 WasmModuleContents(
1330 const Module &halide_module,
1331 const std::vector<Argument> &arguments,
1332 const std::string &fn_name,
1333 const std::map<std::string, Halide::JITExtern> &jit_externs,
1334 const std::vector<JITModule> &extern_deps);
1335
1336 int run(const void **args);
1337
1338 ~WasmModuleContents();
1339 };
1340
WasmModuleContents(const Module & halide_module,const std::vector<Argument> & arguments,const std::string & fn_name,const std::map<std::string,Halide::JITExtern> & jit_externs,const std::vector<JITModule> & extern_deps)1341 WasmModuleContents::WasmModuleContents(
1342 const Module &halide_module,
1343 const std::vector<Argument> &arguments,
1344 const std::string &fn_name,
1345 const std::map<std::string, Halide::JITExtern> &jit_externs,
1346 const std::vector<JITModule> &extern_deps)
1347 : target(halide_module.target()),
1348 arguments(arguments),
1349 jit_externs(jit_externs),
1350 extern_deps(extern_deps),
1351 trampolines(JITModule::make_trampolines_module(get_host_target(), jit_externs, kTrampolineSuffix, extern_deps)) {
1352
1353 #if WITH_WABT
1354 user_assert(LLVM_VERSION >= 110) << "Using the WebAssembly JIT is only supported under LLVM 11+.";
1355
1356 wdebug(1) << "Compiling wasm function " << fn_name << "\n";
1357
1358 // Compile halide into wasm bytecode.
1359 std::vector<char> final_wasm = compile_to_wasm(halide_module, fn_name);
1360
1361 store = wabt::interp::Store(calc_features(halide_module.target()));
1362
1363 // Create a wabt Module for it.
1364 wabt::MemoryStream log_stream;
1365 constexpr bool kReadDebugNames = true;
1366 constexpr bool kStopOnFirstError = true;
1367 constexpr bool kFailOnCustomSectionError = true;
1368 wabt::ReadBinaryOptions options(store.features(),
1369 &log_stream,
1370 kReadDebugNames,
1371 kStopOnFirstError,
1372 kFailOnCustomSectionError);
1373 wabt::Errors errors;
1374 wabt::interp::ModuleDesc module_desc;
1375 wabt::Result r = wabt::interp::ReadBinaryInterp(final_wasm.data(),
1376 final_wasm.size(),
1377 options,
1378 &errors,
1379 &module_desc);
1380 internal_assert(Succeeded(r))
1381 << "ReadBinaryInterp failed:\n"
1382 << wabt::FormatErrorsToString(errors, wabt::Location::Type::Binary) << "\n"
1383 << " log: " << to_string(log_stream) << "\n";
1384
1385 if (WASM_DEBUG_LEVEL >= 2) {
1386 wabt::MemoryStream dis_stream;
1387 module_desc.istream.Disassemble(&dis_stream);
1388 wdebug(WASM_DEBUG_LEVEL) << "Disassembly:\n"
1389 << to_string(dis_stream) << "\n";
1390 }
1391
1392 module = wabt::interp::Module::New(store, module_desc);
1393
1394 // Bind all imports to our callbacks.
1395 wabt::interp::RefVec imports;
1396 const HostCallbackMap &host_callback_map = get_host_callback_map();
1397 for (const auto &import : module->desc().imports) {
1398 wdebug(1) << "import=" << import.type.module << "." << import.type.name << "\n";
1399 if (import.type.type->kind == wabt::interp::ExternKind::Func && import.type.module == "env") {
1400 auto it = host_callback_map.find(import.type.name);
1401 if (it != host_callback_map.end()) {
1402 auto func_type = *wabt::cast<wabt::interp::FuncType>(import.type.type.get());
1403 auto host_func = wabt::interp::HostFunc::New(store, func_type, it->second);
1404 imports.push_back(host_func.ref());
1405 continue;
1406 }
1407
1408 // If it's not one of the standard host callbacks, assume it must be
1409 // a define_extern, and look for it in the jit_externs.
1410 auto host_func = make_extern_callback(store, jit_externs, trampolines, import);
1411 imports.push_back(host_func.ref());
1412 continue;
1413 }
1414 // By default, just push a null reference. This won't resolve, and
1415 // instantiation will fail.
1416 imports.push_back(wabt::interp::Ref::Null);
1417 }
1418
1419 wabt::interp::RefPtr<wabt::interp::Trap> trap;
1420 instance = wabt::interp::Instance::Instantiate(store, module.ref(), imports, &trap);
1421 internal_assert(instance) << "Error initializing module: " << trap->message() << "\n";
1422
1423 int32_t heap_base = -1;
1424
1425 for (const auto &e : module_desc.exports) {
1426 if (e.type.name == "__heap_base") {
1427 internal_assert(e.type.type->kind == wabt::ExternalKind::Global);
1428 heap_base = store.UnsafeGet<wabt::interp::Global>(instance->globals()[e.index])->Get().Get<int32_t>();
1429 wdebug(1) << "__heap_base is " << heap_base << "\n";
1430 continue;
1431 }
1432 if (e.type.name == "memory") {
1433 internal_assert(e.type.type->kind == wabt::ExternalKind::Memory);
1434 internal_assert(!memory.get()) << "Expected exactly one memory object but saw " << (void *)memory.get();
1435 memory = store.UnsafeGet<wabt::interp::Memory>(instance->memories()[e.index]);
1436 wdebug(1) << "heap_size is " << memory->ByteSize() << "\n";
1437 continue;
1438 }
1439 }
1440 internal_assert(heap_base >= 0) << "__heap_base not found";
1441 internal_assert(memory->ByteSize() > 0) << "memory size is unlikely";
1442
1443 bdmalloc.init(memory->ByteSize(), heap_base);
1444
1445 #endif // WITH_WABT
1446 }
1447
run(const void ** args)1448 int WasmModuleContents::run(const void **args) {
1449 #if WITH_WABT
1450 const auto &module_desc = module->desc();
1451
1452 wabt::interp::FuncType *func_type = nullptr;
1453 wabt::interp::RefPtr<wabt::interp::Func> func;
1454 std::string func_name;
1455
1456 for (const auto &e : module_desc.exports) {
1457 if (e.type.type->kind == wabt::ExternalKind::Func) {
1458 wdebug(1) << "Selecting export '" << e.type.name << "'\n";
1459 internal_assert(!func_type && !func) << "Multiple exported funcs found";
1460 func_type = wabt::cast<wabt::interp::FuncType>(e.type.type.get());
1461 func = store.UnsafeGet<wabt::interp::Func>(instance->funcs()[e.index]);
1462 func_name = e.type.name;
1463 continue;
1464 }
1465 }
1466
1467 JITUserContext *jit_user_context = nullptr;
1468 for (size_t i = 0; i < arguments.size(); i++) {
1469 const Argument &arg = arguments[i];
1470 const void *arg_ptr = args[i];
1471 if (arg.name == "__user_context") {
1472 jit_user_context = *(JITUserContext **)const_cast<void *>(arg_ptr);
1473 }
1474 }
1475
1476 WabtContext wabt_context(jit_user_context, *memory, bdmalloc);
1477
1478 wabt::interp::Values wabt_args;
1479 wabt::interp::Values wabt_results;
1480 wabt::interp::Trap::Ptr trap;
1481
1482 std::vector<wasm32_ptr_t> wbufs(arguments.size(), 0);
1483
1484 for (size_t i = 0; i < arguments.size(); i++) {
1485 const Argument &arg = arguments[i];
1486 const void *arg_ptr = args[i];
1487 if (arg.is_buffer()) {
1488 halide_buffer_t *buf = (halide_buffer_t *)const_cast<void *>(arg_ptr);
1489 // It's OK for this to be null (let Halide asserts handle it)
1490 wasm32_ptr_t wbuf = hostbuf_to_wasmbuf(wabt_context, buf);
1491 wbufs[i] = wbuf;
1492 wabt_args.push_back(load_value(wbuf));
1493 } else {
1494 if (arg.name == "__user_context") {
1495 wabt_args.push_back(wabt::interp::Value::Make(kMagicJitUserContextValue));
1496 } else {
1497 wabt_args.push_back(load_value(arg.type, arg_ptr));
1498 }
1499 }
1500 }
1501
1502 wabt::interp::Thread::Options options;
1503 wabt::interp::Thread::Ptr thread = wabt::interp::Thread::New(store, options);
1504 thread->set_host_info(&wabt_context);
1505
1506 auto r = func->Call(*thread, wabt_args, wabt_results, &trap);
1507 if (WASM_DEBUG_LEVEL >= 2) {
1508 wabt::MemoryStream call_stream;
1509 WriteCall(&call_stream, func_name, *func_type, wabt_args, wabt_results, trap);
1510 wdebug(WASM_DEBUG_LEVEL) << to_string(call_stream) << "\n";
1511 }
1512 internal_assert(Succeeded(r)) << "Func::Call failed: " << trap->message() << "\n";
1513 internal_assert(wabt_results.size() == 1);
1514 int32_t result = wabt_results[0].Get<int32_t>();
1515
1516 wdebug(1) << "Result is " << result << "\n";
1517
1518 if (result == 0) {
1519 // Update any output buffers
1520 for (size_t i = 0; i < arguments.size(); i++) {
1521 const Argument &arg = arguments[i];
1522 const void *arg_ptr = args[i];
1523 if (arg.is_buffer()) {
1524 halide_buffer_t *buf = (halide_buffer_t *)const_cast<void *>(arg_ptr);
1525 copy_wasmbuf_to_existing_hostbuf(wabt_context, wbufs[i], buf);
1526 }
1527 }
1528 }
1529
1530 for (wasm32_ptr_t p : wbufs) {
1531 wabt_free(wabt_context, p);
1532 }
1533
1534 // Don't do this: things allocated by Halide runtime might need to persist
1535 // between multiple invocations of the same function.
1536 // bdmalloc.reset();
1537
1538 return result;
1539
1540 #endif
1541
1542 internal_error << "WasmExecutor is not configured correctly";
1543 return -1;
1544 }
1545
~WasmModuleContents()1546 WasmModuleContents::~WasmModuleContents() {
1547 #if WITH_WABT
1548 // nothing
1549 #endif
1550 }
1551
1552 template<>
ref_count(const WasmModuleContents * p)1553 RefCount &ref_count<WasmModuleContents>(const WasmModuleContents *p) noexcept {
1554 return p->ref_count;
1555 }
1556
1557 template<>
destroy(const WasmModuleContents * p)1558 void destroy<WasmModuleContents>(const WasmModuleContents *p) {
1559 delete p;
1560 }
1561
1562 /*static*/
can_jit_target(const Target & target)1563 bool WasmModule::can_jit_target(const Target &target) {
1564 #if WITH_WABT
1565 if (target.arch == Target::WebAssembly) {
1566 return true;
1567 }
1568 #endif
1569 return false;
1570 }
1571
1572 /*static*/
compile(const Module & module,const std::vector<Argument> & arguments,const std::string & fn_name,const std::map<std::string,Halide::JITExtern> & jit_externs,const std::vector<JITModule> & extern_deps)1573 WasmModule WasmModule::compile(
1574 const Module &module,
1575 const std::vector<Argument> &arguments,
1576 const std::string &fn_name,
1577 const std::map<std::string, Halide::JITExtern> &jit_externs,
1578 const std::vector<JITModule> &extern_deps) {
1579 #if !defined(WITH_WABT)
1580 user_error << "Cannot run JITted WebAssembly without configuring a WebAssembly engine.";
1581 return WasmModule();
1582 #endif
1583
1584 WasmModule wasm_module;
1585 wasm_module.contents = new WasmModuleContents(module, arguments, fn_name, jit_externs, extern_deps);
1586 return wasm_module;
1587 }
1588
1589 /** Run generated previously compiled wasm code with a set of arguments. */
run(const void ** args)1590 int WasmModule::run(const void **args) {
1591 internal_assert(contents.defined());
1592 return contents->run(args);
1593 }
1594
1595 } // namespace Internal
1596 } // namespace Halide
1597