1 #include "WasmExecutor.h"
2 
3 #include "CodeGen_WebAssembly.h"
4 #include "Error.h"
5 #include "Float16.h"
6 #include "Func.h"
7 #include "ImageParam.h"
8 #include "JITModule.h"
9 #if WITH_WABT
10 #include "LLVM_Headers.h"
11 #endif
12 #include "LLVM_Output.h"
13 #include "LLVM_Runtime_Linker.h"
14 #include "Target.h"
15 
16 #include <cmath>
17 #include <csignal>
18 #include <cstdlib>
19 #include <mutex>
20 #include <sstream>
21 #include <unordered_map>
22 #include <vector>
23 
24 #if WITH_WABT
25 #include "wabt-src/src/binary-reader.h"
26 #include "wabt-src/src/error-formatter.h"
27 #include "wabt-src/src/feature.h"
28 #include "wabt-src/src/interp/binary-reader-interp.h"
29 #include "wabt-src/src/interp/interp-util.h"
30 #include "wabt-src/src/interp/interp-wasi.h"
31 #include "wabt-src/src/interp/interp.h"
32 #include "wabt-src/src/option-parser.h"
33 #include "wabt-src/src/stream.h"
34 #endif
35 
36 namespace Halide {
37 namespace Internal {
38 
39 // Trampolines do not use "_argv" as the suffix because
40 // that name may already exist and if so, will return an int
41 // instead of taking a pointer at the end of the args list to
42 // receive the result value.
43 static const char kTrampolineSuffix[] = "_trampoline";
44 
45 #if WITH_WABT
46 
47 namespace {
48 
49 // ---------------------
50 // General debug helpers
51 // ---------------------
52 
53 // Debugging the WebAssembly JIT support is usually disconnected from the rest of HL_DEBUG_CODEGEN
54 #define WASM_DEBUG_LEVEL 0
55 
56 struct debug_sink {
57     debug_sink() = default;
58 
59     template<typename T>
operator <<Halide::Internal::__anona20e47b00111::debug_sink60     inline debug_sink &operator<<(T &&x) {
61         return *this;
62     }
63 };
64 
65 #if WASM_DEBUG_LEVEL
66 #define wdebug(x) Halide::Internal::debug(((x) <= WASM_DEBUG_LEVEL) ? 0 : 255)
67 #define wassert(x) internal_assert(x)
68 #else
69 #define wdebug(x) debug_sink()
70 #define wassert(x) debug_sink()
71 #endif
72 
73 // ---------------------
74 // BDMalloc
75 // ---------------------
76 
77 template<typename T>
align_up(T p,int alignment=32)78 inline T align_up(T p, int alignment = 32) {
79     return (p + alignment - 1) & ~(alignment - 1);
80 }
81 
82 // Debugging our Malloc is extremely noisy and usually undesired
83 
84 #define BDMALLOC_DEBUG_LEVEL 0
85 #if BDMALLOC_DEBUG_LEVEL
86 #define bddebug(x) Halide::Internal::debug(((x) <= BDMALLOC_DEBUG_LEVEL) ? 0 : 255)
87 #else
88 #define bddebug(x) debug_sink()
89 #endif
90 
91 // BDMalloc aka BrainDeadMalloc. This is an *extremely* simple-minded implementation
92 // of malloc/free on top of a WasmMemoryObject, and is intended to be just barely adequate
93 // to allow Halide's JIT-based tests to pass. It is neither memory-efficient nor performant,
94 // nor has it been particularly well-vetted for potential buffer overruns and such.
95 class BDMalloc {
96     struct Region {
97         uint32_t size : 31;
98         uint32_t used : 1;
99     };
100 
101     uint32_t total_size = 0;
102     std::map<uint32_t, Region> regions;
103 
104 public:
105     BDMalloc() = default;
106 
get_total_size() const107     uint32_t get_total_size() const {
108         return total_size;
109     }
110 
inited() const111     bool inited() const {
112         return total_size > 0;
113     }
114 
init(uint32_t total_size,uint32_t heap_start=1)115     void init(uint32_t total_size, uint32_t heap_start = 1) {
116         this->total_size = total_size;
117         regions.clear();
118 
119         internal_assert(heap_start < total_size);
120         // Area before heap_start is permanently off-limits
121         regions[0] = {heap_start, true};
122         // Everything else is free
123         regions[heap_start] = {total_size - heap_start, false};
124     }
125 
reset()126     void reset() {
127         this->total_size = 0;
128         regions.clear();
129     }
130 
alloc_region(uint32_t requested_size)131     uint32_t alloc_region(uint32_t requested_size) {
132         internal_assert(requested_size > 0);
133 
134         bddebug(1) << "begin alloc_region " << requested_size << "\n";
135         validate();
136 
137         // TODO: this would be faster with a basic free list,
138         // but for most Halide test code, there aren't enough allocations
139         // for this to be worthwhile, or at least that's my observation from
140         // a first run-thru. Consider adding it if any of the JIT-based tests
141         // seem unreasonably slow; as it is, a linear search for the first free
142         // block of adequate size is (apparently) performant enough.
143 
144         // alignment and min-block-size are the same for our purposes here.
145         constexpr uint32_t kAlignment = 32;
146         const uint32_t size = std::max(align_up((uint32_t)requested_size, kAlignment), kAlignment);
147 
148         constexpr uint32_t kMaxAllocSize = 0x7fffffff;
149         internal_assert(size <= kMaxAllocSize);
150         bddebug(2) << "size -> " << size << "\n";
151 
152         for (auto it = regions.begin(); it != regions.end(); it++) {
153             const uint32_t start = it->first;
154             Region &r = it->second;
155             if (!r.used && r.size >= size) {
156                 bddebug(2) << "alloc @ " << start << "," << (uint32_t)r.size << "\n";
157                 if (r.size > size + kAlignment) {
158                     // Split the block
159                     const uint32_t r2_start = start + size;
160                     const uint32_t r2_size = r.size - size;
161                     regions[r2_start] = {r2_size, false};
162                     r.size = size;
163                     bddebug(2) << "split: r-> " << start << "," << (uint32_t)r.size << "," << (start + r.size) << "\n";
164                     bddebug(2) << "split: r2-> " << r2_start << "," << r2_size << "," << (r2_start + r2_size) << "\n";
165                 }
166                 // Just return the block
167                 r.used = true;
168                 bddebug(1) << "end alloc_region " << requested_size << "\n";
169                 validate();
170                 return start;
171             }
172         }
173         bddebug(1) << "fail alloc_region " << requested_size << "\n";
174         validate();
175         return 0;
176     }
177 
free_region(uint32_t start)178     void free_region(uint32_t start) {
179         bddebug(1) << "begin free_region " << start << "\n";
180         validate();
181 
182         // Can't free region at zero
183         if (!start) {
184             return;
185         }
186 
187         internal_assert(start > 0);
188         auto it = regions.find(start);
189         internal_assert(it != regions.end());
190         internal_assert(it->second.used);
191         it->second.used = false;
192         // If prev region is free, combine with it
193         if (it != regions.begin()) {
194             auto prev = std::prev(it);
195             if (!prev->second.used) {
196                 bddebug(2) << "combine prev: " << prev->first << " w/ " << it->first << "\n";
197                 prev->second.size += it->second.size;
198                 regions.erase(it);
199                 it = prev;
200             }
201         }
202         // If next region is free, combine with it
203         auto next = std::next(it);
204         if (next != regions.end() && !next->second.used) {
205             bddebug(2) << "combine next: " << next->first << " w/ " << it->first << " "
206                        << "\n";
207             it->second.size += next->second.size;
208             regions.erase(next);
209         }
210 
211         bddebug(1) << "end free_region " << start << "\n";
212         validate();
213     }
214 
grow_total_size(uint32_t new_total_size)215     void grow_total_size(uint32_t new_total_size) {
216         bddebug(1) << "begin grow_total_size " << new_total_size << "\n";
217         validate();
218 
219         internal_assert(new_total_size > total_size);
220         auto it = regions.rbegin();
221         const uint32_t start = it->first;
222         Region &r = it->second;
223         uint32_t r_end = start + r.size;
224         internal_assert(r_end == total_size);
225         uint32_t delta = new_total_size - r_end;
226         if (r.used) {
227             // Add a free region after the last one
228             regions[r_end] = {delta, false};
229         } else {
230             // Just extend the last (free) region
231             r.size += delta;
232         }
233 
234         // bookkeeping
235         total_size = new_total_size;
236 
237         bddebug(1) << "end grow_total_size " << new_total_size << "\n";
238         validate();
239     }
240 
validate() const241     void validate() const {
242         internal_assert(total_size > 0);
243 #if (BDMALLOC_DEBUG_LEVEL >= 1) || (WASM_DEBUG_LEVEL >= 1)
244         uint32_t prev_end = 0;
245         bool prev_used = false;
246         for (auto it : regions) {
247             const uint32_t start = it.first;
248             const Region &r = it.second;
249             bddebug(2) << "R: " << start << ".." << (start + r.size - 1) << "," << r.used << "\n";
250             wassert(start == prev_end) << "start " << start << " prev_end " << prev_end << "\n";
251             // it's OK to have two used regions in a row, but not two free ones
252             wassert(!(!prev_used && !r.used));
253             prev_end = start + r.size;
254             prev_used = r.used;
255         }
256         wassert(prev_end == total_size) << "prev_end " << prev_end << " total_size " << total_size << "\n";
257         bddebug(2) << "\n";
258 #endif
259     }
260 };
261 
262 // ---------------------
263 // General Wasm helpers
264 // ---------------------
265 
266 using wasm32_ptr_t = int32_t;
267 
268 const wasm32_ptr_t kMagicJitUserContextValue = -1;
269 
270 // TODO: vector codegen can underead allocated buffers; we need to deliberately
271 // allocate extra and return a pointer partway in to avoid out-of-bounds access
272 // failures. https://github.com/halide/Halide/issues/3738
273 constexpr size_t kExtraMallocSlop = 32;
274 
compile_to_wasm(const Module & module,const std::string & fn_name)275 std::vector<char> compile_to_wasm(const Module &module, const std::string &fn_name) {
276     static std::mutex link_lock;
277     std::lock_guard<std::mutex> lock(link_lock);
278 
279     llvm::LLVMContext context;
280     std::unique_ptr<llvm::Module> fn_module;
281 
282     // Default wasm stack size is ~64k, but schedules with lots of
283     // alloca usage (heavily inlined, or tracing enabled) can blow thru
284     // this, which crashes in amusing ways, so ask for extra stack space
285     // for the alloca usage.
286     size_t stack_size = 65536;
287     {
288         std::unique_ptr<CodeGen_WebAssembly> cg(new CodeGen_WebAssembly(module.target()));
289         cg->set_context(context);
290         fn_module = cg->compile(module);
291         stack_size += cg->get_requested_alloca_total();
292     }
293 
294     stack_size = align_up(stack_size);
295     wdebug(1) << "Requesting stack size of " << stack_size << "\n";
296 
297     std::unique_ptr<llvm::Module> llvm_module =
298         link_with_wasm_jit_runtime(&context, module.target(), std::move(fn_module));
299 
300     llvm::SmallVector<char, 4096> object;
301     llvm::raw_svector_ostream object_stream(object);
302     compile_llvm_module_to_object(*llvm_module, object_stream);
303 
304     // TODO: surely there's a better way that doesn't require spooling things
305     // out to temp files
306     TemporaryFile obj_file("", ".o");
307     write_entire_file(obj_file.pathname(), object.data(), object.size());
308 #if WASM_DEBUG_LEVEL
309     obj_file.detach();
310     wdebug(1) << "Dumping obj_file to " << obj_file.pathname() << "\n";
311 #endif
312 
313     TemporaryFile wasm_output("", ".wasm");
314 
315     std::string lld_arg_strs[] = {
316         "HalideJITLinker",
317         // For debugging purposes:
318         // "--verbose",
319         // "-error-limit=0",
320         // "--print-gc-sections",
321         "--export=__heap_base",
322         "--allow-undefined",
323         "-zstack-size=" + std::to_string(stack_size),
324         obj_file.pathname(),
325         "--entry=" + fn_name,
326         "-o",
327         wasm_output.pathname()};
328 
329     constexpr int c = sizeof(lld_arg_strs) / sizeof(lld_arg_strs[0]);
330     const char *lld_args[c];
331     for (int i = 0; i < c; ++i) {
332         lld_args[i] = lld_arg_strs[i].c_str();
333     }
334 
335     // lld will temporarily hijack the signal handlers to ensure that temp files get cleaned up,
336     // but rather than preserving custom handlers in place, it restores the default handlers.
337     // This conflicts with some of our testing infrastructure, which relies on a SIGABRT handler
338     // set at global-ctor time to stay set. Therefore we'll save and restore this ourselves.
339     // Note that we must restore it before using internal_error (and also on the non-error path).
340     auto old_abort_handler = std::signal(SIGABRT, SIG_DFL);
341 
342 #if LLVM_VERSION >= 110
343     if (!lld::wasm::link(lld_args, /*CanExitEarly*/ false, llvm::outs(), llvm::errs())) {
344         std::signal(SIGABRT, old_abort_handler);
345         internal_error << "lld::wasm::link failed\n";
346     }
347 #elif LLVM_VERSION >= 100
348     std::string lld_errs_string;
349     llvm::raw_string_ostream lld_errs(lld_errs_string);
350 
351     if (!lld::wasm::link(lld_args, /*CanExitEarly*/ false, llvm::outs(), llvm::errs())) {
352         std::signal(SIGABRT, old_abort_handler);
353         internal_error << "lld::wasm::link failed: (" << lld_errs.str() << ")\n";
354     }
355 #else
356     std::string lld_errs_string;
357     llvm::raw_string_ostream lld_errs(lld_errs_string);
358 
359     if (!lld::wasm::link(lld_args, /*CanExitEarly*/ false, lld_errs)) {
360         std::signal(SIGABRT, old_abort_handler);
361         internal_error << "lld::wasm::link failed: (" << lld_errs.str() << ")\n";
362     }
363 #endif
364 
365     std::signal(SIGABRT, old_abort_handler);
366 
367 #if WASM_DEBUG_LEVEL
368     wasm_output.detach();
369     wdebug(1) << "Dumping linked wasm to " << wasm_output.pathname() << "\n";
370 #endif
371 
372     return read_entire_file(wasm_output.pathname());
373 }
374 
halide_type_code(halide_type_code_t code,int bits)375 inline constexpr int halide_type_code(halide_type_code_t code, int bits) {
376     return ((int)code) | (bits << 8);
377 }
378 
379 // dynamic_type_dispatch is a utility for functors that want to be able
380 // to dynamically dispatch a halide_type_t to type-specialized code.
381 // To use it, a functor must be a *templated* class, e.g.
382 //
383 //     template<typename T> class MyFunctor { int operator()(arg1, arg2...); };
384 //
385 // dynamic_type_dispatch() is called with a halide_type_t as the first argument,
386 // followed by the arguments to the Functor's operator():
387 //
388 //     auto result = dynamic_type_dispatch<MyFunctor>(some_halide_type, arg1, arg2);
389 //
390 // Note that this means that the functor must be able to instantiate its
391 // operator() for all the Halide scalar types; it also means that all those
392 // variants *will* be instantiated (increasing code size), so this approach
393 // should only be used when strictly necessary.
394 
395 // clang-format off
396 template<template<typename> class Functor, typename... Args>
dynamic_type_dispatch(const halide_type_t & type,Args &&...args)397 auto dynamic_type_dispatch(const halide_type_t &type, Args &&... args) -> decltype(std::declval<Functor<uint8_t>>()(std::forward<Args>(args)...)) {
398 
399 #define HANDLE_CASE(CODE, BITS, TYPE)  \
400     case halide_type_code(CODE, BITS): \
401         return Functor<TYPE>()(std::forward<Args>(args)...);
402 
403     switch (halide_type_code((halide_type_code_t)type.code, type.bits)) {
404         HANDLE_CASE(halide_type_bfloat, 16, bfloat16_t)
405         HANDLE_CASE(halide_type_float, 16, float16_t)
406         HANDLE_CASE(halide_type_float, 32, float)
407         HANDLE_CASE(halide_type_float, 64, double)
408         HANDLE_CASE(halide_type_int, 8, int8_t)
409         HANDLE_CASE(halide_type_int, 16, int16_t)
410         HANDLE_CASE(halide_type_int, 32, int32_t)
411         HANDLE_CASE(halide_type_int, 64, int64_t)
412         HANDLE_CASE(halide_type_uint, 1, bool)
413         HANDLE_CASE(halide_type_uint, 8, uint8_t)
414         HANDLE_CASE(halide_type_uint, 16, uint16_t)
415         HANDLE_CASE(halide_type_uint, 32, uint32_t)
416         HANDLE_CASE(halide_type_uint, 64, uint64_t)
417         HANDLE_CASE(halide_type_handle, 64, void *)
418     default:
419         internal_error;
420         using ReturnType = decltype(std::declval<Functor<uint8_t>>()(std::forward<Args>(args)...));
421         return ReturnType();
422     }
423 
424 #undef HANDLE_CASE
425 }
426 // clang-format on
427 
to_string(const wabt::MemoryStream & m)428 std::string to_string(const wabt::MemoryStream &m) {
429     wabt::OutputBuffer &o = const_cast<wabt::MemoryStream *>(&m)->output_buffer();
430     return std::string((const char *)o.data.data(), o.data.size());
431 }
432 
433 struct WabtContext {
434     JITUserContext *const jit_user_context;
435     wabt::interp::Memory &memory;
436     BDMalloc &bdmalloc;
437 
WabtContextHalide::Internal::__anona20e47b00111::WabtContext438     explicit WabtContext(JITUserContext *jit_user_context, wabt::interp::Memory &memory, BDMalloc &bdmalloc)
439         : jit_user_context(jit_user_context), memory(memory), bdmalloc(bdmalloc) {
440     }
441 
442     WabtContext(const WabtContext &) = delete;
443     WabtContext(WabtContext &&) = delete;
444     void operator=(const WabtContext &) = delete;
445     void operator=(WabtContext &&) = delete;
446 };
447 
get_wabt_context(wabt::interp::Thread & thread)448 WabtContext &get_wabt_context(wabt::interp::Thread &thread) {
449     void *host_info = thread.host_info();
450     wassert(host_info);
451     return *(WabtContext *)host_info;
452 }
453 
get_wasm_memory_base(WabtContext & wabt_context)454 uint8_t *get_wasm_memory_base(WabtContext &wabt_context) {
455     return wabt_context.memory.UnsafeData();
456 }
457 
wabt_malloc(WabtContext & wabt_context,size_t size)458 wasm32_ptr_t wabt_malloc(WabtContext &wabt_context, size_t size) {
459     wasm32_ptr_t p = wabt_context.bdmalloc.alloc_region(size);
460     if (!p) {
461         constexpr int kWasmPageSize = 65536;
462         const int32_t pages_needed = (size + kWasmPageSize - 1) / 65536;
463         wdebug(1) << "attempting to grow by pages: " << pages_needed << "\n";
464 
465         wabt::Result r = wabt_context.memory.Grow(pages_needed);
466         internal_assert(Succeeded(r)) << "Memory::Grow() failed";
467 
468         wabt_context.bdmalloc.grow_total_size(wabt_context.memory.ByteSize());
469         p = wabt_context.bdmalloc.alloc_region(size);
470     }
471 
472     wdebug(2) << "allocation of " << size << " at: " << p << "\n";
473     return p;
474 }
475 
wabt_free(WabtContext & wabt_context,wasm32_ptr_t ptr)476 void wabt_free(WabtContext &wabt_context, wasm32_ptr_t ptr) {
477     wdebug(2) << "freeing ptr at: " << ptr << "\n";
478     wabt_context.bdmalloc.free_region(ptr);
479 }
480 
481 // Some internal code can call halide_error(null, ...), so this needs to be resilient to that.
482 // Callers must expect null and not crash.
get_jit_user_context(WabtContext & wabt_context,const wabt::interp::Value & arg)483 JITUserContext *get_jit_user_context(WabtContext &wabt_context, const wabt::interp::Value &arg) {
484     int32_t ucon_magic = arg.Get<int32_t>();
485     if (ucon_magic == 0) {
486         return nullptr;
487     }
488     wassert(ucon_magic == kMagicJitUserContextValue);
489     JITUserContext *jit_user_context = wabt_context.jit_user_context;
490     wassert(jit_user_context);
491     return jit_user_context;
492 }
493 
494 // -----------------------
495 // halide_buffer_t <-> wasm_halide_buffer_t helpers
496 // -----------------------
497 
498 struct wasm_halide_buffer_t {
499     uint64_t device;
500     wasm32_ptr_t device_interface;  // halide_device_interface_t*
501     wasm32_ptr_t host;              // uint8_t*
502     uint64_t flags;
503     halide_type_t type;
504     int32_t dimensions;
505     wasm32_ptr_t dim;      // halide_dimension_t*
506     wasm32_ptr_t padding;  // always zero
507 };
508 
dump_hostbuf(WabtContext & wabt_context,const halide_buffer_t * buf,const std::string & label)509 void dump_hostbuf(WabtContext &wabt_context, const halide_buffer_t *buf, const std::string &label) {
510 #if WASM_DEBUG_LEVEL >= 2
511     const halide_dimension_t *dim = buf->dim;
512     const uint8_t *host = buf->host;
513 
514     wdebug(1) << label << " = " << (void *)buf << " = {\n";
515     wdebug(1) << "  device = " << buf->device << "\n";
516     wdebug(1) << "  device_interface = " << buf->device_interface << "\n";
517     wdebug(1) << "  host = " << (void *)host << " = {\n";
518     if (host) {
519         wdebug(1) << "    " << (int)host[0] << ", " << (int)host[1] << ", " << (int)host[2] << ", " << (int)host[3] << "...\n";
520     }
521     wdebug(1) << "  }\n";
522     wdebug(1) << "  flags = " << buf->flags << "\n";
523     wdebug(1) << "  type = " << (int)buf->type.code << "," << (int)buf->type.bits << "," << buf->type.lanes << "\n";
524     wdebug(1) << "  dimensions = " << buf->dimensions << "\n";
525     wdebug(1) << "  dim = " << (void *)buf->dim << " = {\n";
526     for (int i = 0; i < buf->dimensions; i++) {
527         const auto &d = dim[i];
528         wdebug(1) << "    {" << d.min << "," << d.extent << "," << d.stride << "," << d.flags << "},\n";
529     }
530     wdebug(1) << "  }\n";
531     wdebug(1) << "  padding = " << buf->padding << "\n";
532     wdebug(1) << "}\n";
533 #endif
534 }
535 
dump_wasmbuf(WabtContext & wabt_context,wasm32_ptr_t buf_ptr,const std::string & label)536 void dump_wasmbuf(WabtContext &wabt_context, wasm32_ptr_t buf_ptr, const std::string &label) {
537 #if WASM_DEBUG_LEVEL >= 2
538     wassert(buf_ptr);
539 
540     uint8_t *base = get_wasm_memory_base(wabt_context);
541     wasm_halide_buffer_t *buf = (wasm_halide_buffer_t *)(base + buf_ptr);
542     halide_dimension_t *dim = buf->dim ? (halide_dimension_t *)(base + buf->dim) : nullptr;
543     uint8_t *host = buf->host ? (base + buf->host) : nullptr;
544 
545     wdebug(1) << label << " = " << buf_ptr << " -> " << (void *)buf << " = {\n";
546     wdebug(1) << "  device = " << buf->device << "\n";
547     wdebug(1) << "  device_interface = " << buf->device_interface << "\n";
548     wdebug(1) << "  host = " << buf->host << " -> " << (void *)host << " = {\n";
549     if (host) {
550         wdebug(1) << "    " << (int)host[0] << ", " << (int)host[1] << ", " << (int)host[2] << ", " << (int)host[3] << "...\n";
551     }
552     wdebug(1) << "  }\n";
553     wdebug(1) << "  flags = " << buf->flags << "\n";
554     wdebug(1) << "  type = " << (int)buf->type.code << "," << (int)buf->type.bits << "," << buf->type.lanes << "\n";
555     wdebug(1) << "  dimensions = " << buf->dimensions << "\n";
556     wdebug(1) << "  dim = " << buf->dim << " -> " << (void *)dim << " = {\n";
557     for (int i = 0; i < buf->dimensions; i++) {
558         const auto &d = dim[i];
559         wdebug(1) << "    {" << d.min << "," << d.extent << "," << d.stride << "," << d.flags << "},\n";
560     }
561     wdebug(1) << "  }\n";
562     wdebug(1) << "  padding = " << buf->padding << "\n";
563     wdebug(1) << "}\n";
564 #endif
565 }
566 
567 static_assert(sizeof(halide_type_t) == 4, "halide_type_t");
568 static_assert(sizeof(halide_dimension_t) == 16, "halide_dimension_t");
569 static_assert(sizeof(wasm_halide_buffer_t) == 40, "wasm_halide_buffer_t");
570 
571 // Given a halide_buffer_t on the host, allocate a wasm_halide_buffer_t in wasm
572 // memory space and copy all relevant data. The resulting buf is laid out in
573 // contiguous memory, and can be free with a single free().
hostbuf_to_wasmbuf(WabtContext & wabt_context,const halide_buffer_t * src)574 wasm32_ptr_t hostbuf_to_wasmbuf(WabtContext &wabt_context, const halide_buffer_t *src) {
575     static_assert(sizeof(halide_type_t) == 4, "halide_type_t");
576     static_assert(sizeof(halide_dimension_t) == 16, "halide_dimension_t");
577     static_assert(sizeof(wasm_halide_buffer_t) == 40, "wasm_halide_buffer_t");
578 
579     wdebug(2) << "\nhostbuf_to_wasmbuf:\n";
580     if (!src) {
581         return 0;
582     }
583 
584     dump_hostbuf(wabt_context, src, "src");
585 
586     wassert(src->device == 0);
587     wassert(src->device_interface == nullptr);
588 
589     // Assume our malloc() has everything 32-byte aligned,
590     // and insert enough padding for host to also be 32-byte aligned.
591     const size_t dims_size_in_bytes = sizeof(halide_dimension_t) * src->dimensions;
592     const size_t dims_offset = sizeof(wasm_halide_buffer_t);
593     const size_t mem_needed_base = sizeof(wasm_halide_buffer_t) + dims_size_in_bytes;
594     const size_t host_offset = align_up(mem_needed_base);
595     const size_t host_size_in_bytes = src->size_in_bytes();
596     const size_t mem_needed = host_offset + host_size_in_bytes;
597 
598     const wasm32_ptr_t dst_ptr = wabt_malloc(wabt_context, mem_needed);
599     wassert(dst_ptr);
600 
601     uint8_t *base = get_wasm_memory_base(wabt_context);
602 
603     wasm_halide_buffer_t *dst = (wasm_halide_buffer_t *)(base + dst_ptr);
604     dst->device = 0;
605     dst->device_interface = 0;
606     dst->host = src->host ? (dst_ptr + host_offset) : 0;
607     dst->flags = src->flags;
608     dst->type = src->type;
609     dst->dimensions = src->dimensions;
610     dst->dim = src->dimensions ? (dst_ptr + dims_offset) : 0;
611     dst->padding = 0;
612 
613     if (src->dim) {
614         memcpy(base + dst->dim, src->dim, dims_size_in_bytes);
615     }
616     if (src->host) {
617         memcpy(base + dst->host, src->host, host_size_in_bytes);
618     }
619 
620     dump_wasmbuf(wabt_context, dst_ptr, "dst");
621 
622     return dst_ptr;
623 }
624 
625 // Given a pointer to a wasm_halide_buffer_t in wasm memory space,
626 // allocate a Buffer<> on the host and copy all relevant data.
wasmbuf_to_hostbuf(WabtContext & wabt_context,wasm32_ptr_t src_ptr,Halide::Runtime::Buffer<> & dst)627 void wasmbuf_to_hostbuf(WabtContext &wabt_context, wasm32_ptr_t src_ptr, Halide::Runtime::Buffer<> &dst) {
628     wdebug(2) << "\nwasmbuf_to_hostbuf:\n";
629 
630     dump_wasmbuf(wabt_context, src_ptr, "src");
631 
632     wassert(src_ptr);
633 
634     uint8_t *base = get_wasm_memory_base(wabt_context);
635 
636     wasm_halide_buffer_t *src = (wasm_halide_buffer_t *)(base + src_ptr);
637 
638     wassert(src->device == 0);
639     wassert(src->device_interface == 0);
640 
641     halide_buffer_t dst_tmp;
642     dst_tmp.device = 0;
643     dst_tmp.device_interface = 0;
644     dst_tmp.host = nullptr;  // src->host ? (base + src->host) : nullptr;
645     dst_tmp.flags = src->flags;
646     dst_tmp.type = src->type;
647     dst_tmp.dimensions = src->dimensions;
648     dst_tmp.dim = src->dim ? (halide_dimension_t *)(base + src->dim) : nullptr;
649     dst_tmp.padding = 0;
650 
651     dump_hostbuf(wabt_context, &dst_tmp, "dst_tmp");
652 
653     dst = Halide::Runtime::Buffer<>(dst_tmp);
654     if (src->host) {
655         // Don't use dst.copy(); it can tweak strides in ways that matter.
656         dst.allocate();
657         const size_t host_size_in_bytes = dst.raw_buffer()->size_in_bytes();
658         memcpy(dst.raw_buffer()->host, base + src->host, host_size_in_bytes);
659     }
660     dump_hostbuf(wabt_context, dst.raw_buffer(), "dst");
661 }
662 
663 // Given a wasm_halide_buffer_t, copy possibly-changed data into a halide_buffer_t.
664 // Both buffers are asserted to match in type and dimensions.
copy_wasmbuf_to_existing_hostbuf(WabtContext & wabt_context,wasm32_ptr_t src_ptr,halide_buffer_t * dst)665 void copy_wasmbuf_to_existing_hostbuf(WabtContext &wabt_context, wasm32_ptr_t src_ptr, halide_buffer_t *dst) {
666     wassert(src_ptr && dst);
667 
668     wdebug(2) << "\ncopy_wasmbuf_to_existing_hostbuf:\n";
669     dump_wasmbuf(wabt_context, src_ptr, "src");
670 
671     uint8_t *base = get_wasm_memory_base(wabt_context);
672 
673     wasm_halide_buffer_t *src = (wasm_halide_buffer_t *)(base + src_ptr);
674     wassert(src->device == 0);
675     wassert(src->device_interface == 0);
676     wassert(src->dimensions == dst->dimensions);
677     wassert(src->type == dst->type);
678 
679     dump_hostbuf(wabt_context, dst, "dst_pre");
680 
681     if (src->dimensions) {
682         memcpy(dst->dim, base + src->dim, sizeof(halide_dimension_t) * src->dimensions);
683     }
684     if (src->host) {
685         size_t host_size_in_bytes = dst->size_in_bytes();
686         memcpy(dst->host, base + src->host, host_size_in_bytes);
687     }
688 
689     dst->device = 0;
690     dst->device_interface = 0;
691     dst->flags = src->flags;
692 
693     dump_hostbuf(wabt_context, dst, "dst_post");
694 }
695 
696 // Given a halide_buffer_t, copy possibly-changed data into a wasm_halide_buffer_t.
697 // Both buffers are asserted to match in type and dimensions.
copy_hostbuf_to_existing_wasmbuf(WabtContext & wabt_context,const halide_buffer_t * src,wasm32_ptr_t dst_ptr)698 void copy_hostbuf_to_existing_wasmbuf(WabtContext &wabt_context, const halide_buffer_t *src, wasm32_ptr_t dst_ptr) {
699     wassert(src && dst_ptr);
700 
701     wdebug(1) << "\ncopy_hostbuf_to_existing_wasmbuf:\n";
702     dump_hostbuf(wabt_context, src, "src");
703 
704     uint8_t *base = get_wasm_memory_base(wabt_context);
705 
706     wasm_halide_buffer_t *dst = (wasm_halide_buffer_t *)(base + dst_ptr);
707     wassert(src->device == 0);
708     wassert(src->device_interface == 0);
709     wassert(src->dimensions == dst->dimensions);
710     wassert(src->type == dst->type);
711 
712     dump_wasmbuf(wabt_context, dst_ptr, "dst_pre");
713 
714     if (src->dimensions) {
715         memcpy(base + dst->dim, src->dim, sizeof(halide_dimension_t) * src->dimensions);
716     }
717     if (src->host) {
718         size_t host_size_in_bytes = src->size_in_bytes();
719         memcpy(base + dst->host, src->host, host_size_in_bytes);
720     }
721 
722     dst->device = 0;
723     dst->device_interface = 0;
724     dst->flags = src->flags;
725 
726     dump_wasmbuf(wabt_context, dst_ptr, "dst_post");
727 }
728 
729 // --------------------------------------------------
730 // Helpers for converting to/from wabt::interp::Value
731 // --------------------------------------------------
732 
733 template<typename T>
734 struct LoadValue {
operator ()Halide::Internal::__anona20e47b00111::LoadValue735     inline wabt::interp::Value operator()(const void *src) {
736         const T val = *(const T *)(src);
737         return wabt::interp::Value::Make(val);
738     }
739 };
740 
741 template<>
operator ()(const void * src)742 inline wabt::interp::Value LoadValue<bool>::operator()(const void *src) {
743     // WABT doesn't do bools. Stash as u8 for now.
744     const uint8_t val = *(const uint8_t *)src;
745     return wabt::interp::Value::Make(val);
746 }
747 
748 template<>
operator ()(const void * src)749 inline wabt::interp::Value LoadValue<void *>::operator()(const void *src) {
750     // Halide 'handle' types are always uint64, even on 32-bit systems
751     const uint64_t val = *(const uint64_t *)src;
752     return wabt::interp::Value::Make(val);
753 }
754 
755 template<>
operator ()(const void * src)756 inline wabt::interp::Value LoadValue<float16_t>::operator()(const void *src) {
757     const uint16_t val = *(const uint16_t *)src;
758     return wabt::interp::Value::Make(val);
759 }
760 
761 template<>
operator ()(const void * src)762 inline wabt::interp::Value LoadValue<bfloat16_t>::operator()(const void *src) {
763     const uint16_t val = *(const uint16_t *)src;
764     return wabt::interp::Value::Make(val);
765 }
766 
load_value(const Type & t,const void * src)767 inline wabt::interp::Value load_value(const Type &t, const void *src) {
768     return dynamic_type_dispatch<LoadValue>(t, src);
769 }
770 
771 template<typename T>
load_value(const T & val)772 inline wabt::interp::Value load_value(const T &val) {
773     return LoadValue<T>()(&val);
774 }
775 
776 // -----
777 
778 template<typename T>
779 struct StoreValue {
operator ()Halide::Internal::__anona20e47b00111::StoreValue780     inline void operator()(const wabt::interp::Value &src, void *dst) {
781         *(T *)dst = src.Get<T>();
782     }
783 };
784 
785 template<>
operator ()(const wabt::interp::Value & src,void * dst)786 inline void StoreValue<bool>::operator()(const wabt::interp::Value &src, void *dst) {
787     // WABT doesn't do bools. Stash as u8 for now.
788     *(uint8_t *)dst = src.Get<uint8_t>();
789 }
790 
791 template<>
operator ()(const wabt::interp::Value & src,void * dst)792 inline void StoreValue<void *>::operator()(const wabt::interp::Value &src, void *dst) {
793     // Halide 'handle' types are always uint64, even on 32-bit systems
794     *(uint64_t *)dst = src.Get<uint64_t>();
795 }
796 
797 template<>
operator ()(const wabt::interp::Value & src,void * dst)798 inline void StoreValue<float16_t>::operator()(const wabt::interp::Value &src, void *dst) {
799     *(uint16_t *)dst = src.Get<uint16_t>();
800 }
801 
802 template<>
operator ()(const wabt::interp::Value & src,void * dst)803 inline void StoreValue<bfloat16_t>::operator()(const wabt::interp::Value &src, void *dst) {
804     *(uint16_t *)dst = src.Get<uint16_t>();
805 }
806 
store_value(const Type & t,const wabt::interp::Value & src,void * dst)807 inline void store_value(const Type &t, const wabt::interp::Value &src, void *dst) {
808     dynamic_type_dispatch<StoreValue>(t, src, dst);
809 }
810 
811 template<typename T>
store_value(const wabt::interp::Value & src,T * dst)812 inline void store_value(const wabt::interp::Value &src, T *dst) {
813     StoreValue<T>()(src, dst);
814 }
815 
816 // --------------------------------------------------
817 // Host Callback Functions
818 // --------------------------------------------------
819 
820 template<typename T, T some_func(T)>
wabt_posix_math_1(wabt::interp::Thread & thread,const wabt::interp::Values & args,wabt::interp::Values & results,wabt::interp::Trap::Ptr * trap)821 wabt::Result wabt_posix_math_1(wabt::interp::Thread &thread,
822                                const wabt::interp::Values &args,
823                                wabt::interp::Values &results,
824                                wabt::interp::Trap::Ptr *trap) {
825     wassert(args.size() == 1);
826     const T in = args[0].Get<T>();
827     const T out = some_func(in);
828     results[0] = wabt::interp::Value::Make(out);
829     return wabt::Result::Ok;
830 }
831 
832 template<typename T, T some_func(T, T)>
wabt_posix_math_2(wabt::interp::Thread & thread,const wabt::interp::Values & args,wabt::interp::Values & results,wabt::interp::Trap::Ptr * trap)833 wabt::Result wabt_posix_math_2(wabt::interp::Thread &thread,
834                                const wabt::interp::Values &args,
835                                wabt::interp::Values &results,
836                                wabt::interp::Trap::Ptr *trap) {
837     wassert(args.size() == 2);
838     const T in1 = args[0].Get<T>();
839     const T in2 = args[1].Get<T>();
840     const T out = some_func(in1, in2);
841     results[0] = wabt::interp::Value::Make(out);
842     return wabt::Result::Ok;
843 }
844 
845 #define WABT_HOST_CALLBACK(x)                                              \
846     wabt::Result wabt_jit_##x##_callback(wabt::interp::Thread &thread,     \
847                                          const wabt::interp::Values &args, \
848                                          wabt::interp::Values &results,    \
849                                          wabt::interp::Trap::Ptr *trap)
850 
851 #define WABT_HOST_CALLBACK_UNIMPLEMENTED(x)                                          \
852     WABT_HOST_CALLBACK(x) {                                                          \
853         internal_error << "WebAssembly JIT does not yet support the " #x "() call."; \
854         return wabt::Result::Ok;                                                     \
855     }
856 
WABT_HOST_CALLBACK(__cxa_atexit)857 WABT_HOST_CALLBACK(__cxa_atexit) {
858     // nothing
859     return wabt::Result::Ok;
860 }
861 
WABT_HOST_CALLBACK(__extendhfsf2)862 WABT_HOST_CALLBACK(__extendhfsf2) {
863     const uint16_t in = args[0].Get<uint16_t>();
864     const float out = (float)float16_t::make_from_bits(in);
865     results[0] = wabt::interp::Value::Make(out);
866     return wabt::Result::Ok;
867 }
868 
WABT_HOST_CALLBACK(__truncsfhf2)869 WABT_HOST_CALLBACK(__truncsfhf2) {
870     const float in = args[0].Get<float>();
871     const uint16_t out = float16_t(in).to_bits();
872     results[0] = wabt::interp::Value::Make(out);
873     return wabt::Result::Ok;
874 }
875 
WABT_HOST_CALLBACK(abort)876 WABT_HOST_CALLBACK(abort) {
877     abort();
878     return wabt::Result::Ok;
879 }
880 
881 WABT_HOST_CALLBACK_UNIMPLEMENTED(fclose)
882 
WABT_HOST_CALLBACK_UNIMPLEMENTED(fileno)883 WABT_HOST_CALLBACK_UNIMPLEMENTED(fileno)
884 
885 WABT_HOST_CALLBACK_UNIMPLEMENTED(fopen)
886 
887 WABT_HOST_CALLBACK(free) {
888     WabtContext &wabt_context = get_wabt_context(thread);
889 
890     wasm32_ptr_t p = args[0].Get<int32_t>();
891     if (p) p -= kExtraMallocSlop;
892     wabt_free(wabt_context, p);
893     return wabt::Result::Ok;
894 }
895 
896 WABT_HOST_CALLBACK_UNIMPLEMENTED(fwrite)
897 
WABT_HOST_CALLBACK(getenv)898 WABT_HOST_CALLBACK(getenv) {
899     WabtContext &wabt_context = get_wabt_context(thread);
900 
901     const int32_t s = args[0].Get<int32_t>();
902 
903     uint8_t *base = get_wasm_memory_base(wabt_context);
904     char *e = getenv((char *)base + s);
905 
906     // TODO: this string is leaked
907     if (e) {
908         wasm32_ptr_t r = wabt_malloc(wabt_context, strlen(e) + 1);
909         strcpy((char *)base + r, e);
910         results[0] = wabt::interp::Value::Make(r);
911     } else {
912         results[0] = wabt::interp::Value::Make(0);
913     }
914     return wabt::Result::Ok;
915 }
916 
WABT_HOST_CALLBACK(halide_print)917 WABT_HOST_CALLBACK(halide_print) {
918     WabtContext &wabt_context = get_wabt_context(thread);
919 
920     wassert(args.size() == 2);
921 
922     JITUserContext *jit_user_context = get_jit_user_context(wabt_context, args[0]);
923     const int32_t str_address = args[1].Get<int32_t>();
924 
925     uint8_t *p = get_wasm_memory_base(wabt_context);
926     const char *str = (const char *)p + str_address;
927 
928     if (jit_user_context && jit_user_context->handlers.custom_print != NULL) {
929         (*jit_user_context->handlers.custom_print)(jit_user_context, str);
930     } else {
931         std::cout << str;
932     }
933     return wabt::Result::Ok;
934 }
935 
WABT_HOST_CALLBACK(halide_trace_helper)936 WABT_HOST_CALLBACK(halide_trace_helper) {
937     WabtContext &wabt_context = get_wabt_context(thread);
938 
939     wassert(args.size() == 12);
940 
941     uint8_t *base = get_wasm_memory_base(wabt_context);
942 
943     JITUserContext *jit_user_context = get_jit_user_context(wabt_context, args[0]);
944 
945     const wasm32_ptr_t func_name_ptr = args[1].Get<int32_t>();
946     const wasm32_ptr_t value_ptr = args[2].Get<int32_t>();
947     const wasm32_ptr_t coordinates_ptr = args[3].Get<int32_t>();
948     const int type_code = args[4].Get<int32_t>();
949     const int type_bits = args[5].Get<int32_t>();
950     const int type_lanes = args[6].Get<int32_t>();
951     const int trace_code = args[7].Get<int32_t>();
952     const int parent_id = args[8].Get<int32_t>();
953     const int value_index = args[9].Get<int32_t>();
954     const int dimensions = args[10].Get<int32_t>();
955     const wasm32_ptr_t trace_tag_ptr = args[11].Get<int32_t>();
956 
957     wassert(dimensions >= 0 && dimensions < 1024);  // not a hard limit, just a sanity check
958 
959     halide_trace_event_t event;
960     event.func = (const char *)(base + func_name_ptr);
961     event.value = value_ptr ? ((void *)(base + value_ptr)) : nullptr;
962     event.coordinates = coordinates_ptr ? ((int32_t *)(base + coordinates_ptr)) : nullptr;
963     event.trace_tag = (const char *)(base + trace_tag_ptr);
964     event.type.code = (halide_type_code_t)type_code;
965     event.type.bits = (uint8_t)type_bits;
966     event.type.lanes = (uint16_t)type_lanes;
967     event.event = (halide_trace_event_code_t)trace_code;
968     event.parent_id = parent_id;
969     event.value_index = value_index;
970     event.dimensions = dimensions;
971 
972     int32_t result = 0;
973     if (jit_user_context && jit_user_context->handlers.custom_trace != NULL) {
974         result = (*jit_user_context->handlers.custom_trace)(jit_user_context, &event);
975     } else {
976         debug(0) << "Dropping trace event due to lack of trace handler.\n";
977     }
978 
979     results[0] = wabt::interp::Value::Make(result);
980     return wabt::Result::Ok;
981 }
982 
WABT_HOST_CALLBACK(halide_error)983 WABT_HOST_CALLBACK(halide_error) {
984     WabtContext &wabt_context = get_wabt_context(thread);
985 
986     wassert(args.size() == 2);
987 
988     JITUserContext *jit_user_context = get_jit_user_context(wabt_context, args[0]);
989     const int32_t str_address = args[1].Get<int32_t>();
990 
991     uint8_t *p = get_wasm_memory_base(wabt_context);
992     const char *str = (const char *)p + str_address;
993 
994     if (jit_user_context && jit_user_context->handlers.custom_error != NULL) {
995         (*jit_user_context->handlers.custom_error)(jit_user_context, str);
996     } else {
997         halide_runtime_error << str;
998     }
999     return wabt::Result::Ok;
1000 }
1001 
WABT_HOST_CALLBACK(malloc)1002 WABT_HOST_CALLBACK(malloc) {
1003     WabtContext &wabt_context = get_wabt_context(thread);
1004 
1005     size_t size = args[0].Get<int32_t>() + kExtraMallocSlop;
1006     wasm32_ptr_t p = wabt_malloc(wabt_context, size);
1007     if (p) p += kExtraMallocSlop;
1008     results[0] = wabt::interp::Value::Make(p);
1009     return wabt::Result::Ok;
1010 }
1011 
WABT_HOST_CALLBACK(memcpy)1012 WABT_HOST_CALLBACK(memcpy) {
1013     WabtContext &wabt_context = get_wabt_context(thread);
1014 
1015     const int32_t dst = args[0].Get<int32_t>();
1016     const int32_t src = args[1].Get<int32_t>();
1017     const int32_t n = args[2].Get<int32_t>();
1018 
1019     uint8_t *base = get_wasm_memory_base(wabt_context);
1020 
1021     memcpy(base + dst, base + src, n);
1022 
1023     results[0] = wabt::interp::Value::Make(dst);
1024     return wabt::Result::Ok;
1025 }
1026 
WABT_HOST_CALLBACK(memset)1027 WABT_HOST_CALLBACK(memset) {
1028     WabtContext &wabt_context = get_wabt_context(thread);
1029 
1030     const int32_t s = args[0].Get<int32_t>();
1031     const int32_t c = args[1].Get<int32_t>();
1032     const int32_t n = args[2].Get<int32_t>();
1033 
1034     uint8_t *base = get_wasm_memory_base(wabt_context);
1035     memset(base + s, c, n);
1036 
1037     results[0] = wabt::interp::Value::Make(s);
1038     return wabt::Result::Ok;
1039 }
1040 
WABT_HOST_CALLBACK(memcmp)1041 WABT_HOST_CALLBACK(memcmp) {
1042     WabtContext &wabt_context = get_wabt_context(thread);
1043 
1044     const int32_t s1 = args[0].Get<int32_t>();
1045     const int32_t s2 = args[1].Get<int32_t>();
1046     const int32_t n = args[2].Get<int32_t>();
1047 
1048     uint8_t *base = get_wasm_memory_base(wabt_context);
1049 
1050     const int32_t r = memcmp(base + s1, base + s2, n);
1051 
1052     results[0] = wabt::interp::Value::Make(r);
1053     return wabt::Result::Ok;
1054 }
1055 
WABT_HOST_CALLBACK(strlen)1056 WABT_HOST_CALLBACK(strlen) {
1057     WabtContext &wabt_context = get_wabt_context(thread);
1058     const int32_t s = args[0].Get<int32_t>();
1059 
1060     uint8_t *base = get_wasm_memory_base(wabt_context);
1061     int32_t r = strlen((char *)base + s);
1062 
1063     results[0] = wabt::interp::Value::Make(r);
1064     return wabt::Result::Ok;
1065 }
1066 
1067 WABT_HOST_CALLBACK_UNIMPLEMENTED(write)
1068 
1069 using HostCallbackMap = std::unordered_map<std::string, wabt::interp::HostFunc::Callback>;
1070 
1071 // clang-format off
get_host_callback_map()1072 const HostCallbackMap &get_host_callback_map() {
1073 
1074     static HostCallbackMap m = {
1075         // General runtime functions.
1076 
1077         #define DEFINE_CALLBACK(f) { #f, wabt_jit_##f##_callback },
1078 
1079         DEFINE_CALLBACK(__cxa_atexit)
1080         DEFINE_CALLBACK(__extendhfsf2)
1081         DEFINE_CALLBACK(__truncsfhf2)
1082         DEFINE_CALLBACK(abort)
1083         DEFINE_CALLBACK(fclose)
1084         DEFINE_CALLBACK(fileno)
1085         DEFINE_CALLBACK(fopen)
1086         DEFINE_CALLBACK(free)
1087         DEFINE_CALLBACK(fwrite)
1088         DEFINE_CALLBACK(getenv)
1089         DEFINE_CALLBACK(halide_error)
1090         DEFINE_CALLBACK(halide_print)
1091         DEFINE_CALLBACK(halide_trace_helper)
1092         DEFINE_CALLBACK(malloc)
1093         DEFINE_CALLBACK(memcmp)
1094         DEFINE_CALLBACK(memcpy)
1095         DEFINE_CALLBACK(memset)
1096         DEFINE_CALLBACK(strlen)
1097         DEFINE_CALLBACK(write)
1098 
1099         #undef DEFINE_CALLBACK
1100 
1101         // Posix math.
1102         #define DEFINE_POSIX_MATH_CALLBACK(t, f) { #f, wabt_posix_math_1<t, ::f> },
1103 
1104         DEFINE_POSIX_MATH_CALLBACK(double, acos)
1105         DEFINE_POSIX_MATH_CALLBACK(double, acosh)
1106         DEFINE_POSIX_MATH_CALLBACK(double, asin)
1107         DEFINE_POSIX_MATH_CALLBACK(double, asinh)
1108         DEFINE_POSIX_MATH_CALLBACK(double, atan)
1109         DEFINE_POSIX_MATH_CALLBACK(double, atanh)
1110         DEFINE_POSIX_MATH_CALLBACK(double, cos)
1111         DEFINE_POSIX_MATH_CALLBACK(double, cosh)
1112         DEFINE_POSIX_MATH_CALLBACK(double, exp)
1113         DEFINE_POSIX_MATH_CALLBACK(double, log)
1114         DEFINE_POSIX_MATH_CALLBACK(double, round)
1115         DEFINE_POSIX_MATH_CALLBACK(double, sin)
1116         DEFINE_POSIX_MATH_CALLBACK(double, sinh)
1117         DEFINE_POSIX_MATH_CALLBACK(double, tan)
1118         DEFINE_POSIX_MATH_CALLBACK(double, tanh)
1119 
1120         DEFINE_POSIX_MATH_CALLBACK(float, acosf)
1121         DEFINE_POSIX_MATH_CALLBACK(float, acoshf)
1122         DEFINE_POSIX_MATH_CALLBACK(float, asinf)
1123         DEFINE_POSIX_MATH_CALLBACK(float, asinhf)
1124         DEFINE_POSIX_MATH_CALLBACK(float, atanf)
1125         DEFINE_POSIX_MATH_CALLBACK(float, atanhf)
1126         DEFINE_POSIX_MATH_CALLBACK(float, cosf)
1127         DEFINE_POSIX_MATH_CALLBACK(float, coshf)
1128         DEFINE_POSIX_MATH_CALLBACK(float, expf)
1129         DEFINE_POSIX_MATH_CALLBACK(float, logf)
1130         DEFINE_POSIX_MATH_CALLBACK(float, roundf)
1131         DEFINE_POSIX_MATH_CALLBACK(float, sinf)
1132         DEFINE_POSIX_MATH_CALLBACK(float, sinhf)
1133         DEFINE_POSIX_MATH_CALLBACK(float, tanf)
1134         DEFINE_POSIX_MATH_CALLBACK(float, tanhf)
1135 
1136         #undef DEFINE_POSIX_MATH_CALLBACK
1137 
1138         #define DEFINE_POSIX_MATH_CALLBACK2(t, f) { #f, wabt_posix_math_2<t, ::f> },
1139 
1140         DEFINE_POSIX_MATH_CALLBACK2(float, atan2f)
1141         DEFINE_POSIX_MATH_CALLBACK2(double, atan2)
1142         DEFINE_POSIX_MATH_CALLBACK2(float, fminf)
1143         DEFINE_POSIX_MATH_CALLBACK2(double, fmin)
1144         DEFINE_POSIX_MATH_CALLBACK2(float, fmaxf)
1145         DEFINE_POSIX_MATH_CALLBACK2(double, fmax)
1146         DEFINE_POSIX_MATH_CALLBACK2(float, powf)
1147         DEFINE_POSIX_MATH_CALLBACK2(double, pow)
1148 
1149         #undef DEFINE_POSIX_MATH_CALLBACK2
1150     };
1151 
1152     return m;
1153 }
1154 // clang-format on
1155 
1156 // --------------------------------------------------
1157 // Host Callback Functions
1158 // --------------------------------------------------
1159 
1160 struct ExternArgType {
1161     halide_type_t type;
1162     bool is_void;
1163     bool is_buffer;
1164 };
1165 
1166 using TrampolineFn = void (*)(void **);
1167 
extern_callback_wrapper(const std::vector<ExternArgType> & arg_types,TrampolineFn trampoline_fn,wabt::interp::Thread & thread,const wabt::interp::Values & args,wabt::interp::Values & results,wabt::interp::Trap::Ptr * trap)1168 wabt::Result extern_callback_wrapper(const std::vector<ExternArgType> &arg_types,
1169                                      TrampolineFn trampoline_fn,
1170                                      wabt::interp::Thread &thread,
1171                                      const wabt::interp::Values &args,
1172                                      wabt::interp::Values &results,
1173                                      wabt::interp::Trap::Ptr *trap) {
1174     WabtContext &wabt_context = get_wabt_context(thread);
1175 
1176     wassert(arg_types.size() >= 1);
1177     const size_t arg_types_len = arg_types.size() - 1;
1178     const ExternArgType &ret_type = arg_types[0];
1179 
1180     // There's wasted space here, but that's ok.
1181     std::vector<Halide::Runtime::Buffer<>> buffers(arg_types_len);
1182     std::vector<uint64_t> scalars(arg_types_len, 0);
1183     std::vector<void *> trampoline_args(arg_types_len, nullptr);
1184 
1185     for (size_t i = 0; i < arg_types_len; ++i) {
1186         const auto &a = arg_types[i + 1];
1187         if (a.is_buffer) {
1188             const wasm32_ptr_t buf_ptr = args[i].Get<int32_t>();
1189             wasmbuf_to_hostbuf(wabt_context, buf_ptr, buffers[i]);
1190             trampoline_args[i] = buffers[i].raw_buffer();
1191         } else {
1192             store_value(a.type, args[i], &scalars[i]);
1193             trampoline_args[i] = &scalars[i];
1194         }
1195     }
1196 
1197     // The return value (if any) is always scalar.
1198     uint64_t ret_val = 0;
1199     const bool has_retval = !ret_type.is_void;
1200     internal_assert(!ret_type.is_buffer);
1201     if (has_retval) {
1202         trampoline_args.push_back(&ret_val);
1203     }
1204     (*trampoline_fn)(trampoline_args.data());
1205 
1206     if (has_retval) {
1207         results[0] = dynamic_type_dispatch<LoadValue>(ret_type.type, (void *)&ret_val);
1208     }
1209 
1210     // Progagate buffer data backwards. Note that for arbitrary extern functions,
1211     // we have no idea which buffers might be "input only", so we copy all data for all of them.
1212     for (size_t i = 0; i < arg_types_len; ++i) {
1213         const auto &a = arg_types[i + 1];
1214         if (a.is_buffer) {
1215             const wasm32_ptr_t buf_ptr = args[i].Get<int32_t>();
1216             copy_hostbuf_to_existing_wasmbuf(wabt_context, buffers[i], buf_ptr);
1217         }
1218     }
1219 
1220     return wabt::Result::Ok;
1221 }
1222 
should_skip_extern_symbol(const std::string & name)1223 bool should_skip_extern_symbol(const std::string &name) {
1224     static std::set<std::string> symbols = {
1225         "halide_print",
1226         "halide_error"};
1227     return symbols.count(name) > 0;
1228 }
1229 
make_extern_callback(wabt::interp::Store & store,const std::map<std::string,Halide::JITExtern> & jit_externs,const JITModule & trampolines,const wabt::interp::ImportDesc & import)1230 wabt::interp::HostFunc::Ptr make_extern_callback(wabt::interp::Store &store,
1231                                                  const std::map<std::string, Halide::JITExtern> &jit_externs,
1232                                                  const JITModule &trampolines,
1233                                                  const wabt::interp::ImportDesc &import) {
1234     const std::string &fn_name = import.type.name;
1235     if (should_skip_extern_symbol(fn_name)) {
1236         wdebug(1) << "Skipping extern symbol: " << fn_name << "\n";
1237         return wabt::interp::HostFunc::Ptr();
1238     }
1239 
1240     const auto it = jit_externs.find(fn_name);
1241     if (it == jit_externs.end()) {
1242         wdebug(1) << "Extern symbol not found in JIT Externs: " << fn_name << "\n";
1243         return wabt::interp::HostFunc::Ptr();
1244     }
1245     const ExternSignature &sig = it->second.extern_c_function().signature();
1246 
1247     const auto &tramp_it = trampolines.exports().find(fn_name + kTrampolineSuffix);
1248     if (tramp_it == trampolines.exports().end()) {
1249         wdebug(1) << "Extern symbol not found in trampolines: " << fn_name << "\n";
1250         return wabt::interp::HostFunc::Ptr();
1251     }
1252     TrampolineFn trampoline_fn = (TrampolineFn)tramp_it->second.address;
1253 
1254     const size_t arg_count = sig.arg_types().size();
1255 
1256     std::vector<ExternArgType> arg_types;
1257 
1258     if (sig.is_void_return()) {
1259         const bool is_void = true;
1260         const bool is_buffer = false;
1261         // Specifying a type here with bits == 0 should trigger a proper 'void' return type
1262         arg_types.push_back(ExternArgType{{halide_type_int, 0, 0}, is_void, is_buffer});
1263     } else {
1264         const Type &t = sig.ret_type();
1265         const bool is_void = false;
1266         const bool is_buffer = (t == type_of<halide_buffer_t *>());
1267         user_assert(t.lanes() == 1) << "Halide Extern functions cannot return vector values.";
1268         user_assert(!is_buffer) << "Halide Extern functions cannot return halide_buffer_t.";
1269         arg_types.push_back(ExternArgType{t, is_void, is_buffer});
1270     }
1271     for (size_t i = 0; i < arg_count; ++i) {
1272         const Type &t = sig.arg_types()[i];
1273         const bool is_void = false;
1274         const bool is_buffer = (t == type_of<halide_buffer_t *>());
1275         user_assert(t.lanes() == 1) << "Halide Extern functions cannot accept vector values as arguments.";
1276         arg_types.push_back(ExternArgType{t, is_void, is_buffer});
1277     }
1278 
1279     const auto callback_wrapper =
1280         [arg_types, trampoline_fn](wabt::interp::Thread &thread,
1281                                    const wabt::interp::Values &args,
1282                                    wabt::interp::Values &results,
1283                                    wabt::interp::Trap::Ptr *trap) -> wabt::Result {
1284         return extern_callback_wrapper(arg_types, trampoline_fn, thread, args, results, trap);
1285     };
1286 
1287     auto func_type = *wabt::cast<wabt::interp::FuncType>(import.type.type.get());
1288     auto host_func = wabt::interp::HostFunc::New(store, func_type, callback_wrapper);
1289 
1290     return host_func;
1291 }
1292 
calc_features(const Target & target)1293 wabt::Features calc_features(const Target &target) {
1294     wabt::Features f;
1295     if (target.has_feature(Target::WasmSignExt)) {
1296         f.enable_sign_extension();
1297     }
1298     if (target.has_feature(Target::WasmSimd128)) {
1299         f.enable_simd();
1300     }
1301     if (target.has_feature(Target::WasmSatFloatToInt)) {
1302         f.enable_sat_float_to_int();
1303     }
1304     return f;
1305 }
1306 
1307 }  // namespace
1308 
1309 #endif  // WITH_WABT
1310 
1311 struct WasmModuleContents {
1312     mutable RefCount ref_count;
1313 
1314     const Target target;
1315     const std::vector<Argument> arguments;
1316     std::map<std::string, Halide::JITExtern> jit_externs;
1317     std::vector<JITModule> extern_deps;
1318     JITModule trampolines;
1319 
1320 #if WITH_WABT
1321     BDMalloc bdmalloc;
1322     wabt::interp::Store store;
1323     wabt::interp::Module::Ptr module;
1324     wabt::interp::Instance::Ptr instance;
1325     wabt::interp::Thread::Options thread_options;
1326     wabt::interp::Memory::Ptr memory;
1327 #endif
1328 
1329     WasmModuleContents(
1330         const Module &halide_module,
1331         const std::vector<Argument> &arguments,
1332         const std::string &fn_name,
1333         const std::map<std::string, Halide::JITExtern> &jit_externs,
1334         const std::vector<JITModule> &extern_deps);
1335 
1336     int run(const void **args);
1337 
1338     ~WasmModuleContents();
1339 };
1340 
WasmModuleContents(const Module & halide_module,const std::vector<Argument> & arguments,const std::string & fn_name,const std::map<std::string,Halide::JITExtern> & jit_externs,const std::vector<JITModule> & extern_deps)1341 WasmModuleContents::WasmModuleContents(
1342     const Module &halide_module,
1343     const std::vector<Argument> &arguments,
1344     const std::string &fn_name,
1345     const std::map<std::string, Halide::JITExtern> &jit_externs,
1346     const std::vector<JITModule> &extern_deps)
1347     : target(halide_module.target()),
1348       arguments(arguments),
1349       jit_externs(jit_externs),
1350       extern_deps(extern_deps),
1351       trampolines(JITModule::make_trampolines_module(get_host_target(), jit_externs, kTrampolineSuffix, extern_deps)) {
1352 
1353 #if WITH_WABT
1354     user_assert(LLVM_VERSION >= 110) << "Using the WebAssembly JIT is only supported under LLVM 11+.";
1355 
1356     wdebug(1) << "Compiling wasm function " << fn_name << "\n";
1357 
1358     // Compile halide into wasm bytecode.
1359     std::vector<char> final_wasm = compile_to_wasm(halide_module, fn_name);
1360 
1361     store = wabt::interp::Store(calc_features(halide_module.target()));
1362 
1363     // Create a wabt Module for it.
1364     wabt::MemoryStream log_stream;
1365     constexpr bool kReadDebugNames = true;
1366     constexpr bool kStopOnFirstError = true;
1367     constexpr bool kFailOnCustomSectionError = true;
1368     wabt::ReadBinaryOptions options(store.features(),
1369                                     &log_stream,
1370                                     kReadDebugNames,
1371                                     kStopOnFirstError,
1372                                     kFailOnCustomSectionError);
1373     wabt::Errors errors;
1374     wabt::interp::ModuleDesc module_desc;
1375     wabt::Result r = wabt::interp::ReadBinaryInterp(final_wasm.data(),
1376                                                     final_wasm.size(),
1377                                                     options,
1378                                                     &errors,
1379                                                     &module_desc);
1380     internal_assert(Succeeded(r))
1381         << "ReadBinaryInterp failed:\n"
1382         << wabt::FormatErrorsToString(errors, wabt::Location::Type::Binary) << "\n"
1383         << "  log: " << to_string(log_stream) << "\n";
1384 
1385     if (WASM_DEBUG_LEVEL >= 2) {
1386         wabt::MemoryStream dis_stream;
1387         module_desc.istream.Disassemble(&dis_stream);
1388         wdebug(WASM_DEBUG_LEVEL) << "Disassembly:\n"
1389                                  << to_string(dis_stream) << "\n";
1390     }
1391 
1392     module = wabt::interp::Module::New(store, module_desc);
1393 
1394     // Bind all imports to our callbacks.
1395     wabt::interp::RefVec imports;
1396     const HostCallbackMap &host_callback_map = get_host_callback_map();
1397     for (const auto &import : module->desc().imports) {
1398         wdebug(1) << "import=" << import.type.module << "." << import.type.name << "\n";
1399         if (import.type.type->kind == wabt::interp::ExternKind::Func && import.type.module == "env") {
1400             auto it = host_callback_map.find(import.type.name);
1401             if (it != host_callback_map.end()) {
1402                 auto func_type = *wabt::cast<wabt::interp::FuncType>(import.type.type.get());
1403                 auto host_func = wabt::interp::HostFunc::New(store, func_type, it->second);
1404                 imports.push_back(host_func.ref());
1405                 continue;
1406             }
1407 
1408             // If it's not one of the standard host callbacks, assume it must be
1409             // a define_extern, and look for it in the jit_externs.
1410             auto host_func = make_extern_callback(store, jit_externs, trampolines, import);
1411             imports.push_back(host_func.ref());
1412             continue;
1413         }
1414         // By default, just push a null reference. This won't resolve, and
1415         // instantiation will fail.
1416         imports.push_back(wabt::interp::Ref::Null);
1417     }
1418 
1419     wabt::interp::RefPtr<wabt::interp::Trap> trap;
1420     instance = wabt::interp::Instance::Instantiate(store, module.ref(), imports, &trap);
1421     internal_assert(instance) << "Error initializing module: " << trap->message() << "\n";
1422 
1423     int32_t heap_base = -1;
1424 
1425     for (const auto &e : module_desc.exports) {
1426         if (e.type.name == "__heap_base") {
1427             internal_assert(e.type.type->kind == wabt::ExternalKind::Global);
1428             heap_base = store.UnsafeGet<wabt::interp::Global>(instance->globals()[e.index])->Get().Get<int32_t>();
1429             wdebug(1) << "__heap_base is " << heap_base << "\n";
1430             continue;
1431         }
1432         if (e.type.name == "memory") {
1433             internal_assert(e.type.type->kind == wabt::ExternalKind::Memory);
1434             internal_assert(!memory.get()) << "Expected exactly one memory object but saw " << (void *)memory.get();
1435             memory = store.UnsafeGet<wabt::interp::Memory>(instance->memories()[e.index]);
1436             wdebug(1) << "heap_size is " << memory->ByteSize() << "\n";
1437             continue;
1438         }
1439     }
1440     internal_assert(heap_base >= 0) << "__heap_base not found";
1441     internal_assert(memory->ByteSize() > 0) << "memory size is unlikely";
1442 
1443     bdmalloc.init(memory->ByteSize(), heap_base);
1444 
1445 #endif  // WITH_WABT
1446 }
1447 
run(const void ** args)1448 int WasmModuleContents::run(const void **args) {
1449 #if WITH_WABT
1450     const auto &module_desc = module->desc();
1451 
1452     wabt::interp::FuncType *func_type = nullptr;
1453     wabt::interp::RefPtr<wabt::interp::Func> func;
1454     std::string func_name;
1455 
1456     for (const auto &e : module_desc.exports) {
1457         if (e.type.type->kind == wabt::ExternalKind::Func) {
1458             wdebug(1) << "Selecting export '" << e.type.name << "'\n";
1459             internal_assert(!func_type && !func) << "Multiple exported funcs found";
1460             func_type = wabt::cast<wabt::interp::FuncType>(e.type.type.get());
1461             func = store.UnsafeGet<wabt::interp::Func>(instance->funcs()[e.index]);
1462             func_name = e.type.name;
1463             continue;
1464         }
1465     }
1466 
1467     JITUserContext *jit_user_context = nullptr;
1468     for (size_t i = 0; i < arguments.size(); i++) {
1469         const Argument &arg = arguments[i];
1470         const void *arg_ptr = args[i];
1471         if (arg.name == "__user_context") {
1472             jit_user_context = *(JITUserContext **)const_cast<void *>(arg_ptr);
1473         }
1474     }
1475 
1476     WabtContext wabt_context(jit_user_context, *memory, bdmalloc);
1477 
1478     wabt::interp::Values wabt_args;
1479     wabt::interp::Values wabt_results;
1480     wabt::interp::Trap::Ptr trap;
1481 
1482     std::vector<wasm32_ptr_t> wbufs(arguments.size(), 0);
1483 
1484     for (size_t i = 0; i < arguments.size(); i++) {
1485         const Argument &arg = arguments[i];
1486         const void *arg_ptr = args[i];
1487         if (arg.is_buffer()) {
1488             halide_buffer_t *buf = (halide_buffer_t *)const_cast<void *>(arg_ptr);
1489             // It's OK for this to be null (let Halide asserts handle it)
1490             wasm32_ptr_t wbuf = hostbuf_to_wasmbuf(wabt_context, buf);
1491             wbufs[i] = wbuf;
1492             wabt_args.push_back(load_value(wbuf));
1493         } else {
1494             if (arg.name == "__user_context") {
1495                 wabt_args.push_back(wabt::interp::Value::Make(kMagicJitUserContextValue));
1496             } else {
1497                 wabt_args.push_back(load_value(arg.type, arg_ptr));
1498             }
1499         }
1500     }
1501 
1502     wabt::interp::Thread::Options options;
1503     wabt::interp::Thread::Ptr thread = wabt::interp::Thread::New(store, options);
1504     thread->set_host_info(&wabt_context);
1505 
1506     auto r = func->Call(*thread, wabt_args, wabt_results, &trap);
1507     if (WASM_DEBUG_LEVEL >= 2) {
1508         wabt::MemoryStream call_stream;
1509         WriteCall(&call_stream, func_name, *func_type, wabt_args, wabt_results, trap);
1510         wdebug(WASM_DEBUG_LEVEL) << to_string(call_stream) << "\n";
1511     }
1512     internal_assert(Succeeded(r)) << "Func::Call failed: " << trap->message() << "\n";
1513     internal_assert(wabt_results.size() == 1);
1514     int32_t result = wabt_results[0].Get<int32_t>();
1515 
1516     wdebug(1) << "Result is " << result << "\n";
1517 
1518     if (result == 0) {
1519         // Update any output buffers
1520         for (size_t i = 0; i < arguments.size(); i++) {
1521             const Argument &arg = arguments[i];
1522             const void *arg_ptr = args[i];
1523             if (arg.is_buffer()) {
1524                 halide_buffer_t *buf = (halide_buffer_t *)const_cast<void *>(arg_ptr);
1525                 copy_wasmbuf_to_existing_hostbuf(wabt_context, wbufs[i], buf);
1526             }
1527         }
1528     }
1529 
1530     for (wasm32_ptr_t p : wbufs) {
1531         wabt_free(wabt_context, p);
1532     }
1533 
1534     // Don't do this: things allocated by Halide runtime might need to persist
1535     // between multiple invocations of the same function.
1536     // bdmalloc.reset();
1537 
1538     return result;
1539 
1540 #endif
1541 
1542     internal_error << "WasmExecutor is not configured correctly";
1543     return -1;
1544 }
1545 
~WasmModuleContents()1546 WasmModuleContents::~WasmModuleContents() {
1547 #if WITH_WABT
1548     // nothing
1549 #endif
1550 }
1551 
1552 template<>
ref_count(const WasmModuleContents * p)1553 RefCount &ref_count<WasmModuleContents>(const WasmModuleContents *p) noexcept {
1554     return p->ref_count;
1555 }
1556 
1557 template<>
destroy(const WasmModuleContents * p)1558 void destroy<WasmModuleContents>(const WasmModuleContents *p) {
1559     delete p;
1560 }
1561 
1562 /*static*/
can_jit_target(const Target & target)1563 bool WasmModule::can_jit_target(const Target &target) {
1564 #if WITH_WABT
1565     if (target.arch == Target::WebAssembly) {
1566         return true;
1567     }
1568 #endif
1569     return false;
1570 }
1571 
1572 /*static*/
compile(const Module & module,const std::vector<Argument> & arguments,const std::string & fn_name,const std::map<std::string,Halide::JITExtern> & jit_externs,const std::vector<JITModule> & extern_deps)1573 WasmModule WasmModule::compile(
1574     const Module &module,
1575     const std::vector<Argument> &arguments,
1576     const std::string &fn_name,
1577     const std::map<std::string, Halide::JITExtern> &jit_externs,
1578     const std::vector<JITModule> &extern_deps) {
1579 #if !defined(WITH_WABT)
1580     user_error << "Cannot run JITted WebAssembly without configuring a WebAssembly engine.";
1581     return WasmModule();
1582 #endif
1583 
1584     WasmModule wasm_module;
1585     wasm_module.contents = new WasmModuleContents(module, arguments, fn_name, jit_externs, extern_deps);
1586     return wasm_module;
1587 }
1588 
1589 /** Run generated previously compiled wasm code with a set of arguments. */
run(const void ** args)1590 int WasmModule::run(const void **args) {
1591     internal_assert(contents.defined());
1592     return contents->run(args);
1593 }
1594 
1595 }  // namespace Internal
1596 }  // namespace Halide
1597