1 #include <mutex>
2 #include <set>
3 #include <stdint.h>
4 #include <string>
5 
6 #ifdef _WIN32
7 #ifdef _MSC_VER
8 #define NOMINMAX
9 #endif
10 #include <windows.h>
11 #else
12 #include <dlfcn.h>
13 #include <sys/mman.h>
14 #endif
15 
16 #include "CodeGen_Internal.h"
17 #include "CodeGen_LLVM.h"
18 #include "Debug.h"
19 #include "JITModule.h"
20 #include "LLVM_Headers.h"
21 #include "LLVM_Output.h"
22 #include "LLVM_Runtime_Linker.h"
23 #include "Pipeline.h"
24 
25 namespace Halide {
26 namespace Internal {
27 
28 using std::string;
29 
30 #if defined(__GNUC__) && defined(__i386__)
31 extern "C" unsigned long __udivdi3(unsigned long a, unsigned long b);
32 #endif
33 
34 #ifdef _WIN32
get_symbol_address(const char * s)35 void *get_symbol_address(const char *s) {
36     return (void *)GetProcAddress(GetModuleHandle(nullptr), s);
37 }
38 #else
get_symbol_address(const char * s)39 void *get_symbol_address(const char *s) {
40     // Mac OS 10.11 fails to return a symbol address if nullptr or RTLD_DEFAULT
41     // is passed to dlsym. This seems to work.
42     void *handle = dlopen(nullptr, RTLD_LAZY);
43     void *result = dlsym(handle, s);
44     dlclose(handle);
45     return result;
46 }
47 #endif
48 
49 namespace {
50 
have_symbol(const char * s)51 bool have_symbol(const char *s) {
52     return get_symbol_address(s) != nullptr;
53 }
54 
55 typedef struct CUctx_st *CUcontext;
56 
57 struct SharedCudaContext {
58     CUctx_st *ptr;
59     volatile int lock;
60 
61     // Will be created on first use by a jitted kernel that uses it
SharedCudaContextHalide::Internal::__anon7d6c1bd60111::SharedCudaContext62     SharedCudaContext()
63         : ptr(0), lock(0) {
64     }
65 
66     // Note that we never free the context, because static destructor
67     // order is unpredictable, and we can't free the context before
68     // all JITModules are freed. Users may be stashing Funcs or Images
69     // in globals, and these keep JITModules around.
70 } cuda_ctx;
71 
72 typedef struct cl_context_st *cl_context;
73 typedef struct cl_command_queue_st *cl_command_queue;
74 
75 // A single global OpenCL context and command queue to share between
76 // jitted functions.
77 struct SharedOpenCLContext {
78     cl_context context;
79     cl_command_queue command_queue;
80     volatile int lock;
81 
SharedOpenCLContextHalide::Internal::__anon7d6c1bd60111::SharedOpenCLContext82     SharedOpenCLContext()
83         : context(nullptr), command_queue(nullptr), lock(0) {
84     }
85 
86     // We never free the context, for the same reason as above.
87 } cl_ctx;
88 
load_opengl()89 void load_opengl() {
90 #if defined(__linux__)
91     if (have_symbol("glXGetCurrentContext") && have_symbol("glDeleteTextures")) {
92         debug(1) << "OpenGL support code already linked in...\n";
93     } else {
94         debug(1) << "Looking for OpenGL support code...\n";
95         string error;
96         llvm::sys::DynamicLibrary::LoadLibraryPermanently("libGL.so.1", &error);
97         user_assert(error.empty()) << "Could not find libGL.so\n";
98         llvm::sys::DynamicLibrary::LoadLibraryPermanently("libX11.so", &error);
99         user_assert(error.empty()) << "Could not find libX11.so\n";
100     }
101 #elif defined(__APPLE__)
102     if (have_symbol("aglCreateContext") && have_symbol("glDeleteTextures")) {
103         debug(1) << "OpenGL support code already linked in...\n";
104     } else {
105         debug(1) << "Looking for OpenGL support code...\n";
106         string error;
107         llvm::sys::DynamicLibrary::LoadLibraryPermanently("/System/Library/Frameworks/AGL.framework/AGL", &error);
108         user_assert(error.empty()) << "Could not find AGL.framework\n";
109         llvm::sys::DynamicLibrary::LoadLibraryPermanently("/System/Library/Frameworks/OpenGL.framework/OpenGL", &error);
110         user_assert(error.empty()) << "Could not find OpenGL.framework\n";
111     }
112 #else
113     internal_error << "JIT support for OpenGL on anything other than linux or OS X not yet implemented\n";
114 #endif
115 }
116 
load_metal()117 void load_metal() {
118 #if defined(__APPLE__)
119     if (have_symbol("MTLCreateSystemDefaultDevice")) {
120         debug(1) << "Metal framework already linked in...\n";
121     } else {
122         debug(1) << "Looking for Metal framework...\n";
123         string error;
124         llvm::sys::DynamicLibrary::LoadLibraryPermanently("/System/Library/Frameworks/Metal.framework/Metal", &error);
125         user_assert(error.empty()) << "Could not find Metal.framework\n";
126     }
127 #else
128     internal_error << "JIT support for Metal only implemented on OS X\n";
129 #endif
130 }
131 
132 }  // namespace
133 
134 using namespace llvm;
135 
136 class JITModuleContents {
137 public:
138     mutable RefCount ref_count;
139 
140     // Just construct a module with symbols to import into other modules.
JITModuleContents()141     JITModuleContents()
142         : execution_engine(nullptr) {
143     }
144 
~JITModuleContents()145     ~JITModuleContents() {
146         if (execution_engine != nullptr) {
147             execution_engine->runStaticConstructorsDestructors(true);
148             delete execution_engine;
149         }
150     }
151 
152     std::map<std::string, JITModule::Symbol> exports;
153     llvm::LLVMContext context;
154     ExecutionEngine *execution_engine;
155     std::vector<JITModule> dependencies;
156     JITModule::Symbol entrypoint;
157     JITModule::Symbol argv_entrypoint;
158 
159     std::string name;
160 };
161 
162 template<>
ref_count(const JITModuleContents * f)163 RefCount &ref_count<JITModuleContents>(const JITModuleContents *f) noexcept {
164     return f->ref_count;
165 }
166 
167 template<>
destroy(const JITModuleContents * f)168 void destroy<JITModuleContents>(const JITModuleContents *f) {
169     delete f;
170 }
171 
172 namespace {
173 
174 // Retrieve a function pointer from an llvm module, possibly by compiling it.
compile_and_get_function(ExecutionEngine & ee,const string & name)175 JITModule::Symbol compile_and_get_function(ExecutionEngine &ee, const string &name) {
176     debug(2) << "JIT Compiling " << name << "\n";
177     llvm::Function *fn = ee.FindFunctionNamed(name.c_str());
178     internal_assert(fn->getName() == name);
179     void *f = (void *)ee.getFunctionAddress(name);
180     if (!f) {
181         internal_error << "Compiling " << name << " returned nullptr\n";
182     }
183 
184     JITModule::Symbol symbol(f);
185 
186     debug(2) << "Function " << name << " is at " << f << "\n";
187 
188     return symbol;
189 }
190 
191 // Expand LLVM's search for symbols to include code contained in a set of JITModule.
192 class HalideJITMemoryManager : public SectionMemoryManager {
193     std::vector<JITModule> modules;
194     std::vector<std::pair<uint8_t *, size_t>> code_pages;
195 
196 public:
HalideJITMemoryManager(const std::vector<JITModule> & modules)197     HalideJITMemoryManager(const std::vector<JITModule> &modules)
198         : modules(modules) {
199     }
200 
getSymbolAddress(const std::string & name)201     uint64_t getSymbolAddress(const std::string &name) override {
202         for (size_t i = 0; i < modules.size(); i++) {
203             const JITModule &m = modules[i];
204             std::map<std::string, JITModule::Symbol>::const_iterator iter = m.exports().find(name);
205             if (iter == m.exports().end() && starts_with(name, "_")) {
206                 iter = m.exports().find(name.substr(1));
207             }
208             if (iter != m.exports().end()) {
209                 return (uint64_t)iter->second.address;
210             }
211         }
212         uint64_t result = SectionMemoryManager::getSymbolAddress(name);
213 #if defined(__GNUC__) && defined(__i386__)
214         // This is a workaround for an odd corner case (cross-compiling + testing
215         // Python bindings x86-32 on an x86-64 system): __udivdi3 is a helper function
216         // that GCC uses to do u64/u64 division on 32-bit systems; it's usually included
217         // by the linker on these systems as needed. When we JIT, LLVM will include references
218         // to this call; MCJIT fixes up these references by doing (roughly) dlopen(NULL)
219         // to look up the symbol. For normal JIT tests, this works fine, as dlopen(NULL)
220         // finds the test executable, which has the right lookups to locate it inside libHalide.so.
221         // If, however, we are running a JIT-via-Python test, dlopen(NULL) returns the
222         // CPython executable... which apparently *doesn't* include this as an exported
223         // function, so the lookup fails and crashiness ensues. So our workaround here is
224         // a bit icky, but expedient: check for this name if we can't find it elsewhere,
225         // and if so, return the one we know should be present. (Obviously, if other runtime
226         // helper functions of this sort crop up in the future, this should be expanded
227         // into a "builtins map".)
228         if (result == 0 && name == "__udivdi3") {
229             result = (uint64_t)&__udivdi3;
230         }
231 #endif
232         internal_assert(result != 0)
233             << "HalideJITMemoryManager: unable to find address for " << name << "\n";
234         return result;
235     }
236 
allocateCodeSection(uintptr_t size,unsigned alignment,unsigned section_id,StringRef section_name)237     uint8_t *allocateCodeSection(uintptr_t size, unsigned alignment, unsigned section_id, StringRef section_name) override {
238         uint8_t *result = SectionMemoryManager::allocateCodeSection(size, alignment, section_id, section_name);
239         code_pages.emplace_back(result, size);
240         return result;
241     }
242 };
243 
244 }  // namespace
245 
JITModule()246 JITModule::JITModule() {
247     jit_module = new JITModuleContents();
248 }
249 
JITModule(const Module & m,const LoweredFunc & fn,const std::vector<JITModule> & dependencies)250 JITModule::JITModule(const Module &m, const LoweredFunc &fn,
251                      const std::vector<JITModule> &dependencies) {
252     jit_module = new JITModuleContents();
253     std::unique_ptr<llvm::Module> llvm_module(compile_module_to_llvm_module(m, jit_module->context));
254     std::vector<JITModule> deps_with_runtime = dependencies;
255     std::vector<JITModule> shared_runtime = JITSharedRuntime::get(llvm_module.get(), m.target());
256     deps_with_runtime.insert(deps_with_runtime.end(), shared_runtime.begin(), shared_runtime.end());
257     compile_module(std::move(llvm_module), fn.name, m.target(), deps_with_runtime);
258     // If -time-passes is in HL_LLVM_ARGS, this will print llvm passes time statstics otherwise its no-op.
259     llvm::reportAndResetTimings();
260 }
261 
compile_module(std::unique_ptr<llvm::Module> m,const string & function_name,const Target & target,const std::vector<JITModule> & dependencies,const std::vector<std::string> & requested_exports)262 void JITModule::compile_module(std::unique_ptr<llvm::Module> m, const string &function_name, const Target &target,
263                                const std::vector<JITModule> &dependencies,
264                                const std::vector<std::string> &requested_exports) {
265 
266     // Ensure that LLVM is initialized
267     CodeGen_LLVM::initialize_llvm();
268 
269     // Make the execution engine
270     debug(2) << "Creating new execution engine\n";
271     debug(2) << "Target triple: " << m->getTargetTriple() << "\n";
272     string error_string;
273 
274     string mcpu;
275     string mattrs;
276     llvm::TargetOptions options;
277     get_target_options(*m, options, mcpu, mattrs);
278 
279     DataLayout initial_module_data_layout = m->getDataLayout();
280     string module_name = m->getModuleIdentifier();
281 
282     llvm::EngineBuilder engine_builder((std::move(m)));
283     engine_builder.setTargetOptions(options);
284     engine_builder.setErrorStr(&error_string);
285     engine_builder.setEngineKind(llvm::EngineKind::JIT);
286     HalideJITMemoryManager *memory_manager = new HalideJITMemoryManager(dependencies);
287     engine_builder.setMCJITMemoryManager(std::unique_ptr<RTDyldMemoryManager>(memory_manager));
288 
289     engine_builder.setOptLevel(CodeGenOpt::Aggressive);
290     if (!mcpu.empty()) {
291         engine_builder.setMCPU(mcpu);
292     }
293     std::vector<string> mattrs_array = {mattrs};
294     engine_builder.setMAttrs(mattrs_array);
295 
296     TargetMachine *tm = engine_builder.selectTarget();
297     internal_assert(tm) << error_string << "\n";
298     DataLayout target_data_layout(tm->createDataLayout());
299     if (initial_module_data_layout != target_data_layout) {
300         internal_error << "Warning: data layout mismatch between module ("
301                        << initial_module_data_layout.getStringRepresentation()
302                        << ") and what the execution engine expects ("
303                        << target_data_layout.getStringRepresentation() << ")\n";
304     }
305     ExecutionEngine *ee = engine_builder.create(tm);
306 
307     if (!ee) std::cerr << error_string << "\n";
308     internal_assert(ee) << "Couldn't create execution engine\n";
309 
310     // Do any target-specific initialization
311     std::vector<llvm::JITEventListener *> listeners;
312 
313     if (target.arch == Target::X86) {
314         listeners.push_back(llvm::JITEventListener::createIntelJITEventListener());
315     }
316     // TODO: If this ever works in LLVM, this would allow profiling of JIT code with symbols with oprofile.
317     //listeners.push_back(llvm::createOProfileJITEventListener());
318 
319     for (size_t i = 0; i < listeners.size(); i++) {
320         ee->RegisterJITEventListener(listeners[i]);
321     }
322 
323     // Retrieve function pointers from the compiled module (which also
324     // triggers compilation)
325     debug(1) << "JIT compiling " << module_name
326              << " for " << target.to_string() << "\n";
327 
328     std::map<std::string, Symbol> exports;
329 
330     Symbol entrypoint;
331     Symbol argv_entrypoint;
332     if (!function_name.empty()) {
333         entrypoint = compile_and_get_function(*ee, function_name);
334         exports[function_name] = entrypoint;
335         argv_entrypoint = compile_and_get_function(*ee, function_name + "_argv");
336         exports[function_name + "_argv"] = argv_entrypoint;
337     }
338 
339     for (size_t i = 0; i < requested_exports.size(); i++) {
340         exports[requested_exports[i]] = compile_and_get_function(*ee, requested_exports[i]);
341     }
342 
343     debug(2) << "Finalizing object\n";
344     ee->finalizeObject();
345     // Do any target-specific post-compilation module meddling
346     for (size_t i = 0; i < listeners.size(); i++) {
347         ee->UnregisterJITEventListener(listeners[i]);
348         delete listeners[i];
349     }
350     listeners.clear();
351 
352     // TODO: I don't think this is necessary, we shouldn't have any static constructors
353     ee->runStaticConstructorsDestructors(false);
354 
355     // Stash the various objects that need to stay alive behind a reference-counted pointer.
356     jit_module->exports = exports;
357     jit_module->execution_engine = ee;
358     jit_module->dependencies = dependencies;
359     jit_module->entrypoint = entrypoint;
360     jit_module->argv_entrypoint = argv_entrypoint;
361     jit_module->name = function_name;
362 }
363 
364 /*static*/
make_trampolines_module(const Target & target_arg,const std::map<std::string,JITExtern> & externs,const std::string & suffix,const std::vector<JITModule> & deps)365 JITModule JITModule::make_trampolines_module(const Target &target_arg,
366                                              const std::map<std::string, JITExtern> &externs,
367                                              const std::string &suffix,
368                                              const std::vector<JITModule> &deps) {
369     Target target = target_arg;
370     target.set_feature(Target::JIT);
371 
372     JITModule result;
373     std::vector<std::pair<std::string, ExternSignature>> extern_signatures;
374     std::vector<std::string> requested_exports;
375     for (const std::pair<const std::string, JITExtern> &e : externs) {
376         const std::string &callee_name = e.first;
377         const std::string wrapper_name = callee_name + suffix;
378         const ExternCFunction &extern_c = e.second.extern_c_function();
379         result.add_extern_for_export(callee_name, extern_c);
380         requested_exports.push_back(wrapper_name);
381         extern_signatures.emplace_back(callee_name, extern_c.signature());
382     }
383 
384     std::unique_ptr<llvm::Module> llvm_module = CodeGen_LLVM::compile_trampolines(
385         target, result.jit_module->context, suffix, extern_signatures);
386 
387     result.compile_module(std::move(llvm_module), /*function_name*/ "", target, deps, requested_exports);
388 
389     return result;
390 }
391 
exports() const392 const std::map<std::string, JITModule::Symbol> &JITModule::exports() const {
393     return jit_module->exports;
394 }
395 
find_symbol_by_name(const std::string & name) const396 JITModule::Symbol JITModule::find_symbol_by_name(const std::string &name) const {
397     std::map<std::string, JITModule::Symbol>::iterator it = jit_module->exports.find(name);
398     if (it != jit_module->exports.end()) {
399         return it->second;
400     }
401     for (const JITModule &dep : jit_module->dependencies) {
402         JITModule::Symbol s = dep.find_symbol_by_name(name);
403         if (s.address) return s;
404     }
405     return JITModule::Symbol();
406 }
407 
main_function() const408 void *JITModule::main_function() const {
409     return jit_module->entrypoint.address;
410 }
411 
entrypoint_symbol() const412 JITModule::Symbol JITModule::entrypoint_symbol() const {
413     return jit_module->entrypoint;
414 }
415 
argv_function() const416 int (*JITModule::argv_function() const)(const void **) {
417     return (int (*)(const void **))jit_module->argv_entrypoint.address;
418 }
419 
argv_entrypoint_symbol() const420 JITModule::Symbol JITModule::argv_entrypoint_symbol() const {
421     return jit_module->argv_entrypoint;
422 }
423 
module_already_in_graph(const JITModuleContents * start,const JITModuleContents * target,std::set<const JITModuleContents * > & already_seen)424 static bool module_already_in_graph(const JITModuleContents *start, const JITModuleContents *target, std::set<const JITModuleContents *> &already_seen) {
425     if (start == target) {
426         return true;
427     }
428     if (already_seen.count(start) != 0) {
429         return false;
430     }
431     already_seen.insert(start);
432     for (const JITModule &dep_holder : start->dependencies) {
433         const JITModuleContents *dep = dep_holder.jit_module.get();
434         if (module_already_in_graph(dep, target, already_seen)) {
435             return true;
436         }
437     }
438     return false;
439 }
440 
add_dependency(JITModule & dep)441 void JITModule::add_dependency(JITModule &dep) {
442     std::set<const JITModuleContents *> already_seen;
443     internal_assert(!module_already_in_graph(dep.jit_module.get(), jit_module.get(), already_seen)) << "JITModule::add_dependency: creating circular dependency graph.\n";
444     jit_module->dependencies.push_back(dep);
445 }
446 
add_symbol_for_export(const std::string & name,const Symbol & extern_symbol)447 void JITModule::add_symbol_for_export(const std::string &name, const Symbol &extern_symbol) {
448     jit_module->exports[name] = extern_symbol;
449 }
450 
add_extern_for_export(const std::string & name,const ExternCFunction & extern_c_function)451 void JITModule::add_extern_for_export(const std::string &name,
452                                       const ExternCFunction &extern_c_function) {
453     Symbol symbol(extern_c_function.address());
454     jit_module->exports[name] = symbol;
455 }
456 
memoization_cache_set_size(int64_t size) const457 void JITModule::memoization_cache_set_size(int64_t size) const {
458     std::map<std::string, Symbol>::const_iterator f =
459         exports().find("halide_memoization_cache_set_size");
460     if (f != exports().end()) {
461         (reinterpret_bits<void (*)(int64_t)>(f->second.address))(size);
462     }
463 }
464 
reuse_device_allocations(bool b) const465 void JITModule::reuse_device_allocations(bool b) const {
466     std::map<std::string, Symbol>::const_iterator f =
467         exports().find("halide_reuse_device_allocations");
468     if (f != exports().end()) {
469         (reinterpret_bits<int (*)(void *, bool)>(f->second.address))(nullptr, b);
470     }
471 }
472 
compiled() const473 bool JITModule::compiled() const {
474     return jit_module->execution_engine != nullptr;
475 }
476 
477 namespace {
478 
479 JITHandlers runtime_internal_handlers;
480 JITHandlers default_handlers;
481 JITHandlers active_handlers;
482 int64_t default_cache_size;
483 
merge_handlers(JITHandlers & base,const JITHandlers & addins)484 void merge_handlers(JITHandlers &base, const JITHandlers &addins) {
485     if (addins.custom_print) {
486         base.custom_print = addins.custom_print;
487     }
488     if (addins.custom_malloc) {
489         base.custom_malloc = addins.custom_malloc;
490     }
491     if (addins.custom_free) {
492         base.custom_free = addins.custom_free;
493     }
494     if (addins.custom_do_task) {
495         base.custom_do_task = addins.custom_do_task;
496     }
497     if (addins.custom_do_par_for) {
498         base.custom_do_par_for = addins.custom_do_par_for;
499     }
500     if (addins.custom_error) {
501         base.custom_error = addins.custom_error;
502     }
503     if (addins.custom_trace) {
504         base.custom_trace = addins.custom_trace;
505     }
506     if (addins.custom_get_symbol) {
507         base.custom_get_symbol = addins.custom_get_symbol;
508     }
509     if (addins.custom_load_library) {
510         base.custom_load_library = addins.custom_load_library;
511     }
512     if (addins.custom_get_library_symbol) {
513         base.custom_get_library_symbol = addins.custom_get_library_symbol;
514     }
515 }
516 
print_handler(void * context,const char * msg)517 void print_handler(void *context, const char *msg) {
518     if (context) {
519         JITUserContext *jit_user_context = (JITUserContext *)context;
520         (*jit_user_context->handlers.custom_print)(context, msg);
521     } else {
522         return (*active_handlers.custom_print)(context, msg);
523     }
524 }
525 
malloc_handler(void * context,size_t x)526 void *malloc_handler(void *context, size_t x) {
527     if (context) {
528         JITUserContext *jit_user_context = (JITUserContext *)context;
529         return (*jit_user_context->handlers.custom_malloc)(context, x);
530     } else {
531         return (*active_handlers.custom_malloc)(context, x);
532     }
533 }
534 
free_handler(void * context,void * ptr)535 void free_handler(void *context, void *ptr) {
536     if (context) {
537         JITUserContext *jit_user_context = (JITUserContext *)context;
538         (*jit_user_context->handlers.custom_free)(context, ptr);
539     } else {
540         (*active_handlers.custom_free)(context, ptr);
541     }
542 }
543 
do_task_handler(void * context,halide_task f,int idx,uint8_t * closure)544 int do_task_handler(void *context, halide_task f, int idx,
545                     uint8_t *closure) {
546     if (context) {
547         JITUserContext *jit_user_context = (JITUserContext *)context;
548         return (*jit_user_context->handlers.custom_do_task)(context, f, idx, closure);
549     } else {
550         return (*active_handlers.custom_do_task)(context, f, idx, closure);
551     }
552 }
553 
do_par_for_handler(void * context,halide_task f,int min,int size,uint8_t * closure)554 int do_par_for_handler(void *context, halide_task f,
555                        int min, int size, uint8_t *closure) {
556     if (context) {
557         JITUserContext *jit_user_context = (JITUserContext *)context;
558         return (*jit_user_context->handlers.custom_do_par_for)(context, f, min, size, closure);
559     } else {
560         return (*active_handlers.custom_do_par_for)(context, f, min, size, closure);
561     }
562 }
563 
error_handler_handler(void * context,const char * msg)564 void error_handler_handler(void *context, const char *msg) {
565     if (context) {
566         JITUserContext *jit_user_context = (JITUserContext *)context;
567         (*jit_user_context->handlers.custom_error)(context, msg);
568     } else {
569         (*active_handlers.custom_error)(context, msg);
570     }
571 }
572 
trace_handler(void * context,const halide_trace_event_t * e)573 int32_t trace_handler(void *context, const halide_trace_event_t *e) {
574     if (context) {
575         JITUserContext *jit_user_context = (JITUserContext *)context;
576         return (*jit_user_context->handlers.custom_trace)(context, e);
577     } else {
578         return (*active_handlers.custom_trace)(context, e);
579     }
580 }
581 
get_symbol_handler(const char * name)582 void *get_symbol_handler(const char *name) {
583     return (*active_handlers.custom_get_symbol)(name);
584 }
585 
load_library_handler(const char * name)586 void *load_library_handler(const char *name) {
587     return (*active_handlers.custom_load_library)(name);
588 }
589 
get_library_symbol_handler(void * lib,const char * name)590 void *get_library_symbol_handler(void *lib, const char *name) {
591     return (*active_handlers.custom_get_library_symbol)(lib, name);
592 }
593 
594 template<typename function_t>
hook_function(const std::map<std::string,JITModule::Symbol> & exports,const char * hook_name,function_t hook)595 function_t hook_function(const std::map<std::string, JITModule::Symbol> &exports, const char *hook_name, function_t hook) {
596     auto iter = exports.find(hook_name);
597     internal_assert(iter != exports.end()) << "Failed to find function " << hook_name << "\n";
598     function_t (*hook_setter)(function_t) =
599         reinterpret_bits<function_t (*)(function_t)>(iter->second.address);
600     return (*hook_setter)(hook);
601 }
602 
adjust_module_ref_count(void * arg,int32_t count)603 void adjust_module_ref_count(void *arg, int32_t count) {
604     JITModuleContents *module = (JITModuleContents *)arg;
605 
606     debug(2) << "Adjusting refcount for module " << module->name << " by " << count << "\n";
607 
608     if (count > 0) {
609         module->ref_count.increment();
610     } else {
611         module->ref_count.decrement();
612     }
613 }
614 
615 std::mutex shared_runtimes_mutex;
616 
617 // The Halide runtime is broken up into pieces so that state can be
618 // shared across JIT compilations that do not use the same target
619 // options. At present, the split is into a MainShared module that
620 // contains most of the runtime except for device API specific code
621 // (GPU runtimes). There is one shared runtime per device API and a
622 // the JITModule for a Func depends on all device API modules
623 // specified in the target when it is JITted. (Instruction set variant
624 // specific code, such as math routines, is inlined into the module
625 // produced by compiling a Func so it can be specialized exactly for
626 // each target.)
627 enum RuntimeKind {
628     MainShared,
629     OpenCL,
630     Metal,
631     CUDA,
632     OpenGL,
633     OpenGLCompute,
634     Hexagon,
635     D3D12Compute,
636     OpenCLDebug,
637     MetalDebug,
638     CUDADebug,
639     OpenGLDebug,
640     OpenGLComputeDebug,
641     HexagonDebug,
642     D3D12ComputeDebug,
643     MaxRuntimeKind
644 };
645 
shared_runtimes(RuntimeKind k)646 JITModule &shared_runtimes(RuntimeKind k) {
647     // We're already guarded by the shared_runtimes_mutex
648     static JITModule *m = nullptr;
649     if (!m) {
650         // Note that this is never freed. On windows this would invoke
651         // static destructors that use threading objects, and these
652         // don't work (crash or deadlock) after main exits.
653         m = new JITModule[MaxRuntimeKind];
654     }
655     return m[k];
656 }
657 
make_module(llvm::Module * for_module,Target target,RuntimeKind runtime_kind,const std::vector<JITModule> & deps,bool create)658 JITModule &make_module(llvm::Module *for_module, Target target,
659                        RuntimeKind runtime_kind, const std::vector<JITModule> &deps,
660                        bool create) {
661 
662     JITModule &runtime = shared_runtimes(runtime_kind);
663     if (!runtime.compiled() && create) {
664         // Ensure that JIT feature is set on target as it must be in
665         // order for the right runtime components to be added.
666         target.set_feature(Target::JIT);
667         // msan doesn't work for jit modules
668         target.set_feature(Target::MSAN, false);
669 
670         Target one_gpu(target);
671         one_gpu.set_feature(Target::Debug, false);
672         one_gpu.set_feature(Target::OpenCL, false);
673         one_gpu.set_feature(Target::Metal, false);
674         one_gpu.set_feature(Target::CUDA, false);
675         one_gpu.set_feature(Target::HVX_64, false);
676         one_gpu.set_feature(Target::HVX_128, false);
677         one_gpu.set_feature(Target::OpenGL, false);
678         one_gpu.set_feature(Target::OpenGLCompute, false);
679         one_gpu.set_feature(Target::D3D12Compute, false);
680         string module_name;
681         switch (runtime_kind) {
682         case OpenCLDebug:
683             one_gpu.set_feature(Target::Debug);
684             one_gpu.set_feature(Target::OpenCL);
685             module_name = "debug_opencl";
686             break;
687         case OpenCL:
688             one_gpu.set_feature(Target::OpenCL);
689             module_name += "opencl";
690             break;
691         case MetalDebug:
692             one_gpu.set_feature(Target::Debug);
693             one_gpu.set_feature(Target::Metal);
694             load_metal();
695             module_name = "debug_metal";
696             break;
697         case Metal:
698             one_gpu.set_feature(Target::Metal);
699             module_name += "metal";
700             load_metal();
701             break;
702         case CUDADebug:
703             one_gpu.set_feature(Target::Debug);
704             one_gpu.set_feature(Target::CUDA);
705             module_name = "debug_cuda";
706             break;
707         case CUDA:
708             one_gpu.set_feature(Target::CUDA);
709             module_name += "cuda";
710             break;
711         case OpenGLDebug:
712             one_gpu.set_feature(Target::Debug);
713             one_gpu.set_feature(Target::OpenGL);
714             module_name = "debug_opengl";
715             load_opengl();
716             break;
717         case OpenGL:
718             one_gpu.set_feature(Target::OpenGL);
719             module_name += "opengl";
720             load_opengl();
721             break;
722         case OpenGLComputeDebug:
723             one_gpu.set_feature(Target::Debug);
724             one_gpu.set_feature(Target::OpenGLCompute);
725             module_name = "debug_openglcompute";
726             load_opengl();
727             break;
728         case OpenGLCompute:
729             one_gpu.set_feature(Target::OpenGLCompute);
730             module_name += "openglcompute";
731             load_opengl();
732             break;
733         case HexagonDebug:
734             one_gpu.set_feature(Target::Debug);
735             one_gpu.set_feature(Target::HVX_64);
736             module_name = "debug_hexagon";
737             break;
738         case Hexagon:
739             one_gpu.set_feature(Target::HVX_64);
740             module_name += "hexagon";
741             break;
742         case D3D12ComputeDebug:
743             one_gpu.set_feature(Target::Debug);
744             one_gpu.set_feature(Target::D3D12Compute);
745             module_name = "debug_d3d12compute";
746             break;
747         case D3D12Compute:
748             one_gpu.set_feature(Target::D3D12Compute);
749             module_name += "d3d12compute";
750 #if !defined(_WIN32)
751             internal_error << "JIT support for Direct3D 12 is only implemented on Windows 10 and above.\n";
752 #endif
753             break;
754         default:
755             module_name = "shared runtime";
756             break;
757         }
758 
759         // This function is protected by a mutex so this is thread safe.
760         auto module =
761             get_initial_module_for_target(one_gpu,
762                                           &runtime.jit_module->context,
763                                           true,
764                                           runtime_kind != MainShared);
765         if (for_module) {
766             clone_target_options(*for_module, *module);
767         }
768         module->setModuleIdentifier(module_name);
769 
770         std::set<std::string> halide_exports_unique;
771 
772         // Enumerate the functions.
773         for (auto &f : *module) {
774             // LLVM_Runtime_Linker has marked everything that should be exported as weak
775             if (f.hasWeakLinkage()) {
776                 halide_exports_unique.insert(get_llvm_function_name(f));
777             }
778         }
779 
780         std::vector<std::string> halide_exports(halide_exports_unique.begin(), halide_exports_unique.end());
781 
782         runtime.compile_module(std::move(module), "", target, deps, halide_exports);
783 
784         if (runtime_kind == MainShared) {
785             runtime_internal_handlers.custom_print =
786                 hook_function(runtime.exports(), "halide_set_custom_print", print_handler);
787 
788             runtime_internal_handlers.custom_malloc =
789                 hook_function(runtime.exports(), "halide_set_custom_malloc", malloc_handler);
790 
791             runtime_internal_handlers.custom_free =
792                 hook_function(runtime.exports(), "halide_set_custom_free", free_handler);
793 
794             runtime_internal_handlers.custom_do_task =
795                 hook_function(runtime.exports(), "halide_set_custom_do_task", do_task_handler);
796 
797             runtime_internal_handlers.custom_do_par_for =
798                 hook_function(runtime.exports(), "halide_set_custom_do_par_for", do_par_for_handler);
799 
800             runtime_internal_handlers.custom_error =
801                 hook_function(runtime.exports(), "halide_set_error_handler", error_handler_handler);
802 
803             runtime_internal_handlers.custom_trace =
804                 hook_function(runtime.exports(), "halide_set_custom_trace", trace_handler);
805 
806             runtime_internal_handlers.custom_get_symbol =
807                 hook_function(shared_runtimes(MainShared).exports(), "halide_set_custom_get_symbol", get_symbol_handler);
808 
809             runtime_internal_handlers.custom_load_library =
810                 hook_function(shared_runtimes(MainShared).exports(), "halide_set_custom_load_library", load_library_handler);
811 
812             runtime_internal_handlers.custom_get_library_symbol =
813                 hook_function(shared_runtimes(MainShared).exports(), "halide_set_custom_get_library_symbol", get_library_symbol_handler);
814 
815             active_handlers = runtime_internal_handlers;
816             merge_handlers(active_handlers, default_handlers);
817 
818             if (default_cache_size != 0) {
819                 runtime.memoization_cache_set_size(default_cache_size);
820             }
821 
822             runtime.jit_module->name = "MainShared";
823         } else {
824             runtime.jit_module->name = "GPU";
825         }
826 
827         uint64_t arg_addr =
828             runtime.jit_module->execution_engine->getGlobalValueAddress("halide_jit_module_argument");
829 
830         internal_assert(arg_addr != 0);
831         *((void **)arg_addr) = runtime.jit_module.get();
832 
833         uint64_t fun_addr = runtime.jit_module->execution_engine->getGlobalValueAddress("halide_jit_module_adjust_ref_count");
834         internal_assert(fun_addr != 0);
835         *(void (**)(void *arg, int32_t count))fun_addr = &adjust_module_ref_count;
836     }
837     return runtime;
838 }
839 
840 }  // anonymous namespace
841 
842 /* Shared runtimes are stored as global state. The set needed is
843  * determined from the target and the retrieved. If one does not exist
844  * yet, it is made on the fly from the compiled in bitcode of the
845  * runtime modules. As with all JITModules, the shared runtime is ref
846  * counted, but a global keeps one ref alive until shutdown or when
847  * JITSharedRuntime::release_all is called. If
848  * JITSharedRuntime::release_all is called, the global state is reset
849  * and any newly compiled Funcs will get a new runtime. */
get(llvm::Module * for_module,const Target & target,bool create)850 std::vector<JITModule> JITSharedRuntime::get(llvm::Module *for_module, const Target &target, bool create) {
851     std::lock_guard<std::mutex> lock(shared_runtimes_mutex);
852 
853     std::vector<JITModule> result;
854 
855     JITModule m = make_module(for_module, target, MainShared, result, create);
856     if (m.compiled()) {
857         result.push_back(m);
858     }
859 
860     // Add all requested GPU modules, each only depending on the main shared runtime.
861     std::vector<JITModule> gpu_modules;
862     if (target.has_feature(Target::OpenCL)) {
863         auto kind = target.has_feature(Target::Debug) ? OpenCLDebug : OpenCL;
864         JITModule m = make_module(for_module, target, kind, result, create);
865         if (m.compiled()) {
866             result.push_back(m);
867         }
868     }
869     if (target.has_feature(Target::Metal)) {
870         auto kind = target.has_feature(Target::Debug) ? MetalDebug : Metal;
871         JITModule m = make_module(for_module, target, kind, result, create);
872         if (m.compiled()) {
873             result.push_back(m);
874         }
875     }
876     if (target.has_feature(Target::CUDA)) {
877         auto kind = target.has_feature(Target::Debug) ? CUDADebug : CUDA;
878         JITModule m = make_module(for_module, target, kind, result, create);
879         if (m.compiled()) {
880             result.push_back(m);
881         }
882     }
883     if (target.has_feature(Target::OpenGL)) {
884         auto kind = target.has_feature(Target::Debug) ? OpenGLDebug : OpenGL;
885         JITModule m = make_module(for_module, target, kind, result, create);
886         if (m.compiled()) {
887             result.push_back(m);
888         }
889     }
890     if (target.has_feature(Target::OpenGLCompute)) {
891         auto kind = target.has_feature(Target::Debug) ? OpenGLComputeDebug : OpenGLCompute;
892         JITModule m = make_module(for_module, target, kind, result, create);
893         if (m.compiled()) {
894             result.push_back(m);
895         }
896     }
897     if (target.features_any_of({Target::HVX_64, Target::HVX_128})) {
898         auto kind = target.has_feature(Target::Debug) ? HexagonDebug : Hexagon;
899         JITModule m = make_module(for_module, target, kind, result, create);
900         if (m.compiled()) {
901             result.push_back(m);
902         }
903     }
904     if (target.has_feature(Target::D3D12Compute)) {
905         auto kind = target.has_feature(Target::Debug) ? D3D12ComputeDebug : D3D12Compute;
906         JITModule m = make_module(for_module, target, kind, result, create);
907         if (m.compiled()) {
908             result.push_back(m);
909         }
910     }
911 
912     return result;
913 }
914 
915 // TODO: Either remove user_context argument figure out how to make
916 // caller provided user context work with JIT. (At present, this
917 // cascaded handler calls cannot work with the right context as
918 // JITModule needs its context to be passed in case the called handler
919 // calls another callback which is not overriden by the caller.)
init_jit_user_context(JITUserContext & jit_user_context,void * user_context,const JITHandlers & handlers)920 void JITSharedRuntime::init_jit_user_context(JITUserContext &jit_user_context,
921                                              void *user_context, const JITHandlers &handlers) {
922     jit_user_context.handlers = active_handlers;
923     jit_user_context.user_context = user_context;
924     merge_handlers(jit_user_context.handlers, handlers);
925 }
926 
release_all()927 void JITSharedRuntime::release_all() {
928     std::lock_guard<std::mutex> lock(shared_runtimes_mutex);
929 
930     for (int i = MaxRuntimeKind; i > 0; i--) {
931         shared_runtimes((RuntimeKind)(i - 1)) = JITModule();
932     }
933 }
934 
set_default_handlers(const JITHandlers & handlers)935 JITHandlers JITSharedRuntime::set_default_handlers(const JITHandlers &handlers) {
936     JITHandlers result = default_handlers;
937     default_handlers = handlers;
938     active_handlers = runtime_internal_handlers;
939     merge_handlers(active_handlers, default_handlers);
940     return result;
941 }
942 
memoization_cache_set_size(int64_t size)943 void JITSharedRuntime::memoization_cache_set_size(int64_t size) {
944     std::lock_guard<std::mutex> lock(shared_runtimes_mutex);
945 
946     if (size != default_cache_size) {
947         default_cache_size = size;
948         shared_runtimes(MainShared).memoization_cache_set_size(size);
949     }
950 }
951 
reuse_device_allocations(bool b)952 void JITSharedRuntime::reuse_device_allocations(bool b) {
953     std::lock_guard<std::mutex> lock(shared_runtimes_mutex);
954     shared_runtimes(MainShared).reuse_device_allocations(b);
955 }
956 
957 }  // namespace Internal
958 }  // namespace Halide
959