1 #pragma once
2 
3 #include "wasm-rt.h"
4 
5 // Pull the helper header from the main repo for dynamic_check and scope_exit
6 #include "rlbox_helpers.hpp"
7 
8 #include <cstdint>
9 #include <iostream>
10 #include <limits>
11 #include <map>
12 #include <memory>
13 #include <mutex>
14 // RLBox allows applications to provide a custom shared lock implementation
15 #ifndef RLBOX_USE_CUSTOM_SHARED_LOCK
16 #  include <shared_mutex>
17 #endif
18 #include <string>
19 #include <type_traits>
20 #include <utility>
21 #include <vector>
22 
23 #if defined(_WIN32)
24 // Ensure the min/max macro in the header doesn't collide with functions in std::
25 #ifndef NOMINMAX
26 #define NOMINMAX
27 #endif
28 #include <windows.h>
29 #else
30 #include <dlfcn.h>
31 #endif
32 
33 #define RLBOX_WASM2C_UNUSED(...) (void)__VA_ARGS__
34 
35 // Use the same convention as rlbox to allow applications to customize the
36 // shared lock
37 #ifndef RLBOX_USE_CUSTOM_SHARED_LOCK
38 #  define RLBOX_SHARED_LOCK(name) std::shared_timed_mutex name
39 #  define RLBOX_ACQUIRE_SHARED_GUARD(name, ...)                                \
40     std::shared_lock<std::shared_timed_mutex> name(__VA_ARGS__)
41 #  define RLBOX_ACQUIRE_UNIQUE_GUARD(name, ...)                                \
42     std::unique_lock<std::shared_timed_mutex> name(__VA_ARGS__)
43 #else
44 #  if !defined(RLBOX_SHARED_LOCK) || !defined(RLBOX_ACQUIRE_SHARED_GUARD) ||   \
45     !defined(RLBOX_ACQUIRE_UNIQUE_GUARD)
46 #    error                                                                     \
47       "RLBOX_USE_CUSTOM_SHARED_LOCK defined but missing definitions for RLBOX_SHARED_LOCK, RLBOX_ACQUIRE_SHARED_GUARD, RLBOX_ACQUIRE_UNIQUE_GUARD"
48 #  endif
49 #endif
50 
51 namespace rlbox {
52 
53 namespace wasm2c_detail {
54 
55   template<typename T>
56   constexpr bool false_v = false;
57 
58   // https://stackoverflow.com/questions/6512019/can-we-get-the-type-of-a-lambda-argument
59   namespace return_argument_detail {
60     template<typename Ret, typename... Rest>
61     Ret helper(Ret (*)(Rest...));
62 
63     template<typename Ret, typename F, typename... Rest>
64     Ret helper(Ret (F::*)(Rest...));
65 
66     template<typename Ret, typename F, typename... Rest>
67     Ret helper(Ret (F::*)(Rest...) const);
68 
69     template<typename F>
70     decltype(helper(&F::operator())) helper(F);
71   } // namespace return_argument_detail
72 
73   template<typename T>
74   using return_argument =
75     decltype(return_argument_detail::helper(std::declval<T>()));
76 
77   ///////////////////////////////////////////////////////////////
78 
79   // https://stackoverflow.com/questions/37602057/why-isnt-a-for-loop-a-compile-time-expression
80   namespace compile_time_for_detail {
81     template<std::size_t N>
82     struct num
83     {
84       static const constexpr auto value = N;
85     };
86 
87     template<class F, std::size_t... Is>
compile_time_for_helper(F func,std::index_sequence<Is...>)88     inline void compile_time_for_helper(F func, std::index_sequence<Is...>)
89     {
90       (func(num<Is>{}), ...);
91     }
92   } // namespace compile_time_for_detail
93 
94   template<std::size_t N, typename F>
compile_time_for(F func)95   inline void compile_time_for(F func)
96   {
97     compile_time_for_detail::compile_time_for_helper(
98       func, std::make_index_sequence<N>());
99   }
100 
101   ///////////////////////////////////////////////////////////////
102 
103   template<typename T, typename = void>
104   struct convert_type_to_wasm_type
105   {
106     static_assert(std::is_void_v<T>, "Missing specialization");
107     using type = void;
108     // wasm2c has no void type so use i32 for now
109     static constexpr wasm_rt_type_t wasm2c_type = WASM_RT_I32;
110   };
111 
112   template<typename T>
113   struct convert_type_to_wasm_type<
114     T,
115     std::enable_if_t<(std::is_integral_v<T> || std::is_enum_v<T>)&&sizeof(T) <=
116                      sizeof(uint32_t)>>
117   {
118     using type = uint32_t;
119     static constexpr wasm_rt_type_t wasm2c_type = WASM_RT_I32;
120   };
121 
122   template<typename T>
123   struct convert_type_to_wasm_type<
124     T,
125     std::enable_if_t<(std::is_integral_v<T> ||
126                       std::is_enum_v<T>)&&sizeof(uint32_t) < sizeof(T) &&
127                      sizeof(T) <= sizeof(uint64_t)>>
128   {
129     using type = uint64_t;
130     static constexpr wasm_rt_type_t wasm2c_type = WASM_RT_I64;
131   };
132 
133   template<typename T>
134   struct convert_type_to_wasm_type<T,
135                                    std::enable_if_t<std::is_same_v<T, float>>>
136   {
137     using type = T;
138     static constexpr wasm_rt_type_t wasm2c_type = WASM_RT_F32;
139   };
140 
141   template<typename T>
142   struct convert_type_to_wasm_type<T,
143                                    std::enable_if_t<std::is_same_v<T, double>>>
144   {
145     using type = T;
146     static constexpr wasm_rt_type_t wasm2c_type = WASM_RT_F64;
147   };
148 
149   template<typename T>
150   struct convert_type_to_wasm_type<
151     T,
152     std::enable_if_t<std::is_pointer_v<T> || std::is_class_v<T>>>
153   {
154     // pointers are 32 bit indexes in wasm
155     // class paramters are passed as a pointer to an object in the stack or heap
156     using type = uint32_t;
157     static constexpr wasm_rt_type_t wasm2c_type = WASM_RT_I32;
158   };
159 
160   ///////////////////////////////////////////////////////////////
161 
162   namespace prepend_arg_type_detail {
163     template<typename T, typename T_ArgNew>
164     struct helper;
165 
166     template<typename T_ArgNew, typename T_Ret, typename... T_Args>
167     struct helper<T_Ret(T_Args...), T_ArgNew>
168     {
169       using type = T_Ret(T_ArgNew, T_Args...);
170     };
171   }
172 
173   template<typename T_Func, typename T_ArgNew>
174   using prepend_arg_type =
175     typename prepend_arg_type_detail::helper<T_Func, T_ArgNew>::type;
176 
177   ///////////////////////////////////////////////////////////////
178 
179   namespace change_return_type_detail {
180     template<typename T, typename T_RetNew>
181     struct helper;
182 
183     template<typename T_RetNew, typename T_Ret, typename... T_Args>
184     struct helper<T_Ret(T_Args...), T_RetNew>
185     {
186       using type = T_RetNew(T_Args...);
187     };
188   }
189 
190   template<typename T_Func, typename T_RetNew>
191   using change_return_type =
192     typename change_return_type_detail::helper<T_Func, T_RetNew>::type;
193 
194   ///////////////////////////////////////////////////////////////
195 
196   namespace change_class_arg_types_detail {
197     template<typename T, typename T_ArgNew>
198     struct helper;
199 
200     template<typename T_ArgNew, typename T_Ret, typename... T_Args>
201     struct helper<T_Ret(T_Args...), T_ArgNew>
202     {
203       using type =
204         T_Ret(std::conditional_t<std::is_class_v<T_Args>, T_ArgNew, T_Args>...);
205     };
206   }
207 
208   template<typename T_Func, typename T_ArgNew>
209   using change_class_arg_types =
210     typename change_class_arg_types_detail::helper<T_Func, T_ArgNew>::type;
211 
212 } // namespace wasm2c_detail
213 
214 class rlbox_wasm2c_sandbox;
215 
216 struct rlbox_wasm2c_sandbox_thread_data
217 {
218   rlbox_wasm2c_sandbox* sandbox;
219   uint32_t last_callback_invoked;
220 };
221 
222 #ifdef RLBOX_EMBEDDER_PROVIDES_TLS_STATIC_VARIABLES
223 
224 rlbox_wasm2c_sandbox_thread_data* get_rlbox_wasm2c_sandbox_thread_data();
225 #  define RLBOX_WASM2C_SANDBOX_STATIC_VARIABLES()                                \
226     thread_local rlbox::rlbox_wasm2c_sandbox_thread_data                         \
227       rlbox_wasm2c_sandbox_thread_info{ 0, 0 };                                  \
228     namespace rlbox {                                                          \
229       rlbox_wasm2c_sandbox_thread_data* get_rlbox_wasm2c_sandbox_thread_data()     \
230       {                                                                        \
231         return &rlbox_wasm2c_sandbox_thread_info;                                \
232       }                                                                        \
233     }                                                                          \
234     static_assert(true, "Enforce semi-colon")
235 
236 #endif
237 
238 class rlbox_wasm2c_sandbox
239 {
240 public:
241   using T_LongLongType = int64_t;
242   using T_LongType = int32_t;
243   using T_IntType = int32_t;
244   using T_PointerType = uint32_t;
245   using T_ShortType = int16_t;
246 
247 private:
248   void* sandbox = nullptr;
249   wasm2c_sandbox_funcs_t sandbox_info;
250 #if !defined(_MSC_VER)
251 __attribute__((weak))
252 #endif
253   static std::once_flag wasm2c_runtime_initialized;
254   wasm_rt_memory_t* sandbox_memory_info = nullptr;
255 #ifndef RLBOX_USE_STATIC_CALLS
256   void* library = nullptr;
257 #endif
258   uintptr_t heap_base;
259   void* exec_env = 0;
260   void* malloc_index = 0;
261   void* free_index = 0;
262   size_t return_slot_size = 0;
263   T_PointerType return_slot = 0;
264 
265   static const size_t MAX_CALLBACKS = 128;
266   mutable RLBOX_SHARED_LOCK(callback_mutex);
267   void* callback_unique_keys[MAX_CALLBACKS]{ 0 };
268   void* callbacks[MAX_CALLBACKS]{ 0 };
269   uint32_t callback_slot_assignment[MAX_CALLBACKS]{ 0 };
270   mutable std::map<const void*, uint32_t> internal_callbacks;
271   mutable std::map<uint32_t, const void*> slot_assignments;
272 
273 #ifndef RLBOX_EMBEDDER_PROVIDES_TLS_STATIC_VARIABLES
274   thread_local static inline rlbox_wasm2c_sandbox_thread_data thread_data{ 0, 0 };
275 #endif
276 
277   template<typename T_FormalRet, typename T_ActualRet>
serialize_to_sandbox(T_ActualRet arg)278   inline auto serialize_to_sandbox(T_ActualRet arg)
279   {
280     if constexpr (std::is_class_v<T_FormalRet>) {
281       // structs returned as pointers into wasm memory/wasm stack
282       auto ptr = reinterpret_cast<T_FormalRet*>(
283         impl_get_unsandboxed_pointer<T_FormalRet*>(arg));
284       T_FormalRet ret = *ptr;
285       return ret;
286     } else {
287       return arg;
288     }
289   }
290 
291   template<uint32_t N, typename T_Ret, typename... T_Args>
292   static typename wasm2c_detail::convert_type_to_wasm_type<T_Ret>::type
callback_interceptor(void *,typename wasm2c_detail::convert_type_to_wasm_type<T_Args>::type...params)293   callback_interceptor(
294     void* /* vmContext */,
295     typename wasm2c_detail::convert_type_to_wasm_type<T_Args>::type... params)
296   {
297 #ifdef RLBOX_EMBEDDER_PROVIDES_TLS_STATIC_VARIABLES
298     auto& thread_data = *get_rlbox_wasm2c_sandbox_thread_data();
299 #endif
300     thread_data.last_callback_invoked = N;
301     using T_Func = T_Ret (*)(T_Args...);
302     T_Func func;
303     {
304 #ifndef RLBOX_SINGLE_THREADED_INVOCATIONS
305       RLBOX_ACQUIRE_SHARED_GUARD(lock, thread_data.sandbox->callback_mutex);
306 #endif
307       func = reinterpret_cast<T_Func>(thread_data.sandbox->callbacks[N]);
308     }
309     // Callbacks are invoked through function pointers, cannot use std::forward
310     // as we don't have caller context for T_Args, which means they are all
311     // effectively passed by value
312     return func(thread_data.sandbox->serialize_to_sandbox<T_Args>(params)...);
313   }
314 
315   template<uint32_t N, typename T_Ret, typename... T_Args>
callback_interceptor_promoted(void *,typename wasm2c_detail::convert_type_to_wasm_type<T_Ret>::type ret,typename wasm2c_detail::convert_type_to_wasm_type<T_Args>::type...params)316   static void callback_interceptor_promoted(
317     void* /* vmContext */,
318     typename wasm2c_detail::convert_type_to_wasm_type<T_Ret>::type ret,
319     typename wasm2c_detail::convert_type_to_wasm_type<T_Args>::type... params)
320   {
321 #ifdef RLBOX_EMBEDDER_PROVIDES_TLS_STATIC_VARIABLES
322     auto& thread_data = *get_rlbox_wasm2c_sandbox_thread_data();
323 #endif
324     thread_data.last_callback_invoked = N;
325     using T_Func = T_Ret (*)(T_Args...);
326     T_Func func;
327     {
328 #ifndef RLBOX_SINGLE_THREADED_INVOCATIONS
329       RLBOX_ACQUIRE_SHARED_GUARD(lock, thread_data.sandbox->callback_mutex);
330 #endif
331       func = reinterpret_cast<T_Func>(thread_data.sandbox->callbacks[N]);
332     }
333     // Callbacks are invoked through function pointers, cannot use std::forward
334     // as we don't have caller context for T_Args, which means they are all
335     // effectively passed by value
336     auto ret_val =
337       func(thread_data.sandbox->serialize_to_sandbox<T_Args>(params)...);
338     // Copy the return value back
339     auto ret_ptr = reinterpret_cast<T_Ret*>(
340       thread_data.sandbox->template impl_get_unsandboxed_pointer<T_Ret*>(ret));
341     *ret_ptr = ret_val;
342   }
343 
344   template<typename T_Ret, typename... T_Args>
get_wasm2c_func_index(T_Ret (*)(T_Args...)=nullptr) const345   inline uint32_t get_wasm2c_func_index(
346     // dummy for template inference
347     T_Ret (*)(T_Args...) = nullptr
348   ) const
349   {
350     // Class return types as promoted to args
351     constexpr bool promoted = std::is_class_v<T_Ret>;
352 
353     // If return type is void, then there is no return type
354     // But it is fine if we add it anyway as it as at the end of the array
355     // and we pass in counts to lookup_wasm2c_func_index that would result in this
356     // element not being accessed
357     wasm_rt_type_t ret_param_types[] = {
358       wasm2c_detail::convert_type_to_wasm_type<T_Args>::wasm2c_type...,
359       wasm2c_detail::convert_type_to_wasm_type<T_Ret>::wasm2c_type
360     };
361 
362     uint32_t param_count = 0;
363     uint32_t ret_count = 0;
364 
365     if constexpr (promoted) {
366       param_count = sizeof...(T_Args) + 1;
367       ret_count = 0;
368     } else {
369       param_count = sizeof...(T_Args);
370       ret_count = std::is_void_v<T_Ret>? 0 : 1;
371     }
372 
373     auto ret = sandbox_info.lookup_wasm2c_func_index(sandbox, param_count, ret_count, ret_param_types);
374     return ret;
375   }
376 
ensure_return_slot_size(size_t size)377   void ensure_return_slot_size(size_t size)
378   {
379     if (size > return_slot_size) {
380       if (return_slot_size) {
381         impl_free_in_sandbox(return_slot);
382       }
383       return_slot = impl_malloc_in_sandbox(size);
384       detail::dynamic_check(
385         return_slot != 0,
386         "Error initializing return slot. Sandbox may be out of memory!");
387       return_slot_size = size;
388     }
389   }
390 
391 #ifndef RLBOX_USE_STATIC_CALLS
symbol_lookup(std::string prefixed_name)392   inline void* symbol_lookup(std::string prefixed_name) {
393     #if defined(_WIN32)
394       void* ret = (void*) GetProcAddress((HMODULE) library, prefixed_name.c_str());
395     #else
396       void* ret = dlsym(library, prefixed_name.c_str());
397     #endif
398     if (ret == nullptr) {
399       // Some lookups such as globals are not exposed as shared library symbols
400       uint32_t* heap_index_pointer = (uint32_t*) sandbox_info.lookup_wasm2c_nonfunc_export(sandbox, prefixed_name.c_str());
401       if (heap_index_pointer != nullptr) {
402         uint32_t heap_index = *heap_index_pointer;
403         ret = &(reinterpret_cast<char*>(heap_base)[heap_index]);
404       }
405     }
406     return ret;
407   }
408 #endif
409 
410   // function takes a 32-bit value and returns the next power of 2
411   // return is a 64-bit value as large 32-bit values will return 2^32
next_power_of_two(uint32_t value)412   static inline uint64_t next_power_of_two(uint32_t value) {
413     uint64_t power = 1;
414     while(power < value) {
415       power *= 2;
416     }
417     return power;
418   }
419 
420 public:
421 
422 #define WASM_PAGE_SIZE 65536
423 #define WASM_HEAP_MAX_ALLOWED_PAGES 65536
424 #define WASM_MAX_HEAP (static_cast<uint64_t>(1) << 32)
rlbox_wasm2c_get_adjusted_heap_size(uint64_t heap_size)425   static uint64_t rlbox_wasm2c_get_adjusted_heap_size(uint64_t heap_size)
426   {
427     if (heap_size == 0){
428       return 0;
429     }
430 
431     if(heap_size <= WASM_PAGE_SIZE) {
432       return WASM_PAGE_SIZE;
433     } else if (heap_size >= WASM_MAX_HEAP) {
434       return WASM_MAX_HEAP;
435     }
436 
437     return next_power_of_two(static_cast<uint32_t>(heap_size));
438   }
439 
rlbox_wasm2c_get_heap_page_count(uint64_t heap_size)440   static uint64_t rlbox_wasm2c_get_heap_page_count(uint64_t heap_size)
441   {
442     const uint64_t pages = heap_size / WASM_PAGE_SIZE;
443     return pages;
444   }
445 #undef WASM_MAX_HEAP
446 #undef WASM_HEAP_MAX_ALLOWED_PAGES
447 #undef WASM_PAGE_SIZE
448 
449 protected:
450 
451 #ifndef RLBOX_USE_STATIC_CALLS
impl_lookup_symbol(const char * func_name)452   void* impl_lookup_symbol(const char* func_name)
453   {
454     std::string prefixed_name = "w2c_";
455     prefixed_name += func_name;
456     void* ret = symbol_lookup(prefixed_name);
457     return ret;
458   }
459 #else
460 
461   #define rlbox_wasm2c_sandbox_lookup_symbol(func_name)                            \
462   reinterpret_cast<void*>(&w2c_##func_name) /* NOLINT */
463 
464   // adding a template so that we can use static_assert to fire only if this
465   // function is invoked
466   template<typename T = void>
467   void* impl_lookup_symbol(const char* func_name)
468   {
469     constexpr bool fail = std::is_same_v<T, void>;
470     static_assert(
471       !fail,
472       "The wasm2c_sandbox uses static calls and thus developers should add\n\n"
473       "#define RLBOX_USE_STATIC_CALLS() rlbox_wasm2c_sandbox_lookup_symbol\n\n"
474       "to their code, to ensure that static calls are handled correctly.");
475     return nullptr;
476   }
477 #endif
478 
479   #if defined(_WIN32)
480   using path_buf = const LPCWSTR;
481   #else
482   using path_buf = const char*;
483   #endif
484 
485 #define FALLIBLE_DYNAMIC_CHECK(infallible, cond, msg) \
486   if (infallible) {                                   \
487     detail::dynamic_check(cond, msg);                 \
488   } else if(!(cond)) {                                \
489     impl_destroy_sandbox();                           \
490     return false;                                     \
491   }
492 
493   /**
494    * @brief creates the Wasm sandbox from the given shared library
495    *
496    * @param wasm2c_module_path path to shared library compiled with wasm2c. This param is not specified if you are creating a statically linked sandbox.
497    * @param infallible if set to true, the sandbox aborts on failure. If false, the sandbox returns creation status as a return value
498    * @param override_max_heap_size optional override of the maximum size of the wasm heap allowed for this sandbox instance. When the value is zero, platform defaults are used. Non-zero values are rounded to max(64k, next power of 2).
499    * @param wasm_module_name optional module name used when compiling with wasm2c
500    * @return true when sandbox is successfully created
501    * @return false when infallible if set to false and sandbox was not successfully created. If infallible is set to true, this function will never return false.
502    */
impl_create_sandbox(path_buf wasm2c_module_path,bool infallible=true,uint64_t override_max_heap_size=0,const char * wasm_module_name="")503   inline bool impl_create_sandbox(
504 #ifndef RLBOX_USE_STATIC_CALLS
505     path_buf wasm2c_module_path,
506 #endif
507     bool infallible = true, uint64_t override_max_heap_size = 0, const char* wasm_module_name = "")
508   {
509     FALLIBLE_DYNAMIC_CHECK(infallible, sandbox == nullptr, "Sandbox already initialized");
510 
511 #ifndef RLBOX_USE_STATIC_CALLS
512     #if defined(_WIN32)
513     library = (void*) LoadLibraryW(wasm2c_module_path);
514     #else
515     library = dlopen(wasm2c_module_path, RTLD_LAZY);
516     #endif
517 
518     if (!library) {
519       std::string error_msg = "Could not load wasm2c dynamic library: ";
520       #if defined(_WIN32)
521         DWORD errorMessageID  = GetLastError();
522         if (errorMessageID != 0) {
523           LPSTR messageBuffer = nullptr;
524           //The api creates the buffer that holds the message
525           size_t size = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
526                                       NULL, errorMessageID, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&messageBuffer, 0, NULL);
527           //Copy the error message into a std::string.
528           std::string message(messageBuffer, size);
529           error_msg += message;
530           LocalFree(messageBuffer);
531         }
532       #else
533         error_msg += dlerror();
534       #endif
535       FALLIBLE_DYNAMIC_CHECK(infallible, false, error_msg.c_str());
536     }
537 #endif
538 
539 #ifndef RLBOX_USE_STATIC_CALLS
540     std::string info_func_name = wasm_module_name;
541     info_func_name += "get_wasm2c_sandbox_info";
542     auto get_info_func = reinterpret_cast<wasm2c_sandbox_funcs_t(*)()>(symbol_lookup(info_func_name));
543 #else
544     // only permitted if there is no custom module name
545     std::string wasm_module_name_str = wasm_module_name;
546     FALLIBLE_DYNAMIC_CHECK(infallible, wasm_module_name_str.empty(), "Static calls not supported with non empty module names");
547     auto get_info_func = reinterpret_cast<wasm2c_sandbox_funcs_t(*)()>(get_wasm2c_sandbox_info);
548 #endif
549     FALLIBLE_DYNAMIC_CHECK(infallible, get_info_func != nullptr, "wasm2c could not find <MODULE_NAME>get_wasm2c_sandbox_info");
550     sandbox_info = get_info_func();
551 
552     std::call_once(wasm2c_runtime_initialized, [&](){
553       sandbox_info.wasm_rt_sys_init();
554     });
555 
556     override_max_heap_size = rlbox_wasm2c_get_adjusted_heap_size(override_max_heap_size);
557     const uint64_t override_max_wasm_pages = rlbox_wasm2c_get_heap_page_count(override_max_heap_size);
558     FALLIBLE_DYNAMIC_CHECK(infallible, override_max_wasm_pages <= 65536, "Wasm allows a max heap size of 4GB");
559 
560     sandbox = sandbox_info.create_wasm2c_sandbox(static_cast<uint32_t>(override_max_wasm_pages));
561     FALLIBLE_DYNAMIC_CHECK(infallible, sandbox != nullptr, "Sandbox could not be created");
562 
563     sandbox_memory_info = (wasm_rt_memory_t*) sandbox_info.lookup_wasm2c_nonfunc_export(sandbox, "w2c_memory");
564     FALLIBLE_DYNAMIC_CHECK(infallible, sandbox_memory_info != nullptr, "Could not get wasm2c sandbox memory info");
565 
566     heap_base = reinterpret_cast<uintptr_t>(impl_get_memory_location());
567 
568     if constexpr (sizeof(uintptr_t) != sizeof(uint32_t)) {
569       // On larger platforms, check that the heap is aligned to the pointer size
570       // i.e. 32-bit pointer => aligned to 4GB. The implementations of
571       // impl_get_unsandboxed_pointer_no_ctx and impl_get_sandboxed_pointer_no_ctx
572       // below rely on this.
573       uintptr_t heap_offset_mask = std::numeric_limits<T_PointerType>::max();
574       FALLIBLE_DYNAMIC_CHECK(infallible, (heap_base & heap_offset_mask) == 0,
575                             "Sandbox heap not aligned to 4GB");
576     }
577 
578     // cache these for performance
579     exec_env = sandbox;
580 #ifndef RLBOX_USE_STATIC_CALLS
581     malloc_index = impl_lookup_symbol("malloc");
582     free_index = impl_lookup_symbol("free");
583 #else
584     malloc_index = rlbox_wasm2c_sandbox_lookup_symbol(malloc);
585     free_index = rlbox_wasm2c_sandbox_lookup_symbol(free);
586 #endif
587     return true;
588   }
589 
590 #undef FALLIBLE_DYNAMIC_CHECK
591 
impl_destroy_sandbox()592   inline void impl_destroy_sandbox()
593   {
594     if (return_slot_size) {
595       impl_free_in_sandbox(return_slot);
596     }
597 
598     if (sandbox != nullptr) {
599       sandbox_info.destroy_wasm2c_sandbox(sandbox);
600       sandbox = nullptr;
601     }
602 
603 #ifndef RLBOX_USE_STATIC_CALLS
604     if (library != nullptr) {
605       #if defined(_WIN32)
606         FreeLibrary((HMODULE) library);
607       #else
608         dlclose(library);
609       #endif
610       library = nullptr;
611     }
612 #endif
613   }
614 
615   template<typename T>
impl_get_unsandboxed_pointer(T_PointerType p) const616   inline void* impl_get_unsandboxed_pointer(T_PointerType p) const
617   {
618     if constexpr (std::is_function_v<std::remove_pointer_t<T>>) {
619       RLBOX_ACQUIRE_UNIQUE_GUARD(lock, callback_mutex);
620       auto found = slot_assignments.find(p);
621       if (found != slot_assignments.end()) {
622         auto ret = found->second;
623         return const_cast<void*>(ret);
624       } else {
625         return nullptr;
626       }
627     } else {
628       return reinterpret_cast<void*>(heap_base + p);
629     }
630   }
631 
632   template<typename T>
impl_get_sandboxed_pointer(const void * p) const633   inline T_PointerType impl_get_sandboxed_pointer(const void* p) const
634   {
635     if constexpr (std::is_function_v<std::remove_pointer_t<T>>) {
636       RLBOX_ACQUIRE_UNIQUE_GUARD(lock, callback_mutex);
637 
638       uint32_t slot_number = 0;
639       auto found = internal_callbacks.find(p);
640       if (found != internal_callbacks.end()) {
641         slot_number = found->second;
642       } else {
643 
644         auto func_type_idx = get_wasm2c_func_index(static_cast<T>(nullptr));
645         slot_number =
646           sandbox_info.add_wasm2c_callback(sandbox, func_type_idx, const_cast<void*>(p), WASM_RT_INTERNAL_FUNCTION);
647         internal_callbacks[p] = slot_number;
648         slot_assignments[slot_number] = p;
649       }
650       return static_cast<T_PointerType>(slot_number);
651     } else {
652       if constexpr (sizeof(uintptr_t) == sizeof(uint32_t)) {
653         return static_cast<T_PointerType>(reinterpret_cast<uintptr_t>(p) - heap_base);
654       } else {
655         return static_cast<T_PointerType>(reinterpret_cast<uintptr_t>(p));
656       }
657     }
658   }
659 
660   template<typename T>
impl_get_unsandboxed_pointer_no_ctx(T_PointerType p,const void * example_unsandboxed_ptr,rlbox_wasm2c_sandbox * (* expensive_sandbox_finder)(const void * example_unsandboxed_ptr))661   static inline void* impl_get_unsandboxed_pointer_no_ctx(
662     T_PointerType p,
663     const void* example_unsandboxed_ptr,
664     rlbox_wasm2c_sandbox* (*expensive_sandbox_finder)(
665       const void* example_unsandboxed_ptr))
666   {
667     // on 32-bit platforms we don't assume the heap is aligned
668     if constexpr (sizeof(uintptr_t) == sizeof(uint32_t)) {
669       auto sandbox = expensive_sandbox_finder(example_unsandboxed_ptr);
670       return sandbox->impl_get_unsandboxed_pointer<T>(p);
671     } else {
672       if constexpr (std::is_function_v<std::remove_pointer_t<T>>) {
673         // swizzling function pointers needs access to the function pointer tables
674         // and thus cannot be done without context
675         auto sandbox = expensive_sandbox_finder(example_unsandboxed_ptr);
676         return sandbox->impl_get_unsandboxed_pointer<T>(p);
677       } else {
678         // grab the memory base from the example_unsandboxed_ptr
679         uintptr_t heap_base_mask =
680           std::numeric_limits<uintptr_t>::max() &
681           ~(static_cast<uintptr_t>(std::numeric_limits<T_PointerType>::max()));
682         uintptr_t computed_heap_base =
683           reinterpret_cast<uintptr_t>(example_unsandboxed_ptr) & heap_base_mask;
684         uintptr_t ret = computed_heap_base | p;
685         return reinterpret_cast<void*>(ret);
686       }
687     }
688   }
689 
690   template<typename T>
impl_get_sandboxed_pointer_no_ctx(const void * p,const void * example_unsandboxed_ptr,rlbox_wasm2c_sandbox * (* expensive_sandbox_finder)(const void * example_unsandboxed_ptr))691   static inline T_PointerType impl_get_sandboxed_pointer_no_ctx(
692     const void* p,
693     const void* example_unsandboxed_ptr,
694     rlbox_wasm2c_sandbox* (*expensive_sandbox_finder)(
695       const void* example_unsandboxed_ptr))
696   {
697     // on 32-bit platforms we don't assume the heap is aligned
698     if constexpr (sizeof(uintptr_t) == sizeof(uint32_t)) {
699       auto sandbox = expensive_sandbox_finder(example_unsandboxed_ptr);
700       return sandbox->impl_get_sandboxed_pointer<T>(p);
701     } else {
702       if constexpr (std::is_function_v<std::remove_pointer_t<T>>) {
703         // swizzling function pointers needs access to the function pointer tables
704         // and thus cannot be done without context
705         auto sandbox = expensive_sandbox_finder(example_unsandboxed_ptr);
706         return sandbox->impl_get_sandboxed_pointer<T>(p);
707       } else {
708         // Just clear the memory base to leave the offset
709         RLBOX_WASM2C_UNUSED(example_unsandboxed_ptr);
710         uintptr_t ret = reinterpret_cast<uintptr_t>(p) &
711                         std::numeric_limits<T_PointerType>::max();
712         return static_cast<T_PointerType>(ret);
713       }
714     }
715   }
716 
impl_is_in_same_sandbox(const void * p1,const void * p2)717   static inline bool impl_is_in_same_sandbox(const void* p1, const void* p2)
718   {
719     uintptr_t heap_base_mask = std::numeric_limits<uintptr_t>::max() &
720                                ~(std::numeric_limits<T_PointerType>::max());
721     return (reinterpret_cast<uintptr_t>(p1) & heap_base_mask) ==
722            (reinterpret_cast<uintptr_t>(p2) & heap_base_mask);
723   }
724 
impl_is_pointer_in_sandbox_memory(const void * p)725   inline bool impl_is_pointer_in_sandbox_memory(const void* p)
726   {
727     size_t length = impl_get_total_memory();
728     uintptr_t p_val = reinterpret_cast<uintptr_t>(p);
729     return p_val >= heap_base && p_val < (heap_base + length);
730   }
731 
impl_is_pointer_in_app_memory(const void * p)732   inline bool impl_is_pointer_in_app_memory(const void* p)
733   {
734     return !(impl_is_pointer_in_sandbox_memory(p));
735   }
736 
impl_get_total_memory()737   inline size_t impl_get_total_memory() { return sandbox_memory_info->size; }
738 
impl_get_memory_location() const739   inline void* impl_get_memory_location() const
740   {
741     return sandbox_memory_info->data;
742   }
743 
744   template<typename T, typename T_Converted, typename... T_Args>
impl_invoke_with_func_ptr(T_Converted * func_ptr,T_Args &&...params)745   auto impl_invoke_with_func_ptr(T_Converted* func_ptr, T_Args&&... params)
746   {
747 #ifdef RLBOX_EMBEDDER_PROVIDES_TLS_STATIC_VARIABLES
748     auto& thread_data = *get_rlbox_wasm2c_sandbox_thread_data();
749 #endif
750     auto old_sandbox = thread_data.sandbox;
751     thread_data.sandbox = this;
752     auto on_exit = detail::make_scope_exit([&] {
753       thread_data.sandbox = old_sandbox;
754     });
755 
756     // WASM functions are mangled in the following manner
757     // 1. All primitive types are left as is and follow an LP32 machine model
758     // (as opposed to the possibly 64-bit application)
759     // 2. All pointers are changed to u32 types
760     // 3. Returned class are returned as an out parameter before the actual
761     // function parameters
762     // 4. All class parameters are passed as pointers (u32 types)
763     // 5. The heap address is passed in as the first argument to the function
764     //
765     // RLBox accounts for the first 2 differences in T_Converted type, but we
766     // need to handle the rest
767 
768     // Handle point 3
769     using T_Ret = wasm2c_detail::return_argument<T_Converted>;
770     if constexpr (std::is_class_v<T_Ret>) {
771       using T_Conv1 = wasm2c_detail::change_return_type<T_Converted, void>;
772       using T_Conv2 = wasm2c_detail::prepend_arg_type<T_Conv1, T_PointerType>;
773       auto func_ptr_conv =
774         reinterpret_cast<T_Conv2*>(reinterpret_cast<uintptr_t>(func_ptr));
775       ensure_return_slot_size(sizeof(T_Ret));
776       impl_invoke_with_func_ptr<T>(func_ptr_conv, return_slot, params...);
777 
778       auto ptr = reinterpret_cast<T_Ret*>(
779         impl_get_unsandboxed_pointer<T_Ret*>(return_slot));
780       T_Ret ret = *ptr;
781       return ret;
782     }
783 
784     // Handle point 4
785     constexpr size_t alloc_length = [&] {
786       if constexpr (sizeof...(params) > 0) {
787         return ((std::is_class_v<T_Args> ? 1 : 0) + ...);
788       } else {
789         return 0;
790       }
791     }();
792 
793     // 0 arg functions create 0 length arrays which is not allowed
794     T_PointerType allocations_buff[alloc_length == 0 ? 1 : alloc_length];
795     T_PointerType* allocations = allocations_buff;
796 
797     auto serialize_class_arg =
798       [&](auto arg) -> std::conditional_t<std::is_class_v<decltype(arg)>,
799                                           T_PointerType,
800                                           decltype(arg)> {
801       using T_Arg = decltype(arg);
802       if constexpr (std::is_class_v<T_Arg>) {
803         auto slot = impl_malloc_in_sandbox(sizeof(T_Arg));
804         auto ptr =
805           reinterpret_cast<T_Arg*>(impl_get_unsandboxed_pointer<T_Arg*>(slot));
806         *ptr = arg;
807         allocations[0] = slot;
808         allocations++;
809         return slot;
810       } else {
811         return arg;
812       }
813     };
814 
815     // 0 arg functions don't use serialize
816     RLBOX_WASM2C_UNUSED(serialize_class_arg);
817 
818     using T_ConvNoClass =
819       wasm2c_detail::change_class_arg_types<T_Converted, T_PointerType>;
820 
821     // Handle Point 5
822     using T_ConvHeap = wasm2c_detail::prepend_arg_type<T_ConvNoClass, void*>;
823 
824     // Function invocation
825     auto func_ptr_conv =
826       reinterpret_cast<T_ConvHeap*>(reinterpret_cast<uintptr_t>(func_ptr));
827 
828     using T_NoVoidRet =
829       std::conditional_t<std::is_void_v<T_Ret>, uint32_t, T_Ret>;
830     T_NoVoidRet ret;
831 
832     if constexpr (std::is_void_v<T_Ret>) {
833       RLBOX_WASM2C_UNUSED(ret);
834       func_ptr_conv(exec_env, serialize_class_arg(params)...);
835     } else {
836       ret = func_ptr_conv(exec_env, serialize_class_arg(params)...);
837     }
838 
839     for (size_t i = 0; i < alloc_length; i++) {
840       impl_free_in_sandbox(allocations_buff[i]);
841     }
842 
843     if constexpr (!std::is_void_v<T_Ret>) {
844       return ret;
845     }
846   }
847 
impl_malloc_in_sandbox(size_t size)848   inline T_PointerType impl_malloc_in_sandbox(size_t size)
849   {
850     if constexpr(sizeof(size) > sizeof(uint32_t)) {
851       detail::dynamic_check(size <= std::numeric_limits<uint32_t>::max(),
852                             "Attempting to malloc more than the heap size");
853     }
854     using T_Func = void*(size_t);
855     using T_Converted = T_PointerType(uint32_t);
856     T_PointerType ret = impl_invoke_with_func_ptr<T_Func, T_Converted>(
857       reinterpret_cast<T_Converted*>(malloc_index),
858       static_cast<uint32_t>(size));
859     return ret;
860   }
861 
impl_free_in_sandbox(T_PointerType p)862   inline void impl_free_in_sandbox(T_PointerType p)
863   {
864     using T_Func = void(void*);
865     using T_Converted = void(T_PointerType);
866     impl_invoke_with_func_ptr<T_Func, T_Converted>(
867       reinterpret_cast<T_Converted*>(free_index), p);
868   }
869 
870   template<typename T_Ret, typename... T_Args>
impl_register_callback(void * key,void * callback)871   inline T_PointerType impl_register_callback(void* key, void* callback)
872   {
873     bool found = false;
874     uint32_t found_loc = 0;
875     void* chosen_interceptor = nullptr;
876 
877     RLBOX_ACQUIRE_UNIQUE_GUARD(lock, callback_mutex);
878 
879     // need a compile time for loop as we we need I to be a compile time value
880     // this is because we are setting the I'th callback ineterceptor
881     wasm2c_detail::compile_time_for<MAX_CALLBACKS>([&](auto I) {
882       constexpr auto i = I.value;
883       if (!found && callbacks[i] == nullptr) {
884         found = true;
885         found_loc = i;
886 
887         if constexpr (std::is_class_v<T_Ret>) {
888           chosen_interceptor = reinterpret_cast<void*>(
889             callback_interceptor_promoted<i, T_Ret, T_Args...>);
890         } else {
891           chosen_interceptor =
892             reinterpret_cast<void*>(callback_interceptor<i, T_Ret, T_Args...>);
893         }
894       }
895     });
896 
897     detail::dynamic_check(
898       found,
899       "Could not find an empty slot in sandbox function table. This would "
900       "happen if you have registered too many callbacks, or unsandboxed "
901       "too many function pointers. You can file a bug if you want to "
902       "increase the maximum allowed callbacks or unsadnboxed functions "
903       "pointers");
904 
905     auto func_type_idx = get_wasm2c_func_index<T_Ret, T_Args...>();
906     uint32_t slot_number =
907       sandbox_info.add_wasm2c_callback(sandbox, func_type_idx, chosen_interceptor, WASM_RT_EXTERNAL_FUNCTION);
908 
909     callback_unique_keys[found_loc] = key;
910     callbacks[found_loc] = callback;
911     callback_slot_assignment[found_loc] = slot_number;
912     slot_assignments[slot_number] = callback;
913 
914     return static_cast<T_PointerType>(slot_number);
915   }
916 
917   static inline std::pair<rlbox_wasm2c_sandbox*, void*>
impl_get_executed_callback_sandbox_and_key()918   impl_get_executed_callback_sandbox_and_key()
919   {
920 #ifdef RLBOX_EMBEDDER_PROVIDES_TLS_STATIC_VARIABLES
921     auto& thread_data = *get_rlbox_wasm2c_sandbox_thread_data();
922 #endif
923     auto sandbox = thread_data.sandbox;
924     auto callback_num = thread_data.last_callback_invoked;
925     void* key = sandbox->callback_unique_keys[callback_num];
926     return std::make_pair(sandbox, key);
927   }
928 
929   template<typename T_Ret, typename... T_Args>
impl_unregister_callback(void * key)930   inline void impl_unregister_callback(void* key)
931   {
932     bool found = false;
933     uint32_t i = 0;
934     {
935       RLBOX_ACQUIRE_UNIQUE_GUARD(lock, callback_mutex);
936       for (; i < MAX_CALLBACKS; i++) {
937         if (callback_unique_keys[i] == key) {
938           sandbox_info.remove_wasm2c_callback(sandbox, callback_slot_assignment[i]);
939           callback_unique_keys[i] = nullptr;
940           callbacks[i] = nullptr;
941           callback_slot_assignment[i] = 0;
942           found = true;
943           break;
944         }
945       }
946     }
947 
948     detail::dynamic_check(
949       found, "Internal error: Could not find callback to unregister");
950 
951     return;
952   }
953 };
954 
955 // declare the static symbol with weak linkage to keep this header only
956 #if defined(_MSC_VER)
957 __declspec(selectany)
958 #else
959 __attribute__((weak))
960 #endif
961 std::once_flag rlbox_wasm2c_sandbox::wasm2c_runtime_initialized;
962 
963 } // namespace rlbox
964