1 #pragma once
2 
3 #include "lucet_sandbox.h"
4 
5 #include <cstdint>
6 #include <iostream>
7 #include <limits>
8 #include <map>
9 #include <memory>
10 #include <mutex>
11 // RLBox allows applications to provide a custom shared lock implementation
12 #ifndef RLBOX_USE_CUSTOM_SHARED_LOCK
13 #  include <shared_mutex>
14 #endif
15 #include <type_traits>
16 #include <utility>
17 #include <vector>
18 
19 #define RLBOX_LUCET_UNUSED(...) (void)__VA_ARGS__
20 
21 // Use the same convention as rlbox to allow applications to customize the
22 // shared lock
23 #ifndef RLBOX_USE_CUSTOM_SHARED_LOCK
24 #  define RLBOX_SHARED_LOCK(name) std::shared_timed_mutex name
25 #  define RLBOX_ACQUIRE_SHARED_GUARD(name, ...)                                \
26     std::shared_lock<std::shared_timed_mutex> name(__VA_ARGS__)
27 #  define RLBOX_ACQUIRE_UNIQUE_GUARD(name, ...)                                \
28     std::unique_lock<std::shared_timed_mutex> name(__VA_ARGS__)
29 #else
30 #  if !defined(RLBOX_SHARED_LOCK) || !defined(RLBOX_ACQUIRE_SHARED_GUARD) ||   \
31     !defined(RLBOX_ACQUIRE_UNIQUE_GUARD)
32 #    error                                                                     \
33       "RLBOX_USE_CUSTOM_SHARED_LOCK defined but missing definitions for RLBOX_SHARED_LOCK, RLBOX_ACQUIRE_SHARED_GUARD, RLBOX_ACQUIRE_UNIQUE_GUARD"
34 #  endif
35 #endif
36 
37 namespace rlbox {
38 
39 namespace detail {
40   // relying on the dynamic check settings (exception vs abort) in the rlbox lib
41   inline void dynamic_check(bool check, const char* const msg);
42 }
43 
44 namespace lucet_detail {
45 
46   template<typename T>
47   constexpr bool false_v = false;
48 
49   // https://stackoverflow.com/questions/6512019/can-we-get-the-type-of-a-lambda-argument
50   namespace return_argument_detail {
51     template<typename Ret, typename... Rest>
52     Ret helper(Ret (*)(Rest...));
53 
54     template<typename Ret, typename F, typename... Rest>
55     Ret helper(Ret (F::*)(Rest...));
56 
57     template<typename Ret, typename F, typename... Rest>
58     Ret helper(Ret (F::*)(Rest...) const);
59 
60     template<typename F>
61     decltype(helper(&F::operator())) helper(F);
62   } // namespace return_argument_detail
63 
64   template<typename T>
65   using return_argument =
66     decltype(return_argument_detail::helper(std::declval<T>()));
67 
68   ///////////////////////////////////////////////////////////////
69 
70   // https://stackoverflow.com/questions/37602057/why-isnt-a-for-loop-a-compile-time-expression
71   namespace compile_time_for_detail {
72     template<std::size_t N>
73     struct num
74     {
75       static const constexpr auto value = N;
76     };
77 
78     template<class F, std::size_t... Is>
compile_time_for_helper(F func,std::index_sequence<Is...>)79     inline void compile_time_for_helper(F func, std::index_sequence<Is...>)
80     {
81       (func(num<Is>{}), ...);
82     }
83   } // namespace compile_time_for_detail
84 
85   template<std::size_t N, typename F>
compile_time_for(F func)86   inline void compile_time_for(F func)
87   {
88     compile_time_for_detail::compile_time_for_helper(
89       func, std::make_index_sequence<N>());
90   }
91 
92   ///////////////////////////////////////////////////////////////
93 
94   template<typename T, typename = void>
95   struct convert_type_to_wasm_type
96   {
97     static_assert(std::is_void_v<T>, "Missing specialization");
98     using type = void;
99     static constexpr enum LucetValueType lucet_type = LucetValueType_Void;
100   };
101 
102   template<typename T>
103   struct convert_type_to_wasm_type<
104     T,
105     std::enable_if_t<(std::is_integral_v<T> || std::is_enum_v<T>)&&sizeof(T) <=
106                      sizeof(uint32_t)>>
107   {
108     using type = uint32_t;
109     static constexpr enum LucetValueType lucet_type = LucetValueType_I32;
110   };
111 
112   template<typename T>
113   struct convert_type_to_wasm_type<
114     T,
115     std::enable_if_t<(std::is_integral_v<T> ||
116                       std::is_enum_v<T>)&&sizeof(uint32_t) < sizeof(T) &&
117                      sizeof(T) <= sizeof(uint64_t)>>
118   {
119     using type = uint64_t;
120     static constexpr enum LucetValueType lucet_type = LucetValueType_I64;
121   };
122 
123   template<typename T>
124   struct convert_type_to_wasm_type<T,
125                                    std::enable_if_t<std::is_same_v<T, float>>>
126   {
127     using type = T;
128     static constexpr enum LucetValueType lucet_type = LucetValueType_F32;
129   };
130 
131   template<typename T>
132   struct convert_type_to_wasm_type<T,
133                                    std::enable_if_t<std::is_same_v<T, double>>>
134   {
135     using type = T;
136     static constexpr enum LucetValueType lucet_type = LucetValueType_F64;
137   };
138 
139   template<typename T>
140   struct convert_type_to_wasm_type<
141     T,
142     std::enable_if_t<std::is_pointer_v<T> || std::is_class_v<T>>>
143   {
144     // pointers are 32 bit indexes in wasm
145     // class paramters are passed as a pointer to an object in the stack or heap
146     using type = uint32_t;
147     static constexpr enum LucetValueType lucet_type = LucetValueType_I32;
148   };
149 
150   ///////////////////////////////////////////////////////////////
151 
152   namespace prepend_arg_type_detail {
153     template<typename T, typename T_ArgNew>
154     struct helper;
155 
156     template<typename T_ArgNew, typename T_Ret, typename... T_Args>
157     struct helper<T_Ret(T_Args...), T_ArgNew>
158     {
159       using type = T_Ret(T_ArgNew, T_Args...);
160     };
161   }
162 
163   template<typename T_Func, typename T_ArgNew>
164   using prepend_arg_type =
165     typename prepend_arg_type_detail::helper<T_Func, T_ArgNew>::type;
166 
167   ///////////////////////////////////////////////////////////////
168 
169   namespace change_return_type_detail {
170     template<typename T, typename T_RetNew>
171     struct helper;
172 
173     template<typename T_RetNew, typename T_Ret, typename... T_Args>
174     struct helper<T_Ret(T_Args...), T_RetNew>
175     {
176       using type = T_RetNew(T_Args...);
177     };
178   }
179 
180   template<typename T_Func, typename T_RetNew>
181   using change_return_type =
182     typename change_return_type_detail::helper<T_Func, T_RetNew>::type;
183 
184   ///////////////////////////////////////////////////////////////
185 
186   namespace change_class_arg_types_detail {
187     template<typename T, typename T_ArgNew>
188     struct helper;
189 
190     template<typename T_ArgNew, typename T_Ret, typename... T_Args>
191     struct helper<T_Ret(T_Args...), T_ArgNew>
192     {
193       using type =
194         T_Ret(std::conditional_t<std::is_class_v<T_Args>, T_ArgNew, T_Args>...);
195     };
196   }
197 
198   template<typename T_Func, typename T_ArgNew>
199   using change_class_arg_types =
200     typename change_class_arg_types_detail::helper<T_Func, T_ArgNew>::type;
201 
202 } // namespace lucet_detail
203 
204 class rlbox_lucet_sandbox;
205 
206 struct rlbox_lucet_sandbox_thread_data
207 {
208   rlbox_lucet_sandbox* sandbox;
209   uint32_t last_callback_invoked;
210 };
211 
212 #ifdef RLBOX_EMBEDDER_PROVIDES_TLS_STATIC_VARIABLES
213 
214 rlbox_lucet_sandbox_thread_data* get_rlbox_lucet_sandbox_thread_data();
215 #  define RLBOX_LUCET_SANDBOX_STATIC_VARIABLES()                               \
216     thread_local rlbox::rlbox_lucet_sandbox_thread_data                        \
217       rlbox_lucet_sandbox_thread_info{ 0, 0 };                                 \
218     namespace rlbox {                                                          \
219       rlbox_lucet_sandbox_thread_data* get_rlbox_lucet_sandbox_thread_data()   \
220       {                                                                        \
221         return &rlbox_lucet_sandbox_thread_info;                               \
222       }                                                                        \
223     }                                                                          \
224     static_assert(true, "Enforce semi-colon")
225 
226 #endif
227 
228 class rlbox_lucet_sandbox
229 {
230 public:
231   using T_LongLongType = int64_t;
232   using T_LongType = int32_t;
233   using T_IntType = int32_t;
234   using T_PointerType = uint32_t;
235   using T_ShortType = int16_t;
236 
237 private:
238   LucetSandboxInstance* sandbox = nullptr;
239   uintptr_t heap_base;
240   void* malloc_index = 0;
241   void* free_index = 0;
242   size_t return_slot_size = 0;
243   T_PointerType return_slot = 0;
244 
245   static const size_t MAX_CALLBACKS = 128;
246   RLBOX_SHARED_LOCK(callback_mutex);
247   void* callback_unique_keys[MAX_CALLBACKS]{ 0 };
248   void* callbacks[MAX_CALLBACKS]{ 0 };
249   uint32_t callback_slot_assignment[MAX_CALLBACKS]{ 0 };
250 
251   using TableElementRef = LucetFunctionTableElement*;
252   struct FunctionTable
253   {
254     TableElementRef elements[MAX_CALLBACKS];
255     uint32_t slot_number[MAX_CALLBACKS];
256   };
257   inline static std::mutex callback_table_mutex;
258   // We need to share the callback slot info across multiple sandbox instances
259   // that may load the same sandboxed library. Thus if the sandboxed library is
260   // already in the memory space, we should just use the previously saved info
261   // as the load is destroys the callback info. Once all instances of the
262   // library is unloaded, the sandboxed library is removed from the address
263   // space and thus we can "reset" our state. The semantics of shared and weak
264   // pointers ensure this and will automatically release the memory after all
265   // instances are released.
266   inline static std::map<void*, std::weak_ptr<FunctionTable>>
267     shared_callback_slots;
268   std::shared_ptr<FunctionTable> callback_slots = nullptr;
269   // However, if the library is also loaded externally in the application, then
270   // we don't know when we can ever "reset". In such scenarios, we are better of
271   // never throwing away the callback info, rather than figuring out
272   // what/why/when the application is loading or unloading the sandboxed
273   // library. An extra reference to the shared_ptr will ensure this.
274   inline static std::vector<std::shared_ptr<FunctionTable>>
275     saved_callback_slot_info;
276 
277 #ifndef RLBOX_EMBEDDER_PROVIDES_TLS_STATIC_VARIABLES
278   thread_local static inline rlbox_lucet_sandbox_thread_data thread_data{ 0,
279                                                                           0 };
280 #endif
281 
282   template<typename T_FormalRet, typename T_ActualRet>
serialize_to_sandbox(T_ActualRet arg)283   inline auto serialize_to_sandbox(T_ActualRet arg)
284   {
285     if constexpr (std::is_class_v<T_FormalRet>) {
286       // structs returned as pointers into wasm memory/wasm stack
287       auto ptr = reinterpret_cast<T_FormalRet*>(
288         impl_get_unsandboxed_pointer<T_FormalRet*>(arg));
289       T_FormalRet ret = *ptr;
290       return ret;
291     } else {
292       return arg;
293     }
294   }
295 
get_callback_ref_data(LucetFunctionTable & functionPointerTable)296   inline std::shared_ptr<FunctionTable> get_callback_ref_data(
297     LucetFunctionTable& functionPointerTable)
298   {
299     auto callback_slots = std::make_shared<FunctionTable>();
300 
301     for (size_t i = 0; i < MAX_CALLBACKS; i++) {
302       uintptr_t reservedVal =
303         lucet_get_reserved_callback_slot_val(sandbox, i + 1);
304 
305       bool found = false;
306       for (size_t j = 0; j < functionPointerTable.length; j++) {
307         if (functionPointerTable.data[j].rf == reservedVal) {
308           functionPointerTable.data[j].rf = 0;
309           callback_slots->elements[i] = &(functionPointerTable.data[j]);
310           callback_slots->slot_number[i] = static_cast<uint32_t>(j);
311           found = true;
312           break;
313         }
314       }
315 
316       detail::dynamic_check(found, "Unable to intialize callback tables");
317     }
318 
319     return callback_slots;
320   }
321 
reinit_callback_ref_data(LucetFunctionTable & functionPointerTable,std::shared_ptr<FunctionTable> & callback_slots)322   inline void reinit_callback_ref_data(
323     LucetFunctionTable& functionPointerTable,
324     std::shared_ptr<FunctionTable>& callback_slots)
325   {
326     for (size_t i = 0; i < MAX_CALLBACKS; i++) {
327       uintptr_t reservedVal =
328         lucet_get_reserved_callback_slot_val(sandbox, i + 1);
329 
330       for (size_t j = 0; j < functionPointerTable.length; j++) {
331         if (functionPointerTable.data[j].rf == reservedVal) {
332           functionPointerTable.data[j].rf = 0;
333 
334           detail::dynamic_check(
335             callback_slots->elements[i] == &(functionPointerTable.data[j]) &&
336               callback_slots->slot_number[i] == static_cast<uint32_t>(j),
337             "Sandbox creation error: Error when checking the values of "
338             "callback slot data");
339 
340           break;
341         }
342       }
343     }
344   }
345 
set_callbacks_slots_ref(bool external_loads_exist)346   inline void set_callbacks_slots_ref(bool external_loads_exist)
347   {
348     LucetFunctionTable functionPointerTable =
349       lucet_get_function_pointer_table(sandbox);
350     void* key = functionPointerTable.data;
351 
352     std::lock_guard<std::mutex> lock(callback_table_mutex);
353     std::weak_ptr<FunctionTable> slots = shared_callback_slots[key];
354 
355     if (auto shared_slots = slots.lock()) {
356       // pointer exists
357       callback_slots = shared_slots;
358       // Sometimes, dlopen and process forking seem to act a little weird.
359       // Writes to the writable page of the dynamic lib section seem to not
360       // always be propagated (possibly when the dynamic library is opened
361       // externally - "external_loads_exist")). This occurred in when RLBox was
362       // used in ASAN builds of Firefox. In general, we take the precaution of
363       // rechecking this on each sandbox creation.
364       reinit_callback_ref_data(functionPointerTable, callback_slots);
365       return;
366     }
367 
368     callback_slots = get_callback_ref_data(functionPointerTable);
369     shared_callback_slots[key] = callback_slots;
370     if (external_loads_exist) {
371       saved_callback_slot_info.push_back(callback_slots);
372     }
373   }
374 
375   template<uint32_t N, typename T_Ret, typename... T_Args>
376   static typename lucet_detail::convert_type_to_wasm_type<T_Ret>::type
callback_interceptor(void *,typename lucet_detail::convert_type_to_wasm_type<T_Args>::type...params)377   callback_interceptor(
378     void* /* vmContext */,
379     typename lucet_detail::convert_type_to_wasm_type<T_Args>::type... params)
380   {
381 #ifdef RLBOX_EMBEDDER_PROVIDES_TLS_STATIC_VARIABLES
382     auto& thread_data = *get_rlbox_lucet_sandbox_thread_data();
383 #endif
384     thread_data.last_callback_invoked = N;
385     using T_Func = T_Ret (*)(T_Args...);
386     T_Func func;
387     {
388       RLBOX_ACQUIRE_SHARED_GUARD(lock, thread_data.sandbox->callback_mutex);
389       func = reinterpret_cast<T_Func>(thread_data.sandbox->callbacks[N]);
390     }
391     // Callbacks are invoked through function pointers, cannot use std::forward
392     // as we don't have caller context for T_Args, which means they are all
393     // effectively passed by value
394     return func(thread_data.sandbox->serialize_to_sandbox<T_Args>(params)...);
395   }
396 
397   template<uint32_t N, typename T_Ret, typename... T_Args>
callback_interceptor_promoted(void *,typename lucet_detail::convert_type_to_wasm_type<T_Ret>::type ret,typename lucet_detail::convert_type_to_wasm_type<T_Args>::type...params)398   static void callback_interceptor_promoted(
399     void* /* vmContext */,
400     typename lucet_detail::convert_type_to_wasm_type<T_Ret>::type ret,
401     typename lucet_detail::convert_type_to_wasm_type<T_Args>::type... params)
402   {
403 #ifdef RLBOX_EMBEDDER_PROVIDES_TLS_STATIC_VARIABLES
404     auto& thread_data = *get_rlbox_lucet_sandbox_thread_data();
405 #endif
406     thread_data.last_callback_invoked = N;
407     using T_Func = T_Ret (*)(T_Args...);
408     T_Func func;
409     {
410       RLBOX_ACQUIRE_SHARED_GUARD(lock, thread_data.sandbox->callback_mutex);
411       func = reinterpret_cast<T_Func>(thread_data.sandbox->callbacks[N]);
412     }
413     // Callbacks are invoked through function pointers, cannot use std::forward
414     // as we don't have caller context for T_Args, which means they are all
415     // effectively passed by value
416     auto ret_val =
417       func(thread_data.sandbox->serialize_to_sandbox<T_Args>(params)...);
418     // Copy the return value back
419     auto ret_ptr = reinterpret_cast<T_Ret*>(
420       thread_data.sandbox->template impl_get_unsandboxed_pointer<T_Ret*>(ret));
421     *ret_ptr = ret_val;
422   }
423 
424   template<typename T_Ret, typename... T_Args>
get_lucet_type_index(T_Ret (*)(T_Args...)=nullptr) const425   inline T_PointerType get_lucet_type_index(
426     T_Ret (*/* dummy for template inference */)(T_Args...) = nullptr) const
427   {
428     // Class return types as promoted to args
429     constexpr bool promoted = std::is_class_v<T_Ret>;
430     int32_t type_index;
431 
432     if constexpr (promoted) {
433       LucetValueType ret_type = LucetValueType::LucetValueType_Void;
434       LucetValueType param_types[] = {
435         lucet_detail::convert_type_to_wasm_type<T_Ret>::lucet_type,
436         lucet_detail::convert_type_to_wasm_type<T_Args>::lucet_type...
437       };
438       LucetFunctionSignature signature{ ret_type,
439                                         sizeof(param_types) /
440                                           sizeof(LucetValueType),
441                                         &(param_types[0]) };
442       type_index = lucet_get_function_type_index(sandbox, signature);
443     } else {
444       LucetValueType ret_type =
445         lucet_detail::convert_type_to_wasm_type<T_Ret>::lucet_type;
446       LucetValueType param_types[] = {
447         lucet_detail::convert_type_to_wasm_type<T_Args>::lucet_type...
448       };
449       LucetFunctionSignature signature{ ret_type,
450                                         sizeof(param_types) /
451                                           sizeof(LucetValueType),
452                                         &(param_types[0]) };
453       type_index = lucet_get_function_type_index(sandbox, signature);
454     }
455 
456     return type_index;
457   }
458 
ensure_return_slot_size(size_t size)459   void ensure_return_slot_size(size_t size)
460   {
461     if (size > return_slot_size) {
462       if (return_slot_size) {
463         impl_free_in_sandbox(return_slot);
464       }
465       return_slot = impl_malloc_in_sandbox(size);
466       detail::dynamic_check(
467         return_slot != 0,
468         "Error initializing return slot. Sandbox may be out of memory!");
469       return_slot_size = size;
470     }
471   }
472 
473 protected:
474   // Set external_loads_exist to true, if the host application loads the
475   // library lucet_module_path outside of rlbox_lucet_sandbox such as via dlopen
476   // or the Windows equivalent
impl_create_sandbox(const char * lucet_module_path,bool external_loads_exist,bool allow_stdio)477   inline void impl_create_sandbox(const char* lucet_module_path,
478                                   bool external_loads_exist,
479                                   bool allow_stdio)
480   {
481     detail::dynamic_check(sandbox == nullptr, "Sandbox already initialized");
482     sandbox = lucet_load_module(lucet_module_path, allow_stdio);
483     detail::dynamic_check(sandbox != nullptr, "Sandbox could not be created");
484 
485     heap_base = reinterpret_cast<uintptr_t>(impl_get_memory_location());
486     // Check that the address space is larger than the sandbox heap i.e. 4GB
487     // sandbox heap, host has to have more than 4GB
488     static_assert(sizeof(uintptr_t) > sizeof(T_PointerType));
489     // Check that the heap is aligned to the pointer size i.e. 32-bit pointer =>
490     // aligned to 4GB. The implementations of
491     // impl_get_unsandboxed_pointer_no_ctx and impl_get_sandboxed_pointer_no_ctx
492     // below rely on this.
493     uintptr_t heap_offset_mask = std::numeric_limits<T_PointerType>::max();
494     detail::dynamic_check((heap_base & heap_offset_mask) == 0,
495                           "Sandbox heap not aligned to 4GB");
496 
497     // cache these for performance
498     malloc_index = impl_lookup_symbol("malloc");
499     free_index = impl_lookup_symbol("free");
500 
501     set_callbacks_slots_ref(external_loads_exist);
502   }
503 
impl_create_sandbox(const char * lucet_module_path)504   inline void impl_create_sandbox(const char* lucet_module_path)
505   {
506     // Default is to assume that no external code will load the wasm library as
507     // this is usually the case
508     const bool external_loads_exist = false;
509     const bool allow_stdio = true;
510     impl_create_sandbox(lucet_module_path, external_loads_exist, allow_stdio);
511   }
512 
impl_destroy_sandbox()513   inline void impl_destroy_sandbox() {
514     if (return_slot_size) {
515       impl_free_in_sandbox(return_slot);
516     }
517     lucet_drop_module(sandbox);
518   }
519 
520   template<typename T>
impl_get_unsandboxed_pointer(T_PointerType p) const521   inline void* impl_get_unsandboxed_pointer(T_PointerType p) const
522   {
523     if constexpr (std::is_function_v<std::remove_pointer_t<T>>) {
524       LucetFunctionTable functionPointerTable =
525         lucet_get_function_pointer_table(sandbox);
526       if (p >= functionPointerTable.length) {
527         // Received out of range function pointer
528         return nullptr;
529       }
530       auto ret = functionPointerTable.data[p].rf;
531       return reinterpret_cast<void*>(static_cast<uintptr_t>(ret));
532     } else {
533       return reinterpret_cast<void*>(heap_base + p);
534     }
535   }
536 
537   template<typename T>
impl_get_sandboxed_pointer(const void * p) const538   inline T_PointerType impl_get_sandboxed_pointer(const void* p) const
539   {
540     if constexpr (std::is_function_v<std::remove_pointer_t<T>>) {
541       // p is a pointer to a function internal to the lucet module
542       // we need to either
543       // 1) find the indirect function slot this is registered and return the
544       // slot number. For this we need to scan the full indirect function table,
545       // not just the portion we have reserved for callbacks.
546       // 2) in the scenario this function has not ever been listed as an
547       // indirect function, we need to register this like a normal callback.
548       // However, unlike callbacks, we will not require the user to unregister
549       // this. Instead, this permenantly takes up a callback slot.
550       LucetFunctionTable functionPointerTable =
551         lucet_get_function_pointer_table(sandbox);
552       std::lock_guard<std::mutex> lock(callback_table_mutex);
553 
554       // Scenario 1 described above
555       ssize_t empty_slot = -1;
556       for (size_t i = 0; i < functionPointerTable.length; i++) {
557         if (functionPointerTable.data[i].rf == reinterpret_cast<uintptr_t>(p)) {
558           return static_cast<T_PointerType>(i);
559         } else if (functionPointerTable.data[i].rf == 0 && empty_slot == -1) {
560           // found an empty slot. Save it, as we may use it later.
561           empty_slot = i;
562         }
563       }
564 
565       // Scenario 2 described above
566       detail::dynamic_check(
567         empty_slot != -1,
568         "Could not find an empty slot in sandbox function table. This would "
569         "happen if you have registered too many callbacks, or unsandboxed "
570         "too many function pointers. You can file a bug if you want to "
571         "increase the maximum allowed callbacks or unsadnboxed functions "
572         "pointers");
573       T dummy = nullptr;
574       int32_t type_index = get_lucet_type_index(dummy);
575       functionPointerTable.data[empty_slot].ty = type_index;
576       functionPointerTable.data[empty_slot].rf = reinterpret_cast<uintptr_t>(p);
577       return empty_slot;
578 
579     } else {
580       return static_cast<T_PointerType>(reinterpret_cast<uintptr_t>(p));
581     }
582   }
583 
584   template<typename T>
impl_get_unsandboxed_pointer_no_ctx(T_PointerType p,const void * example_unsandboxed_ptr,rlbox_lucet_sandbox * (* expensive_sandbox_finder)(const void * example_unsandboxed_ptr))585   static inline void* impl_get_unsandboxed_pointer_no_ctx(
586     T_PointerType p,
587     const void* example_unsandboxed_ptr,
588     rlbox_lucet_sandbox* (*expensive_sandbox_finder)(
589       const void* example_unsandboxed_ptr))
590   {
591     if constexpr (std::is_function_v<std::remove_pointer_t<T>>) {
592       // swizzling function pointers needs access to the function pointer tables
593       // and thus cannot be done without context
594       auto sandbox = expensive_sandbox_finder(example_unsandboxed_ptr);
595       return sandbox->impl_get_unsandboxed_pointer<T>(p);
596     } else {
597       // grab the memory base from the example_unsandboxed_ptr
598       uintptr_t heap_base_mask =
599         std::numeric_limits<uintptr_t>::max() &
600         ~(static_cast<uintptr_t>(std::numeric_limits<T_PointerType>::max()));
601       uintptr_t computed_heap_base =
602         reinterpret_cast<uintptr_t>(example_unsandboxed_ptr) & heap_base_mask;
603       uintptr_t ret = computed_heap_base | p;
604       return reinterpret_cast<void*>(ret);
605     }
606   }
607 
608   template<typename T>
impl_get_sandboxed_pointer_no_ctx(const void * p,const void * example_unsandboxed_ptr,rlbox_lucet_sandbox * (* expensive_sandbox_finder)(const void * example_unsandboxed_ptr))609   static inline T_PointerType impl_get_sandboxed_pointer_no_ctx(
610     const void* p,
611     const void* example_unsandboxed_ptr,
612     rlbox_lucet_sandbox* (*expensive_sandbox_finder)(
613       const void* example_unsandboxed_ptr))
614   {
615     if constexpr (std::is_function_v<std::remove_pointer_t<T>>) {
616       // swizzling function pointers needs access to the function pointer tables
617       // and thus cannot be done without context
618       auto sandbox = expensive_sandbox_finder(example_unsandboxed_ptr);
619       return sandbox->impl_get_sandboxed_pointer<T>(p);
620     } else {
621       // Just clear the memory base to leave the offset
622       RLBOX_LUCET_UNUSED(example_unsandboxed_ptr);
623       uintptr_t ret = reinterpret_cast<uintptr_t>(p) &
624                       std::numeric_limits<T_PointerType>::max();
625       return static_cast<T_PointerType>(ret);
626     }
627   }
628 
impl_is_in_same_sandbox(const void * p1,const void * p2)629   static inline bool impl_is_in_same_sandbox(const void* p1, const void* p2)
630   {
631     uintptr_t heap_base_mask = std::numeric_limits<uintptr_t>::max() &
632                                ~(std::numeric_limits<T_PointerType>::max());
633     return (reinterpret_cast<uintptr_t>(p1) & heap_base_mask) ==
634            (reinterpret_cast<uintptr_t>(p2) & heap_base_mask);
635   }
636 
impl_is_pointer_in_sandbox_memory(const void * p)637   inline bool impl_is_pointer_in_sandbox_memory(const void* p)
638   {
639     size_t length = impl_get_total_memory();
640     uintptr_t p_val = reinterpret_cast<uintptr_t>(p);
641     return p_val >= heap_base && p_val < (heap_base + length);
642   }
643 
impl_is_pointer_in_app_memory(const void * p)644   inline bool impl_is_pointer_in_app_memory(const void* p)
645   {
646     return !(impl_is_pointer_in_sandbox_memory(p));
647   }
648 
impl_get_total_memory()649   inline size_t impl_get_total_memory() { return lucet_get_heap_size(sandbox); }
650 
impl_get_memory_location()651   inline void* impl_get_memory_location()
652   {
653     return lucet_get_heap_base(sandbox);
654   }
655 
impl_lookup_symbol(const char * func_name)656   void* impl_lookup_symbol(const char* func_name)
657   {
658     return lucet_lookup_function(sandbox, func_name);
659   }
660 
661   template<typename T, typename T_Converted, typename... T_Args>
impl_invoke_with_func_ptr(T_Converted * func_ptr,T_Args &&...params)662   auto impl_invoke_with_func_ptr(T_Converted* func_ptr, T_Args&&... params)
663   {
664 #ifdef RLBOX_EMBEDDER_PROVIDES_TLS_STATIC_VARIABLES
665     auto& thread_data = *get_rlbox_lucet_sandbox_thread_data();
666 #endif
667     thread_data.sandbox = this;
668     lucet_set_curr_instance(sandbox);
669 
670     // WASM functions are mangled in the following manner
671     // 1. All primitive types are left as is and follow an LP32 machine model
672     // (as opposed to the possibly 64-bit application)
673     // 2. All pointers are changed to u32 types
674     // 3. Returned class are returned as an out parameter before the actual
675     // function parameters
676     // 4. All class parameters are passed as pointers (u32 types)
677     // 5. The heap address is passed in as the first argument to the function
678     //
679     // RLBox accounts for the first 2 differences in T_Converted type, but we
680     // need to handle the rest
681 
682     // Handle point 3
683     using T_Ret = lucet_detail::return_argument<T_Converted>;
684     if constexpr (std::is_class_v<T_Ret>) {
685       using T_Conv1 = lucet_detail::change_return_type<T_Converted, void>;
686       using T_Conv2 = lucet_detail::prepend_arg_type<T_Conv1, T_PointerType>;
687       auto func_ptr_conv =
688         reinterpret_cast<T_Conv2*>(reinterpret_cast<uintptr_t>(func_ptr));
689       ensure_return_slot_size(sizeof(T_Ret));
690       impl_invoke_with_func_ptr<T>(func_ptr_conv, return_slot, params...);
691 
692       auto ptr = reinterpret_cast<T_Ret*>(
693         impl_get_unsandboxed_pointer<T_Ret*>(return_slot));
694       T_Ret ret = *ptr;
695       return ret;
696     }
697 
698     // Handle point 4
699     constexpr size_t alloc_length = [&] {
700       if constexpr (sizeof...(params) > 0) {
701         return ((std::is_class_v<T_Args> ? 1 : 0) + ...);
702       } else {
703         return 0;
704       }
705     }();
706 
707     // 0 arg functions create 0 length arrays which is not allowed
708     T_PointerType allocations_buff[alloc_length == 0 ? 1 : alloc_length];
709     T_PointerType* allocations = allocations_buff;
710 
711     auto serialize_class_arg =
712       [&](auto arg) -> std::conditional_t<std::is_class_v<decltype(arg)>,
713                                           T_PointerType,
714                                           decltype(arg)> {
715       using T_Arg = decltype(arg);
716       if constexpr (std::is_class_v<T_Arg>) {
717         auto slot = impl_malloc_in_sandbox(sizeof(T_Arg));
718         auto ptr =
719           reinterpret_cast<T_Arg*>(impl_get_unsandboxed_pointer<T_Arg*>(slot));
720         *ptr = arg;
721         allocations[0] = slot;
722         allocations++;
723         return slot;
724       } else {
725         return arg;
726       }
727     };
728 
729     // 0 arg functions don't use serialize
730     RLBOX_LUCET_UNUSED(serialize_class_arg);
731 
732     using T_ConvNoClass =
733       lucet_detail::change_class_arg_types<T_Converted, T_PointerType>;
734 
735     // Handle Point 5
736     using T_ConvHeap = lucet_detail::prepend_arg_type<T_ConvNoClass, uint64_t>;
737 
738     // Function invocation
739     auto func_ptr_conv =
740       reinterpret_cast<T_ConvHeap*>(reinterpret_cast<uintptr_t>(func_ptr));
741 
742     using T_NoVoidRet =
743       std::conditional_t<std::is_void_v<T_Ret>, uint32_t, T_Ret>;
744     T_NoVoidRet ret;
745 
746     if constexpr (std::is_void_v<T_Ret>) {
747       RLBOX_LUCET_UNUSED(ret);
748       func_ptr_conv(heap_base, serialize_class_arg(params)...);
749     } else {
750       ret = func_ptr_conv(heap_base, serialize_class_arg(params)...);
751     }
752 
753     for (size_t i = 0; i < alloc_length; i++) {
754       impl_free_in_sandbox(allocations_buff[i]);
755     }
756 
757     if constexpr (!std::is_void_v<T_Ret>) {
758       return ret;
759     }
760   }
761 
impl_malloc_in_sandbox(size_t size)762   inline T_PointerType impl_malloc_in_sandbox(size_t size)
763   {
764     detail::dynamic_check(size <= std::numeric_limits<uint32_t>::max(),
765                           "Attempting to malloc more than the heap size");
766     using T_Func = void*(size_t);
767     using T_Converted = T_PointerType(uint32_t);
768     T_PointerType ret = impl_invoke_with_func_ptr<T_Func, T_Converted>(
769       reinterpret_cast<T_Converted*>(malloc_index),
770       static_cast<uint32_t>(size));
771     return ret;
772   }
773 
impl_free_in_sandbox(T_PointerType p)774   inline void impl_free_in_sandbox(T_PointerType p)
775   {
776     using T_Func = void(void*);
777     using T_Converted = void(T_PointerType);
778     impl_invoke_with_func_ptr<T_Func, T_Converted>(
779       reinterpret_cast<T_Converted*>(free_index), p);
780   }
781 
782   template<typename T_Ret, typename... T_Args>
impl_register_callback(void * key,void * callback)783   inline T_PointerType impl_register_callback(void* key, void* callback)
784   {
785     int32_t type_index = get_lucet_type_index<T_Ret, T_Args...>();
786 
787     detail::dynamic_check(
788       type_index != -1,
789       "Could not find lucet type for callback signature. This can "
790       "happen if you tried to register a callback whose signature "
791       "does not correspond to any callbacks used in the library.");
792 
793     bool found = false;
794     uint32_t found_loc = 0;
795     uint32_t slot_number = 0;
796 
797     {
798       std::lock_guard<std::mutex> lock(callback_table_mutex);
799 
800       // need a compile time for loop as we we need I to be a compile time value
801       // this is because we are setting the I'th callback ineterceptor
802       lucet_detail::compile_time_for<MAX_CALLBACKS>([&](auto I) {
803         constexpr auto i = I.value;
804         if (!found && callback_slots->elements[i]->rf == 0) {
805           found = true;
806           found_loc = i;
807           slot_number = callback_slots->slot_number[i];
808 
809           void* chosen_interceptor;
810           if constexpr (std::is_class_v<T_Ret>) {
811             chosen_interceptor = reinterpret_cast<void*>(
812               callback_interceptor_promoted<i, T_Ret, T_Args...>);
813           } else {
814             chosen_interceptor = reinterpret_cast<void*>(
815               callback_interceptor<i, T_Ret, T_Args...>);
816           }
817           callback_slots->elements[i]->ty = type_index;
818           callback_slots->elements[i]->rf =
819             reinterpret_cast<uintptr_t>(chosen_interceptor);
820         }
821       });
822     }
823 
824     detail::dynamic_check(
825       found,
826       "Could not find an empty slot in sandbox function table. This would "
827       "happen if you have registered too many callbacks, or unsandboxed "
828       "too many function pointers. You can file a bug if you want to "
829       "increase the maximum allowed callbacks or unsadnboxed functions "
830       "pointers");
831 
832     {
833       RLBOX_ACQUIRE_UNIQUE_GUARD(lock, callback_mutex);
834       callback_unique_keys[found_loc] = key;
835       callbacks[found_loc] = callback;
836       callback_slot_assignment[found_loc] = slot_number;
837     }
838 
839     return static_cast<T_PointerType>(slot_number);
840   }
841 
842   static inline std::pair<rlbox_lucet_sandbox*, void*>
impl_get_executed_callback_sandbox_and_key()843   impl_get_executed_callback_sandbox_and_key()
844   {
845 #ifdef RLBOX_EMBEDDER_PROVIDES_TLS_STATIC_VARIABLES
846     auto& thread_data = *get_rlbox_lucet_sandbox_thread_data();
847 #endif
848     auto sandbox = thread_data.sandbox;
849     auto callback_num = thread_data.last_callback_invoked;
850     void* key = sandbox->callback_unique_keys[callback_num];
851     return std::make_pair(sandbox, key);
852   }
853 
854   template<typename T_Ret, typename... T_Args>
impl_unregister_callback(void * key)855   inline void impl_unregister_callback(void* key)
856   {
857     bool found = false;
858     uint32_t i = 0;
859     {
860       RLBOX_ACQUIRE_UNIQUE_GUARD(lock, callback_mutex);
861       for (; i < MAX_CALLBACKS; i++) {
862         if (callback_unique_keys[i] == key) {
863           callback_unique_keys[i] = nullptr;
864           callbacks[i] = nullptr;
865           callback_slot_assignment[i] = 0;
866           found = true;
867           break;
868         }
869       }
870     }
871 
872     detail::dynamic_check(
873       found, "Internal error: Could not find callback to unregister");
874 
875     std::lock_guard<std::mutex> shared_lock(callback_table_mutex);
876     callback_slots->elements[i]->rf = 0;
877     return;
878   }
879 };
880 
881 } // namespace rlbox