1 #include <iostream>
2 #include <limits>
3 #include <mutex>
4 #include <sstream>
5 
6 #include "CPlusPlusMangle.h"
7 #include "CSE.h"
8 #include "CodeGen_ARM.h"
9 #include "CodeGen_GPU_Host.h"
10 #include "CodeGen_Hexagon.h"
11 #include "CodeGen_Internal.h"
12 #include "CodeGen_LLVM.h"
13 #include "CodeGen_MIPS.h"
14 #include "CodeGen_PowerPC.h"
15 #include "CodeGen_RISCV.h"
16 #include "CodeGen_WebAssembly.h"
17 #include "CodeGen_X86.h"
18 #include "CompilerLogger.h"
19 #include "Debug.h"
20 #include "Deinterleave.h"
21 #include "EmulateFloat16Math.h"
22 #include "ExprUsesVar.h"
23 #include "IROperator.h"
24 #include "IRPrinter.h"
25 #include "IntegerDivisionTable.h"
26 #include "JITModule.h"
27 #include "LLVM_Headers.h"
28 #include "LLVM_Runtime_Linker.h"
29 #include "Lerp.h"
30 #include "MatlabWrapper.h"
31 #include "Pipeline.h"
32 #include "Simplify.h"
33 #include "Util.h"
34 
35 #if !(__cplusplus > 199711L || _MSC_VER >= 1800)
36 
37 // VS2013 isn't fully C++11 compatible, but it supports enough of what Halide
38 // needs for now to be an acceptable minimum for Windows.
39 #error "Halide requires C++11 or VS2013+; please upgrade your compiler."
40 
41 #endif
42 
43 namespace Halide {
44 
codegen_llvm(const Module & module,llvm::LLVMContext & context)45 std::unique_ptr<llvm::Module> codegen_llvm(const Module &module, llvm::LLVMContext &context) {
46     std::unique_ptr<Internal::CodeGen_LLVM> cg(Internal::CodeGen_LLVM::new_for_target(module.target(), context));
47     return cg->compile(module);
48 }
49 
50 namespace Internal {
51 
52 using namespace llvm;
53 using std::map;
54 using std::ostringstream;
55 using std::pair;
56 using std::string;
57 using std::vector;
58 
59 // Define a local empty inline function for each target
60 // to disable initialization.
61 #define LLVM_TARGET(target)                    \
62     inline void Initialize##target##Target() { \
63     }
64 #include <llvm/Config/Targets.def>
65 #undef LLVM_TARGET
66 
67 #define LLVM_ASM_PARSER(target)                   \
68     inline void Initialize##target##AsmParser() { \
69     }
70 #include <llvm/Config/AsmParsers.def>
71 #undef LLVM_ASM_PARSER
72 
73 #define LLVM_ASM_PRINTER(target)                   \
74     inline void Initialize##target##AsmPrinter() { \
75     }
76 #include <llvm/Config/AsmPrinters.def>
77 #undef LLVM_ASM_PRINTER
78 
79 #define InitializeTarget(target)          \
80     LLVMInitialize##target##Target();     \
81     LLVMInitialize##target##TargetInfo(); \
82     LLVMInitialize##target##TargetMC();   \
83     llvm_##target##_enabled = true;
84 
85 #define InitializeAsmParser(target) \
86     LLVMInitialize##target##AsmParser();
87 
88 #define InitializeAsmPrinter(target) \
89     LLVMInitialize##target##AsmPrinter();
90 
91 // Override above empty init function with macro for supported targets.
92 #ifdef WITH_ARM
93 #define InitializeARMTarget() InitializeTarget(ARM)
94 #define InitializeARMAsmParser() InitializeAsmParser(ARM)
95 #define InitializeARMAsmPrinter() InitializeAsmPrinter(ARM)
96 #endif
97 
98 #ifdef WITH_NVPTX
99 #define InitializeNVPTXTarget() InitializeTarget(NVPTX)
100 #define InitializeNVPTXAsmParser() InitializeAsmParser(NVPTX)
101 #define InitializeNVPTXAsmPrinter() InitializeAsmPrinter(NVPTX)
102 #endif
103 
104 #ifdef WITH_AMDGPU
105 #define InitializeAMDGPUTarget() InitializeTarget(AMDGPU)
106 #define InitializeAMDGPUAsmParser() InitializeAsmParser(AMDGPU)
107 #define InitializeAMDGPUAsmPrinter() InitializeAsmParser(AMDGPU)
108 #endif
109 
110 #ifdef WITH_AARCH64
111 #define InitializeAArch64Target() InitializeTarget(AArch64)
112 #define InitializeAArch64AsmParser() InitializeAsmParser(AArch64)
113 #define InitializeAArch64AsmPrinter() InitializeAsmPrinter(AArch64)
114 #endif
115 
116 #ifdef WITH_HEXAGON
117 #define InitializeHexagonTarget() InitializeTarget(Hexagon)
118 #define InitializeHexagonAsmParser() InitializeAsmParser(Hexagon)
119 #define InitializeHexagonAsmPrinter() InitializeAsmPrinter(Hexagon)
120 #endif
121 
122 #ifdef WITH_MIPS
123 #define InitializeMipsTarget() InitializeTarget(Mips)
124 #define InitializeMipsAsmParser() InitializeAsmParser(Mips)
125 #define InitializeMipsAsmPrinter() InitializeAsmPrinter(Mips)
126 #endif
127 
128 #ifdef WITH_POWERPC
129 #define InitializePowerPCTarget() InitializeTarget(PowerPC)
130 #define InitializePowerPCAsmParser() InitializeAsmParser(PowerPC)
131 #define InitializePowerPCAsmPrinter() InitializeAsmPrinter(PowerPC)
132 #endif
133 
134 #ifdef WITH_RISCV
135 #define InitializeRISCVTarget() InitializeTarget(RISCV)
136 #define InitializeRISCVAsmParser() InitializeAsmParser(RISCV)
137 #define InitializeRISCVAsmPrinter() InitializeAsmPrinter(RISCV)
138 #endif
139 
140 #ifdef WITH_X86
141 #define InitializeX86Target() InitializeTarget(X86)
142 #define InitializeX86AsmParser() InitializeAsmParser(X86)
143 #define InitializeX86AsmPrinter() InitializeAsmPrinter(X86)
144 #endif
145 
146 #ifdef WITH_WEBASSEMBLY
147 #define InitializeWebAssemblyTarget() InitializeTarget(WebAssembly)
148 #define InitializeWebAssemblyAsmParser() InitializeAsmParser(WebAssembly)
149 #define InitializeWebAssemblyAsmPrinter() InitializeAsmPrinter(WebAssembly)
150 #endif
151 
152 namespace {
153 
154 // Get the LLVM linkage corresponding to a Halide linkage type.
llvm_linkage(LinkageType t)155 llvm::GlobalValue::LinkageTypes llvm_linkage(LinkageType t) {
156     // TODO(dsharlet): For some reason, marking internal functions as
157     // private linkage on OSX is causing some of the static tests to
158     // fail. Figure out why so we can remove this.
159     return llvm::GlobalValue::ExternalLinkage;
160 
161     // switch (t) {
162     // case LinkageType::ExternalPlusMetadata:
163     // case LinkageType::External:
164     //     return llvm::GlobalValue::ExternalLinkage;
165     // default:
166     //     return llvm::GlobalValue::PrivateLinkage;
167     // }
168 }
169 
170 // A local helper to make an llvm value type representing
171 // alignment. Can't be declared in a header without introducing a
172 // dependence on the LLVM headers.
173 #if LLVM_VERSION >= 100
make_alignment(int a)174 llvm::Align make_alignment(int a) {
175     return llvm::Align(a);
176 }
177 #else
make_alignment(int a)178 int make_alignment(int a) {
179     return a;
180 }
181 #endif
182 
183 }  // namespace
184 
CodeGen_LLVM(Target t)185 CodeGen_LLVM::CodeGen_LLVM(Target t)
186     : function(nullptr), context(nullptr),
187       builder(nullptr),
188       value(nullptr),
189       very_likely_branch(nullptr),
190       default_fp_math_md(nullptr),
191       strict_fp_math_md(nullptr),
192       target(t),
193       void_t(nullptr), i1_t(nullptr), i8_t(nullptr),
194       i16_t(nullptr), i32_t(nullptr), i64_t(nullptr),
195       f16_t(nullptr), f32_t(nullptr), f64_t(nullptr),
196       halide_buffer_t_type(nullptr),
197       metadata_t_type(nullptr),
198       argument_t_type(nullptr),
199       scalar_value_t_type(nullptr),
200       device_interface_t_type(nullptr),
201       pseudostack_slot_t_type(nullptr),
202 
203       // Vector types. These need an LLVMContext before they can be initialized.
204       i8x8(nullptr),
205       i8x16(nullptr),
206       i8x32(nullptr),
207       i16x4(nullptr),
208       i16x8(nullptr),
209       i16x16(nullptr),
210       i32x2(nullptr),
211       i32x4(nullptr),
212       i32x8(nullptr),
213       i64x2(nullptr),
214       i64x4(nullptr),
215       f32x2(nullptr),
216       f32x4(nullptr),
217       f32x8(nullptr),
218       f64x2(nullptr),
219       f64x4(nullptr),
220 
221       // Wildcards for pattern matching
222       wild_i8x8(Variable::make(Int(8, 8), "*")),
223       wild_i16x4(Variable::make(Int(16, 4), "*")),
224       wild_i32x2(Variable::make(Int(32, 2), "*")),
225 
226       wild_u8x8(Variable::make(UInt(8, 8), "*")),
227       wild_u16x4(Variable::make(UInt(16, 4), "*")),
228       wild_u32x2(Variable::make(UInt(32, 2), "*")),
229 
230       wild_i8x16(Variable::make(Int(8, 16), "*")),
231       wild_i16x8(Variable::make(Int(16, 8), "*")),
232       wild_i32x4(Variable::make(Int(32, 4), "*")),
233       wild_i64x2(Variable::make(Int(64, 2), "*")),
234 
235       wild_u8x16(Variable::make(UInt(8, 16), "*")),
236       wild_u16x8(Variable::make(UInt(16, 8), "*")),
237       wild_u32x4(Variable::make(UInt(32, 4), "*")),
238       wild_u64x2(Variable::make(UInt(64, 2), "*")),
239 
240       wild_i8x32(Variable::make(Int(8, 32), "*")),
241       wild_i16x16(Variable::make(Int(16, 16), "*")),
242       wild_i32x8(Variable::make(Int(32, 8), "*")),
243       wild_i64x4(Variable::make(Int(64, 4), "*")),
244 
245       wild_u8x32(Variable::make(UInt(8, 32), "*")),
246       wild_u16x16(Variable::make(UInt(16, 16), "*")),
247       wild_u32x8(Variable::make(UInt(32, 8), "*")),
248       wild_u64x4(Variable::make(UInt(64, 4), "*")),
249 
250       wild_f32x2(Variable::make(Float(32, 2), "*")),
251 
252       wild_f32x4(Variable::make(Float(32, 4), "*")),
253       wild_f64x2(Variable::make(Float(64, 2), "*")),
254 
255       wild_f32x8(Variable::make(Float(32, 8), "*")),
256       wild_f64x4(Variable::make(Float(64, 4), "*")),
257 
258       wild_u1x_(Variable::make(UInt(1, 0), "*")),
259       wild_i8x_(Variable::make(Int(8, 0), "*")),
260       wild_u8x_(Variable::make(UInt(8, 0), "*")),
261       wild_i16x_(Variable::make(Int(16, 0), "*")),
262       wild_u16x_(Variable::make(UInt(16, 0), "*")),
263       wild_i32x_(Variable::make(Int(32, 0), "*")),
264       wild_u32x_(Variable::make(UInt(32, 0), "*")),
265       wild_i64x_(Variable::make(Int(64, 0), "*")),
266       wild_u64x_(Variable::make(UInt(64, 0), "*")),
267       wild_f32x_(Variable::make(Float(32, 0), "*")),
268       wild_f64x_(Variable::make(Float(64, 0), "*")),
269 
270       // Bounds of types
271       min_i8(Int(8).min()),
272       max_i8(Int(8).max()),
273       max_u8(UInt(8).max()),
274 
275       min_i16(Int(16).min()),
276       max_i16(Int(16).max()),
277       max_u16(UInt(16).max()),
278 
279       min_i32(Int(32).min()),
280       max_i32(Int(32).max()),
281       max_u32(UInt(32).max()),
282 
283       min_i64(Int(64).min()),
284       max_i64(Int(64).max()),
285       max_u64(UInt(64).max()),
286 
287       min_f32(Float(32).min()),
288       max_f32(Float(32).max()),
289 
290       min_f64(Float(64).min()),
291       max_f64(Float(64).max()),
292 
293       inside_atomic_mutex_node(false),
294       emit_atomic_stores(false),
295 
296       destructor_block(nullptr),
297       strict_float(t.has_feature(Target::StrictFloat)) {
298     initialize_llvm();
299 }
300 
301 namespace {
302 
303 template<typename T>
make_codegen(const Target & target,llvm::LLVMContext & context)304 CodeGen_LLVM *make_codegen(const Target &target,
305                            llvm::LLVMContext &context) {
306     CodeGen_LLVM *ret = new T(target);
307     ret->set_context(context);
308     return ret;
309 }
310 
311 }  // namespace
312 
set_context(llvm::LLVMContext & context)313 void CodeGen_LLVM::set_context(llvm::LLVMContext &context) {
314     this->context = &context;
315 }
316 
new_for_target(const Target & target,llvm::LLVMContext & context)317 CodeGen_LLVM *CodeGen_LLVM::new_for_target(const Target &target,
318                                            llvm::LLVMContext &context) {
319     // The awkward mapping from targets to code generators
320     if (target.features_any_of({Target::CUDA,
321                                 Target::OpenCL,
322                                 Target::OpenGL,
323                                 Target::OpenGLCompute,
324                                 Target::Metal,
325                                 Target::D3D12Compute})) {
326 #ifdef WITH_X86
327         if (target.arch == Target::X86) {
328             return make_codegen<CodeGen_GPU_Host<CodeGen_X86>>(target, context);
329         }
330 #endif
331 #if defined(WITH_ARM) || defined(WITH_AARCH64)
332         if (target.arch == Target::ARM) {
333             return make_codegen<CodeGen_GPU_Host<CodeGen_ARM>>(target, context);
334         }
335 #endif
336 #ifdef WITH_MIPS
337         if (target.arch == Target::MIPS) {
338             return make_codegen<CodeGen_GPU_Host<CodeGen_MIPS>>(target, context);
339         }
340 #endif
341 #ifdef WITH_POWERPC
342         if (target.arch == Target::POWERPC) {
343             return make_codegen<CodeGen_GPU_Host<CodeGen_PowerPC>>(target, context);
344         }
345 #endif
346 #ifdef WITH_WEBASSEMBLY
347         if (target.arch == Target::WebAssembly) {
348             return make_codegen<CodeGen_GPU_Host<CodeGen_WebAssembly>>(target, context);
349         }
350 #endif
351 #ifdef WITH_RISCV
352         if (target.arch == Target::RISCV) {
353             return make_codegen<CodeGen_GPU_Host<CodeGen_RISCV>>(target, context);
354         }
355 #endif
356         user_error << "Invalid target architecture for GPU backend: "
357                    << target.to_string() << "\n";
358         return nullptr;
359 
360     } else if (target.arch == Target::X86) {
361         return make_codegen<CodeGen_X86>(target, context);
362     } else if (target.arch == Target::ARM) {
363         return make_codegen<CodeGen_ARM>(target, context);
364     } else if (target.arch == Target::MIPS) {
365         return make_codegen<CodeGen_MIPS>(target, context);
366     } else if (target.arch == Target::POWERPC) {
367         return make_codegen<CodeGen_PowerPC>(target, context);
368     } else if (target.arch == Target::Hexagon) {
369         return make_codegen<CodeGen_Hexagon>(target, context);
370     } else if (target.arch == Target::WebAssembly) {
371         return make_codegen<CodeGen_WebAssembly>(target, context);
372     } else if (target.arch == Target::RISCV) {
373         return make_codegen<CodeGen_RISCV>(target, context);
374     }
375 
376     user_error << "Unknown target architecture: "
377                << target.to_string() << "\n";
378     return nullptr;
379 }
380 
initialize_llvm()381 void CodeGen_LLVM::initialize_llvm() {
382     static std::once_flag init_llvm_once;
383     std::call_once(init_llvm_once, []() {
384         // You can hack in command-line args to llvm with the
385         // environment variable HL_LLVM_ARGS, e.g. HL_LLVM_ARGS="-print-after-all"
386         std::string args = get_env_variable("HL_LLVM_ARGS");
387         if (!args.empty()) {
388             vector<std::string> arg_vec = split_string(args, " ");
389             vector<const char *> c_arg_vec;
390             c_arg_vec.push_back("llc");
391             for (const std::string &s : arg_vec) {
392                 c_arg_vec.push_back(s.c_str());
393             }
394             cl::ParseCommandLineOptions((int)(c_arg_vec.size()), &c_arg_vec[0], "Halide compiler\n");
395         }
396 
397         InitializeNativeTarget();
398         InitializeNativeTargetAsmPrinter();
399         InitializeNativeTargetAsmParser();
400 
401 #define LLVM_TARGET(target) \
402     Initialize##target##Target();
403 #include <llvm/Config/Targets.def>
404 #undef LLVM_TARGET
405 
406 #define LLVM_ASM_PARSER(target) \
407     Initialize##target##AsmParser();
408 #include <llvm/Config/AsmParsers.def>
409 #undef LLVM_ASM_PARSER
410 
411 #define LLVM_ASM_PRINTER(target) \
412     Initialize##target##AsmPrinter();
413 #include <llvm/Config/AsmPrinters.def>
414 #include <utility>
415 #undef LLVM_ASM_PRINTER
416     });
417 }
418 
init_context()419 void CodeGen_LLVM::init_context() {
420     // Ensure our IRBuilder is using the current context.
421     delete builder;
422     builder = new IRBuilder<>(*context);
423 
424     // Branch weights for very likely branches
425     llvm::MDBuilder md_builder(*context);
426     very_likely_branch = md_builder.createBranchWeights(1 << 30, 0);
427     default_fp_math_md = md_builder.createFPMath(0.0);
428     strict_fp_math_md = md_builder.createFPMath(0.0);
429     builder->setDefaultFPMathTag(default_fp_math_md);
430     llvm::FastMathFlags fast_flags;
431     fast_flags.setNoNaNs();
432     fast_flags.setNoInfs();
433     fast_flags.setNoSignedZeros();
434     // Don't use approximate reciprocals for division. It's too inaccurate even for Halide.
435     // fast_flags.setAllowReciprocal();
436     // Theoretically, setAllowReassoc could be setUnsafeAlgebra for earlier versions, but that
437     // turns on all the flags.
438     fast_flags.setAllowReassoc();
439     fast_flags.setAllowContract(true);
440     fast_flags.setApproxFunc();
441     builder->setFastMathFlags(fast_flags);
442 
443     // Define some types
444     void_t = llvm::Type::getVoidTy(*context);
445     i1_t = llvm::Type::getInt1Ty(*context);
446     i8_t = llvm::Type::getInt8Ty(*context);
447     i16_t = llvm::Type::getInt16Ty(*context);
448     i32_t = llvm::Type::getInt32Ty(*context);
449     i64_t = llvm::Type::getInt64Ty(*context);
450     f16_t = llvm::Type::getHalfTy(*context);
451     f32_t = llvm::Type::getFloatTy(*context);
452     f64_t = llvm::Type::getDoubleTy(*context);
453 
454     i8x8 = get_vector_type(i8_t, 8);
455     i8x16 = get_vector_type(i8_t, 16);
456     i8x32 = get_vector_type(i8_t, 32);
457     i16x4 = get_vector_type(i16_t, 4);
458     i16x8 = get_vector_type(i16_t, 8);
459     i16x16 = get_vector_type(i16_t, 16);
460     i32x2 = get_vector_type(i32_t, 2);
461     i32x4 = get_vector_type(i32_t, 4);
462     i32x8 = get_vector_type(i32_t, 8);
463     i64x2 = get_vector_type(i64_t, 2);
464     i64x4 = get_vector_type(i64_t, 4);
465     f32x2 = get_vector_type(f32_t, 2);
466     f32x4 = get_vector_type(f32_t, 4);
467     f32x8 = get_vector_type(f32_t, 8);
468     f64x2 = get_vector_type(f64_t, 2);
469     f64x4 = get_vector_type(f64_t, 4);
470 }
471 
init_module()472 void CodeGen_LLVM::init_module() {
473     init_context();
474 
475     // Start with a module containing the initial module for this target.
476     module = get_initial_module_for_target(target, context);
477 }
478 
add_external_code(const Module & halide_module)479 void CodeGen_LLVM::add_external_code(const Module &halide_module) {
480     for (const ExternalCode &code_blob : halide_module.external_code()) {
481         if (code_blob.is_for_cpu_target(get_target())) {
482             add_bitcode_to_module(context, *module, code_blob.contents(), code_blob.name());
483         }
484     }
485 }
486 
~CodeGen_LLVM()487 CodeGen_LLVM::~CodeGen_LLVM() {
488     delete builder;
489 }
490 
491 bool CodeGen_LLVM::llvm_X86_enabled = false;
492 bool CodeGen_LLVM::llvm_ARM_enabled = false;
493 bool CodeGen_LLVM::llvm_Hexagon_enabled = false;
494 bool CodeGen_LLVM::llvm_AArch64_enabled = false;
495 bool CodeGen_LLVM::llvm_NVPTX_enabled = false;
496 bool CodeGen_LLVM::llvm_Mips_enabled = false;
497 bool CodeGen_LLVM::llvm_PowerPC_enabled = false;
498 bool CodeGen_LLVM::llvm_AMDGPU_enabled = false;
499 bool CodeGen_LLVM::llvm_WebAssembly_enabled = false;
500 bool CodeGen_LLVM::llvm_RISCV_enabled = false;
501 
502 namespace {
503 
504 struct MangledNames {
505     string simple_name;
506     string extern_name;
507     string argv_name;
508     string metadata_name;
509 };
510 
get_mangled_names(const std::string & name,LinkageType linkage,NameMangling mangling,const std::vector<LoweredArgument> & args,const Target & target)511 MangledNames get_mangled_names(const std::string &name,
512                                LinkageType linkage,
513                                NameMangling mangling,
514                                const std::vector<LoweredArgument> &args,
515                                const Target &target) {
516     std::vector<std::string> namespaces;
517     MangledNames names;
518     names.simple_name = extract_namespaces(name, namespaces);
519     names.extern_name = names.simple_name;
520     names.argv_name = names.simple_name + "_argv";
521     names.metadata_name = names.simple_name + "_metadata";
522 
523     if (linkage != LinkageType::Internal &&
524         ((mangling == NameMangling::Default &&
525           target.has_feature(Target::CPlusPlusMangling)) ||
526          mangling == NameMangling::CPlusPlus)) {
527         std::vector<ExternFuncArgument> mangle_args;
528         for (const auto &arg : args) {
529             if (arg.kind == Argument::InputScalar) {
530                 mangle_args.emplace_back(make_zero(arg.type));
531             } else if (arg.kind == Argument::InputBuffer ||
532                        arg.kind == Argument::OutputBuffer) {
533                 mangle_args.emplace_back(Buffer<>());
534             }
535         }
536         names.extern_name = cplusplus_function_mangled_name(names.simple_name, namespaces, type_of<int>(), mangle_args, target);
537         halide_handle_cplusplus_type inner_type(halide_cplusplus_type_name(halide_cplusplus_type_name::Simple, "void"), {}, {},
538                                                 {halide_handle_cplusplus_type::Pointer, halide_handle_cplusplus_type::Pointer});
539         Type void_star_star(Handle(1, &inner_type));
540         names.argv_name = cplusplus_function_mangled_name(names.argv_name, namespaces, type_of<int>(), {ExternFuncArgument(make_zero(void_star_star))}, target);
541         names.metadata_name = cplusplus_function_mangled_name(names.metadata_name, namespaces, type_of<const struct halide_filter_metadata_t *>(), {}, target);
542     }
543     return names;
544 }
545 
get_mangled_names(const LoweredFunc & f,const Target & target)546 MangledNames get_mangled_names(const LoweredFunc &f, const Target &target) {
547     return get_mangled_names(f.name, f.linkage, f.name_mangling, f.args, target);
548 }
549 
550 }  // namespace
551 
signature_to_type(const ExternSignature & signature)552 llvm::FunctionType *CodeGen_LLVM::signature_to_type(const ExternSignature &signature) {
553     internal_assert(void_t != nullptr && halide_buffer_t_type != nullptr);
554     llvm::Type *ret_type =
555         signature.is_void_return() ? void_t : llvm_type_of(upgrade_type_for_argument_passing(signature.ret_type()));
556     std::vector<llvm::Type *> llvm_arg_types;
557     for (const Type &t : signature.arg_types()) {
558         if (t == type_of<struct halide_buffer_t *>()) {
559             llvm_arg_types.push_back(halide_buffer_t_type->getPointerTo());
560         } else {
561             llvm_arg_types.push_back(llvm_type_of(upgrade_type_for_argument_passing(t)));
562         }
563     }
564 
565     return llvm::FunctionType::get(ret_type, llvm_arg_types, false);
566 }
567 
568 /*static*/
compile_trampolines(const Target & target,llvm::LLVMContext & context,const std::string & suffix,const std::vector<std::pair<std::string,ExternSignature>> & externs)569 std::unique_ptr<llvm::Module> CodeGen_LLVM::compile_trampolines(
570     const Target &target,
571     llvm::LLVMContext &context,
572     const std::string &suffix,
573     const std::vector<std::pair<std::string, ExternSignature>> &externs) {
574     std::unique_ptr<CodeGen_LLVM> codegen(new_for_target(target, context));
575     codegen->init_codegen("trampolines" + suffix);
576     for (const std::pair<std::string, ExternSignature> &e : externs) {
577         const std::string &callee_name = e.first;
578         const std::string wrapper_name = callee_name + suffix;
579         llvm::FunctionType *fn_type = codegen->signature_to_type(e.second);
580         // callee might already be present for builtins, e.g. halide_print
581         llvm::Function *callee = codegen->module->getFunction(callee_name);
582         if (!callee) {
583             callee = llvm::Function::Create(fn_type, llvm::Function::ExternalLinkage, callee_name, codegen->module.get());
584         }
585         codegen->add_argv_wrapper(callee, wrapper_name, /*result_in_argv*/ true);
586     }
587     return codegen->finish_codegen();
588 }
589 
init_codegen(const std::string & name,bool any_strict_float)590 void CodeGen_LLVM::init_codegen(const std::string &name, bool any_strict_float) {
591     init_module();
592 
593     internal_assert(module && context);
594 
595     debug(1) << "Target triple of initial module: " << module->getTargetTriple() << "\n";
596 
597     module->setModuleIdentifier(name);
598 
599     // Add some target specific info to the module as metadata.
600     module->addModuleFlag(llvm::Module::Warning, "halide_use_soft_float_abi", use_soft_float_abi() ? 1 : 0);
601     module->addModuleFlag(llvm::Module::Warning, "halide_mcpu", MDString::get(*context, mcpu()));
602     module->addModuleFlag(llvm::Module::Warning, "halide_mattrs", MDString::get(*context, mattrs()));
603     module->addModuleFlag(llvm::Module::Warning, "halide_use_pic", use_pic() ? 1 : 0);
604     module->addModuleFlag(llvm::Module::Warning, "halide_per_instruction_fast_math_flags", any_strict_float);
605 
606     // Ensure some types we need are defined
607     halide_buffer_t_type = module->getTypeByName("struct.halide_buffer_t");
608     internal_assert(halide_buffer_t_type) << "Did not find halide_buffer_t in initial module";
609 
610     type_t_type = module->getTypeByName("struct.halide_type_t");
611     internal_assert(type_t_type) << "Did not find halide_type_t in initial module";
612 
613     dimension_t_type = module->getTypeByName("struct.halide_dimension_t");
614     internal_assert(dimension_t_type) << "Did not find halide_dimension_t in initial module";
615 
616     metadata_t_type = module->getTypeByName("struct.halide_filter_metadata_t");
617     internal_assert(metadata_t_type) << "Did not find halide_filter_metadata_t in initial module";
618 
619     argument_t_type = module->getTypeByName("struct.halide_filter_argument_t");
620     internal_assert(argument_t_type) << "Did not find halide_filter_argument_t in initial module";
621 
622     scalar_value_t_type = module->getTypeByName("struct.halide_scalar_value_t");
623     internal_assert(scalar_value_t_type) << "Did not find halide_scalar_value_t in initial module";
624 
625     device_interface_t_type = module->getTypeByName("struct.halide_device_interface_t");
626     internal_assert(device_interface_t_type) << "Did not find halide_device_interface_t in initial module";
627 
628     pseudostack_slot_t_type = module->getTypeByName("struct.halide_pseudostack_slot_t");
629     internal_assert(pseudostack_slot_t_type) << "Did not find halide_pseudostack_slot_t in initial module";
630 
631     semaphore_t_type = module->getTypeByName("struct.halide_semaphore_t");
632     internal_assert(semaphore_t_type) << "Did not find halide_semaphore_t in initial module";
633 
634     semaphore_acquire_t_type = module->getTypeByName("struct.halide_semaphore_acquire_t");
635     internal_assert(semaphore_acquire_t_type) << "Did not find halide_semaphore_acquire_t in initial module";
636 
637     parallel_task_t_type = module->getTypeByName("struct.halide_parallel_task_t");
638     internal_assert(parallel_task_t_type) << "Did not find halide_parallel_task_t in initial module";
639 }
640 
compile(const Module & input)641 std::unique_ptr<llvm::Module> CodeGen_LLVM::compile(const Module &input) {
642     init_codegen(input.name(), input.any_strict_float());
643 
644     internal_assert(module && context && builder)
645         << "The CodeGen_LLVM subclass should have made an initial module before calling CodeGen_LLVM::compile\n";
646 
647     add_external_code(input);
648 
649     // Generate the code for this module.
650     debug(1) << "Generating llvm bitcode...\n";
651     for (const auto &b : input.buffers()) {
652         compile_buffer(b);
653     }
654     for (const auto &f : input.functions()) {
655         const auto names = get_mangled_names(f, get_target());
656 
657         compile_func(f, names.simple_name, names.extern_name);
658 
659         // If the Func is externally visible, also create the argv wrapper and metadata.
660         // (useful for calling from JIT and other machine interfaces).
661         if (f.linkage == LinkageType::ExternalPlusMetadata) {
662             llvm::Function *wrapper = add_argv_wrapper(function, names.argv_name);
663             llvm::Function *metadata_getter = embed_metadata_getter(names.metadata_name,
664                                                                     names.simple_name, f.args, input.get_metadata_name_map());
665 
666             if (target.has_feature(Target::Matlab)) {
667                 define_matlab_wrapper(module.get(), wrapper, metadata_getter);
668             }
669         }
670     }
671 
672     debug(2) << module.get() << "\n";
673 
674     return finish_codegen();
675 }
676 
finish_codegen()677 std::unique_ptr<llvm::Module> CodeGen_LLVM::finish_codegen() {
678     // Verify the module is ok
679     internal_assert(!verifyModule(*module, &llvm::errs()));
680     debug(2) << "Done generating llvm bitcode\n";
681 
682     // Optimize
683     CodeGen_LLVM::optimize_module();
684 
685     if (target.has_feature(Target::EmbedBitcode)) {
686         std::string halide_command = "halide target=" + target.to_string();
687         embed_bitcode(module.get(), halide_command);
688     }
689 
690     // Disown the module and return it.
691     return std::move(module);
692 }
693 
begin_func(LinkageType linkage,const std::string & name,const std::string & extern_name,const std::vector<LoweredArgument> & args)694 void CodeGen_LLVM::begin_func(LinkageType linkage, const std::string &name,
695                               const std::string &extern_name, const std::vector<LoweredArgument> &args) {
696     current_function_args = args;
697 
698     // Deduce the types of the arguments to our function
699     vector<llvm::Type *> arg_types(args.size());
700     for (size_t i = 0; i < args.size(); i++) {
701         if (args[i].is_buffer()) {
702             arg_types[i] = halide_buffer_t_type->getPointerTo();
703         } else {
704             arg_types[i] = llvm_type_of(upgrade_type_for_argument_passing(args[i].type));
705         }
706     }
707     FunctionType *func_t = FunctionType::get(i32_t, arg_types, false);
708 
709     // Make our function. There may already be a declaration of it.
710     function = module->getFunction(extern_name);
711     if (!function) {
712         function = llvm::Function::Create(func_t, llvm_linkage(linkage), extern_name, module.get());
713     } else {
714         user_assert(function->isDeclaration())
715             << "Another function with the name " << extern_name
716             << " already exists in the same module\n";
717         if (func_t != function->getFunctionType()) {
718             std::cerr << "Desired function type for " << extern_name << ":\n";
719             func_t->print(dbgs(), true);
720             std::cerr << "Declared function type of " << extern_name << ":\n";
721             function->getFunctionType()->print(dbgs(), true);
722             user_error << "Cannot create a function with a declaration of mismatched type.\n";
723         }
724     }
725     set_function_attributes_for_target(function, target);
726 
727     // Mark the buffer args as no alias
728     for (size_t i = 0; i < args.size(); i++) {
729         if (args[i].is_buffer()) {
730             function->addParamAttr(i, Attribute::NoAlias);
731         }
732     }
733 
734     debug(1) << "Generating llvm bitcode prolog for function " << name << "...\n";
735 
736     // Null out the destructor block.
737     destructor_block = nullptr;
738 
739     // Make the initial basic block
740     BasicBlock *block = BasicBlock::Create(*context, "entry", function);
741     builder->SetInsertPoint(block);
742 
743     // Put the arguments in the symbol table
744     {
745         size_t i = 0;
746         for (auto &arg : function->args()) {
747             if (args[i].is_buffer()) {
748                 // Track this buffer name so that loads and stores from it
749                 // don't try to be too aligned.
750                 external_buffer.insert(args[i].name);
751                 sym_push(args[i].name + ".buffer", &arg);
752             } else {
753                 Type passed_type = upgrade_type_for_argument_passing(args[i].type);
754                 if (args[i].type != passed_type) {
755                     llvm::Value *a = builder->CreateBitCast(&arg, llvm_type_of(args[i].type));
756                     sym_push(args[i].name, a);
757                 } else {
758                     sym_push(args[i].name, &arg);
759                 }
760             }
761 
762             i++;
763         }
764     }
765 }
766 
end_func(const std::vector<LoweredArgument> & args)767 void CodeGen_LLVM::end_func(const std::vector<LoweredArgument> &args) {
768     return_with_error_code(ConstantInt::get(i32_t, 0));
769 
770     // Remove the arguments from the symbol table
771     for (size_t i = 0; i < args.size(); i++) {
772         if (args[i].is_buffer()) {
773             sym_pop(args[i].name + ".buffer");
774         } else {
775             sym_pop(args[i].name);
776         }
777     }
778 
779     internal_assert(!verifyFunction(*function, &llvm::errs()));
780 
781     current_function_args.clear();
782 }
783 
compile_func(const LoweredFunc & f,const std::string & simple_name,const std::string & extern_name)784 void CodeGen_LLVM::compile_func(const LoweredFunc &f, const std::string &simple_name,
785                                 const std::string &extern_name) {
786     // Generate the function declaration and argument unpacking code.
787     begin_func(f.linkage, simple_name, extern_name, f.args);
788 
789     // If building with MSAN, ensure that calls to halide_msan_annotate_buffer_is_initialized()
790     // happen for every output buffer if the function succeeds.
791     if (f.linkage != LinkageType::Internal &&
792         target.has_feature(Target::MSAN)) {
793         llvm::Function *annotate_buffer_fn =
794             module->getFunction("halide_msan_annotate_buffer_is_initialized_as_destructor");
795         internal_assert(annotate_buffer_fn)
796             << "Could not find halide_msan_annotate_buffer_is_initialized_as_destructor in module\n";
797         annotate_buffer_fn->addParamAttr(0, Attribute::NoAlias);
798         for (const auto &arg : f.args) {
799             if (arg.kind == Argument::OutputBuffer) {
800                 register_destructor(annotate_buffer_fn, sym_get(arg.name + ".buffer"), OnSuccess);
801             }
802         }
803     }
804 
805     // Generate the function body.
806     debug(1) << "Generating llvm bitcode for function " << f.name << "...\n";
807     f.body.accept(this);
808 
809     // Clean up and return.
810     end_func(f.args);
811 }
812 
813 // Given a range of iterators of constant ints, get a corresponding vector of llvm::Constant.
814 template<typename It>
get_constants(llvm::Type * t,It begin,It end)815 std::vector<llvm::Constant *> get_constants(llvm::Type *t, It begin, It end) {
816     std::vector<llvm::Constant *> ret;
817     for (It i = begin; i != end; i++) {
818         ret.push_back(ConstantInt::get(t, *i));
819     }
820     return ret;
821 }
822 
get_destructor_block()823 BasicBlock *CodeGen_LLVM::get_destructor_block() {
824     if (!destructor_block) {
825         // Create it if it doesn't exist.
826         IRBuilderBase::InsertPoint here = builder->saveIP();
827         destructor_block = BasicBlock::Create(*context, "destructor_block", function);
828         builder->SetInsertPoint(destructor_block);
829         // The first instruction in the destructor block is a phi node
830         // that collects the error code.
831         PHINode *error_code = builder->CreatePHI(i32_t, 0);
832 
833         // Calls to destructors will get inserted here.
834 
835         // The last instruction is the return op that returns it.
836         builder->CreateRet(error_code);
837 
838         // Jump back to where we were.
839         builder->restoreIP(here);
840     }
841     internal_assert(destructor_block->getParent() == function);
842     return destructor_block;
843 }
844 
register_destructor(llvm::Function * destructor_fn,Value * obj,DestructorType when)845 Value *CodeGen_LLVM::register_destructor(llvm::Function *destructor_fn, Value *obj, DestructorType when) {
846 
847     // Create a null-initialized stack slot to track this object
848     llvm::Type *void_ptr = i8_t->getPointerTo();
849     llvm::Value *stack_slot = create_alloca_at_entry(void_ptr, 1, true);
850 
851     // Cast the object to llvm's representation of void *
852     obj = builder->CreatePointerCast(obj, void_ptr);
853 
854     // Put it in the stack slot
855     builder->CreateStore(obj, stack_slot);
856 
857     // Passing the constant null as the object means the destructor
858     // will never get called.
859     {
860         llvm::Constant *c = dyn_cast<llvm::Constant>(obj);
861         if (c && c->isNullValue()) {
862             internal_error << "Destructors must take a non-null object\n";
863         }
864     }
865 
866     // Switch to the destructor block, and add code that cleans up
867     // this object if the contents of the stack slot is not nullptr.
868     IRBuilderBase::InsertPoint here = builder->saveIP();
869     BasicBlock *dtors = get_destructor_block();
870 
871     builder->SetInsertPoint(dtors->getFirstNonPHI());
872 
873     PHINode *error_code = dyn_cast<PHINode>(dtors->begin());
874     internal_assert(error_code) << "The destructor block is supposed to start with a phi node\n";
875 
876     llvm::Value *should_call = nullptr;
877     switch (when) {
878     case Always:
879         should_call = ConstantInt::get(i1_t, 1);
880         break;
881     case OnError:
882         should_call = builder->CreateIsNotNull(error_code);
883         break;
884     case OnSuccess:
885         should_call = builder->CreateIsNull(error_code);
886         break;
887     }
888     llvm::Function *call_destructor = module->getFunction("call_destructor");
889     internal_assert(call_destructor);
890     internal_assert(destructor_fn);
891     internal_assert(should_call);
892     Value *args[] = {get_user_context(), destructor_fn, stack_slot, should_call};
893     builder->CreateCall(call_destructor, args);
894 
895     // Switch back to the original location
896     builder->restoreIP(here);
897 
898     // Return the stack slot so that it's possible to cleanup the object early.
899     return stack_slot;
900 }
901 
trigger_destructor(llvm::Function * destructor_fn,Value * stack_slot)902 void CodeGen_LLVM::trigger_destructor(llvm::Function *destructor_fn, Value *stack_slot) {
903     llvm::Function *call_destructor = module->getFunction("call_destructor");
904     internal_assert(call_destructor);
905     internal_assert(destructor_fn);
906     stack_slot = builder->CreatePointerCast(stack_slot, i8_t->getPointerTo()->getPointerTo());
907     Value *should_call = ConstantInt::get(i1_t, 1);
908     Value *args[] = {get_user_context(), destructor_fn, stack_slot, should_call};
909     builder->CreateCall(call_destructor, args);
910 }
911 
compile_buffer(const Buffer<> & buf)912 void CodeGen_LLVM::compile_buffer(const Buffer<> &buf) {
913     // Embed the buffer declaration as a global.
914     internal_assert(buf.defined());
915 
916     user_assert(buf.data())
917         << "Can't embed buffer " << buf.name() << " because it has a null host pointer.\n";
918     user_assert(!buf.device_dirty())
919         << "Can't embed Image \"" << buf.name() << "\""
920         << " because it has a dirty device pointer\n";
921 
922     Constant *type_fields[] = {
923         ConstantInt::get(i8_t, buf.type().code()),
924         ConstantInt::get(i8_t, buf.type().bits()),
925         ConstantInt::get(i16_t, buf.type().lanes())};
926 
927     Constant *shape = nullptr;
928     if (buf.dimensions()) {
929         size_t shape_size = buf.dimensions() * sizeof(halide_dimension_t);
930         vector<char> shape_blob((char *)buf.raw_buffer()->dim, (char *)buf.raw_buffer()->dim + shape_size);
931         shape = create_binary_blob(shape_blob, buf.name() + ".shape");
932         shape = ConstantExpr::getPointerCast(shape, dimension_t_type->getPointerTo());
933     } else {
934         shape = ConstantPointerNull::get(dimension_t_type->getPointerTo());
935     }
936 
937     // For now, we assume buffers that aren't scalar are constant,
938     // while scalars can be mutated. This accommodates all our existing
939     // use cases, which is that all buffers are constant, except those
940     // used to store stateful module information in offloading runtimes.
941     bool constant = buf.dimensions() != 0;
942 
943     vector<char> data_blob((const char *)buf.data(), (const char *)buf.data() + buf.size_in_bytes());
944 
945     Constant *fields[] = {
946         ConstantInt::get(i64_t, 0),                                         // device
947         ConstantPointerNull::get(device_interface_t_type->getPointerTo()),  // device_interface
948         create_binary_blob(data_blob, buf.name() + ".data", constant),      // host
949         ConstantInt::get(i64_t, halide_buffer_flag_host_dirty),             // flags
950         ConstantStruct::get(type_t_type, type_fields),                      // type
951         ConstantInt::get(i32_t, buf.dimensions()),                          // dimensions
952         shape,                                                              // dim
953         ConstantPointerNull::get(i8_t->getPointerTo()),                     // padding
954     };
955     Constant *buffer_struct = ConstantStruct::get(halide_buffer_t_type, fields);
956 
957     // Embed the halide_buffer_t and make it point to the data array.
958     GlobalVariable *global = new GlobalVariable(*module, halide_buffer_t_type,
959                                                 false, GlobalValue::PrivateLinkage,
960                                                 0, buf.name() + ".buffer");
961     global->setInitializer(buffer_struct);
962 
963     // Finally, dump it in the symbol table
964     Constant *zero[] = {ConstantInt::get(i32_t, 0)};
965     Constant *global_ptr = ConstantExpr::getInBoundsGetElementPtr(halide_buffer_t_type, global, zero);
966     sym_push(buf.name() + ".buffer", global_ptr);
967 }
968 
embed_constant_scalar_value_t(const Expr & e)969 Constant *CodeGen_LLVM::embed_constant_scalar_value_t(const Expr &e) {
970     if (!e.defined()) {
971         return Constant::getNullValue(scalar_value_t_type->getPointerTo());
972     }
973 
974     internal_assert(!e.type().is_handle()) << "Should never see Handle types here.";
975 
976     llvm::Value *val = codegen(e);
977     llvm::Constant *constant = dyn_cast<llvm::Constant>(val);
978     internal_assert(constant);
979 
980     // Verify that the size of the LLVM value is the size we expected.
981     internal_assert((uint64_t)constant->getType()->getPrimitiveSizeInBits() == (uint64_t)e.type().bits());
982 
983     // It's important that we allocate a full scalar_value_t_type here,
984     // even if the type of the value is smaller; downstream consumers should
985     // be able to correctly load an entire scalar_value_t_type regardless of its
986     // type, and if we emit just (say) a uint8 value here, the pointer may be
987     // misaligned and/or the storage after may be unmapped. LLVM doesn't support
988     // unions directly, so we'll fake it by making a constant array of the elements
989     // we need, setting the first to the constant we want, and setting the rest
990     // to all-zeros. (This happens to work because sizeof(halide_scalar_value_t) is evenly
991     // divisible by sizeof(any-union-field.)
992 
993     const size_t value_size = e.type().bytes();
994     internal_assert(value_size > 0 && value_size <= sizeof(halide_scalar_value_t));
995 
996     const size_t array_size = sizeof(halide_scalar_value_t) / value_size;
997     internal_assert(array_size * value_size == sizeof(halide_scalar_value_t));
998 
999     vector<Constant *> array_entries(array_size, Constant::getNullValue(constant->getType()));
1000     array_entries[0] = constant;
1001 
1002     llvm::ArrayType *array_type = ArrayType::get(constant->getType(), array_size);
1003     GlobalVariable *storage = new GlobalVariable(
1004         *module,
1005         array_type,
1006         /*isConstant*/ true,
1007         GlobalValue::PrivateLinkage,
1008         ConstantArray::get(array_type, array_entries));
1009 
1010     // Ensure that the storage is aligned for halide_scalar_value_t
1011     storage->setAlignment(make_alignment((int)sizeof(halide_scalar_value_t)));
1012 
1013     Constant *zero[] = {ConstantInt::get(i32_t, 0)};
1014     return ConstantExpr::getBitCast(
1015         ConstantExpr::getInBoundsGetElementPtr(array_type, storage, zero),
1016         scalar_value_t_type->getPointerTo());
1017 }
1018 
embed_constant_expr(Expr e,llvm::Type * t)1019 Constant *CodeGen_LLVM::embed_constant_expr(Expr e, llvm::Type *t) {
1020     internal_assert(t != scalar_value_t_type);
1021 
1022     if (!e.defined()) {
1023         return Constant::getNullValue(t->getPointerTo());
1024     }
1025 
1026     internal_assert(!e.type().is_handle()) << "Should never see Handle types here.";
1027     if (!is_const(e)) {
1028         e = simplify(e);
1029         internal_assert(is_const(e)) << "Should only see constant values for estimates.";
1030     }
1031 
1032     llvm::Value *val = codegen(e);
1033     llvm::Constant *constant = dyn_cast<llvm::Constant>(val);
1034     internal_assert(constant);
1035 
1036     GlobalVariable *storage = new GlobalVariable(
1037         *module,
1038         constant->getType(),
1039         /*isConstant*/ true,
1040         GlobalValue::PrivateLinkage,
1041         constant);
1042 
1043     Constant *zero[] = {ConstantInt::get(i32_t, 0)};
1044     return ConstantExpr::getBitCast(
1045         ConstantExpr::getInBoundsGetElementPtr(constant->getType(), storage, zero),
1046         t->getPointerTo());
1047 }
1048 
1049 // Make a wrapper to call the function with an array of pointer
1050 // args. This is easier for the JIT to call than a function with an
1051 // unknown (at compile time) argument list. If result_in_argv is false,
1052 // the internal function result is returned as the wrapper function
1053 // result; if result_in_argv is true, the internal function result
1054 // is stored as the last item in the argv list (which must be one
1055 // longer than the number of arguments), and the wrapper's actual
1056 // return type is always 'void'.
add_argv_wrapper(llvm::Function * fn,const std::string & name,bool result_in_argv)1057 llvm::Function *CodeGen_LLVM::add_argv_wrapper(llvm::Function *fn,
1058                                                const std::string &name,
1059                                                bool result_in_argv) {
1060     llvm::Type *wrapper_result_type = result_in_argv ? void_t : i32_t;
1061     llvm::Type *wrapper_args_t[] = {i8_t->getPointerTo()->getPointerTo()};
1062     llvm::FunctionType *wrapper_func_t = llvm::FunctionType::get(wrapper_result_type, wrapper_args_t, false);
1063     llvm::Function *wrapper_func = llvm::Function::Create(wrapper_func_t, llvm::GlobalValue::ExternalLinkage, name, module.get());
1064     llvm::BasicBlock *wrapper_block = llvm::BasicBlock::Create(module->getContext(), "entry", wrapper_func);
1065     builder->SetInsertPoint(wrapper_block);
1066 
1067     llvm::Value *arg_array = iterator_to_pointer(wrapper_func->arg_begin());
1068     std::vector<llvm::Value *> wrapper_args;
1069     for (llvm::Function::arg_iterator i = fn->arg_begin(); i != fn->arg_end(); i++) {
1070         // Get the address of the nth argument
1071         llvm::Value *ptr = builder->CreateConstGEP1_32(arg_array, wrapper_args.size());
1072         ptr = builder->CreateLoad(ptr);
1073         if (i->getType() == halide_buffer_t_type->getPointerTo()) {
1074             // Cast the argument to a halide_buffer_t *
1075             wrapper_args.push_back(builder->CreatePointerCast(ptr, halide_buffer_t_type->getPointerTo()));
1076         } else {
1077             // Cast to the appropriate type and load
1078             ptr = builder->CreatePointerCast(ptr, i->getType()->getPointerTo());
1079             wrapper_args.push_back(builder->CreateLoad(ptr));
1080         }
1081     }
1082     debug(4) << "Creating call from wrapper to actual function\n";
1083     llvm::CallInst *result = builder->CreateCall(fn, wrapper_args);
1084     // This call should never inline
1085     result->setIsNoInline();
1086 
1087     if (result_in_argv) {
1088         llvm::Value *result_in_argv_ptr = builder->CreateConstGEP1_32(arg_array, wrapper_args.size());
1089         if (fn->getReturnType() != void_t) {
1090             result_in_argv_ptr = builder->CreateLoad(result_in_argv_ptr);
1091             // Cast to the appropriate type and store
1092             result_in_argv_ptr = builder->CreatePointerCast(result_in_argv_ptr, fn->getReturnType()->getPointerTo());
1093             builder->CreateStore(result, result_in_argv_ptr);
1094         }
1095         builder->CreateRetVoid();
1096     } else {
1097         // We could probably support other types as return values,
1098         // but int32 results are all that have actually been tested.
1099         internal_assert(fn->getReturnType() == i32_t);
1100         builder->CreateRet(result);
1101     }
1102     internal_assert(!verifyFunction(*wrapper_func, &llvm::errs()));
1103     return wrapper_func;
1104 }
1105 
embed_metadata_getter(const std::string & metadata_name,const std::string & function_name,const std::vector<LoweredArgument> & args,const std::map<std::string,std::string> & metadata_name_map)1106 llvm::Function *CodeGen_LLVM::embed_metadata_getter(const std::string &metadata_name,
1107                                                     const std::string &function_name, const std::vector<LoweredArgument> &args,
1108                                                     const std::map<std::string, std::string> &metadata_name_map) {
1109     Constant *zero = ConstantInt::get(i32_t, 0);
1110 
1111     const int num_args = (int)args.size();
1112 
1113     auto map_string = [&metadata_name_map](const std::string &from) -> std::string {
1114         auto it = metadata_name_map.find(from);
1115         return it == metadata_name_map.end() ? from : it->second;
1116     };
1117 
1118     vector<Constant *> arguments_array_entries;
1119     for (int arg = 0; arg < num_args; ++arg) {
1120 
1121         StructType *type_t_type = module->getTypeByName("struct.halide_type_t");
1122         internal_assert(type_t_type) << "Did not find halide_type_t in module.\n";
1123 
1124         Constant *type_fields[] = {
1125             ConstantInt::get(i8_t, args[arg].type.code()),
1126             ConstantInt::get(i8_t, args[arg].type.bits()),
1127             ConstantInt::get(i16_t, 1)};
1128         Constant *type = ConstantStruct::get(type_t_type, type_fields);
1129 
1130         auto argument_estimates = args[arg].argument_estimates;
1131         if (args[arg].type.is_handle()) {
1132             // Handle values are always emitted into metadata as "undefined", regardless of
1133             // what sort of Expr is provided.
1134             argument_estimates = ArgumentEstimates{};
1135         }
1136 
1137         Constant *buffer_estimates_array_ptr;
1138         if (args[arg].is_buffer() && !argument_estimates.buffer_estimates.empty()) {
1139             internal_assert((int)argument_estimates.buffer_estimates.size() == args[arg].dimensions);
1140             vector<Constant *> buffer_estimates_array_entries;
1141             for (const auto &be : argument_estimates.buffer_estimates) {
1142                 Expr min = be.min;
1143                 if (min.defined()) min = cast<int64_t>(min);
1144                 Expr extent = be.extent;
1145                 if (extent.defined()) extent = cast<int64_t>(extent);
1146                 buffer_estimates_array_entries.push_back(embed_constant_expr(min, i64_t));
1147                 buffer_estimates_array_entries.push_back(embed_constant_expr(extent, i64_t));
1148             }
1149 
1150             llvm::ArrayType *buffer_estimates_array = ArrayType::get(i64_t->getPointerTo(), buffer_estimates_array_entries.size());
1151             GlobalVariable *buffer_estimates_array_storage = new GlobalVariable(
1152                 *module,
1153                 buffer_estimates_array,
1154                 /*isConstant*/ true,
1155                 GlobalValue::PrivateLinkage,
1156                 ConstantArray::get(buffer_estimates_array, buffer_estimates_array_entries));
1157 
1158             Value *zeros[] = {zero, zero};
1159             buffer_estimates_array_ptr = ConstantExpr::getInBoundsGetElementPtr(buffer_estimates_array, buffer_estimates_array_storage, zeros);
1160         } else {
1161             buffer_estimates_array_ptr = Constant::getNullValue(i64_t->getPointerTo()->getPointerTo());
1162         }
1163 
1164         Constant *argument_fields[] = {
1165             create_string_constant(map_string(args[arg].name)),
1166             ConstantInt::get(i32_t, args[arg].kind),
1167             ConstantInt::get(i32_t, args[arg].dimensions),
1168             type,
1169             embed_constant_scalar_value_t(argument_estimates.scalar_def),
1170             embed_constant_scalar_value_t(argument_estimates.scalar_min),
1171             embed_constant_scalar_value_t(argument_estimates.scalar_max),
1172             embed_constant_scalar_value_t(argument_estimates.scalar_estimate),
1173             buffer_estimates_array_ptr};
1174         arguments_array_entries.push_back(ConstantStruct::get(argument_t_type, argument_fields));
1175     }
1176     llvm::ArrayType *arguments_array = ArrayType::get(argument_t_type, num_args);
1177     GlobalVariable *arguments_array_storage = new GlobalVariable(
1178         *module,
1179         arguments_array,
1180         /*isConstant*/ true,
1181         GlobalValue::PrivateLinkage,
1182         ConstantArray::get(arguments_array, arguments_array_entries));
1183 
1184     Constant *version = ConstantInt::get(i32_t, halide_filter_metadata_t::VERSION);
1185 
1186     Value *zeros[] = {zero, zero};
1187     Constant *metadata_fields[] = {
1188         /* version */ version,
1189         /* num_arguments */ ConstantInt::get(i32_t, num_args),
1190         /* arguments */ ConstantExpr::getInBoundsGetElementPtr(arguments_array, arguments_array_storage, zeros),
1191         /* target */ create_string_constant(map_string(target.to_string())),
1192         /* name */ create_string_constant(map_string(function_name))};
1193 
1194     GlobalVariable *metadata_storage = new GlobalVariable(
1195         *module,
1196         metadata_t_type,
1197         /*isConstant*/ true,
1198         GlobalValue::PrivateLinkage,
1199         ConstantStruct::get(metadata_t_type, metadata_fields),
1200         metadata_name + "_storage");
1201 
1202     llvm::FunctionType *func_t = llvm::FunctionType::get(metadata_t_type->getPointerTo(), false);
1203     llvm::Function *metadata_getter = llvm::Function::Create(func_t, llvm::GlobalValue::ExternalLinkage, metadata_name, module.get());
1204     llvm::BasicBlock *block = llvm::BasicBlock::Create(module->getContext(), "entry", metadata_getter);
1205     builder->SetInsertPoint(block);
1206     builder->CreateRet(metadata_storage);
1207     internal_assert(!verifyFunction(*metadata_getter, &llvm::errs()));
1208 
1209     return metadata_getter;
1210 }
1211 
llvm_type_of(const Type & t) const1212 llvm::Type *CodeGen_LLVM::llvm_type_of(const Type &t) const {
1213     return Internal::llvm_type_of(context, t);
1214 }
1215 
optimize_module()1216 void CodeGen_LLVM::optimize_module() {
1217     debug(3) << "Optimizing module\n";
1218 
1219     auto time_start = std::chrono::high_resolution_clock::now();
1220 
1221     if (debug::debug_level() >= 3) {
1222         module->print(dbgs(), nullptr, false, true);
1223     }
1224 
1225     std::unique_ptr<TargetMachine> tm = make_target_machine(*module);
1226 
1227     // At present, we default to *enabling* LLVM loop optimization,
1228     // unless DisableLLVMLoopOpt is set; we're going to flip this to defaulting
1229     // to *not* enabling these optimizations (and removing the DisableLLVMLoopOpt feature).
1230     // See https://github.com/halide/Halide/issues/4113 for more info.
1231     // (Note that setting EnableLLVMLoopOpt always enables loop opt, regardless
1232     // of the setting of DisableLLVMLoopOpt.)
1233     const bool do_loop_opt = !get_target().has_feature(Target::DisableLLVMLoopOpt) ||
1234                              get_target().has_feature(Target::EnableLLVMLoopOpt);
1235 
1236     PipelineTuningOptions pto;
1237     pto.LoopInterleaving = do_loop_opt;
1238     pto.LoopVectorization = do_loop_opt;
1239     pto.SLPVectorization = true;  // Note: SLP vectorization has no analogue in the Halide scheduling model
1240     pto.LoopUnrolling = do_loop_opt;
1241     // Clear ScEv info for all loops. Certain Halide applications spend a very
1242     // long time compiling in forgetLoop, and prefer to forget everything
1243     // and rebuild SCEV (aka "Scalar Evolution") from scratch.
1244     // Sample difference in compile time reduction at the time of this change was
1245     // 21.04 -> 14.78 using current ToT release build. (See also https://reviews.llvm.org/rL358304)
1246     pto.ForgetAllSCEVInLoopUnroll = true;
1247 
1248     llvm::PassBuilder pb(tm.get(), pto);
1249 
1250     bool debug_pass_manager = false;
1251     // These analysis managers have to be declared in this order.
1252     llvm::LoopAnalysisManager lam(debug_pass_manager);
1253     llvm::FunctionAnalysisManager fam(debug_pass_manager);
1254     llvm::CGSCCAnalysisManager cgam(debug_pass_manager);
1255     llvm::ModuleAnalysisManager mam(debug_pass_manager);
1256 
1257     llvm::AAManager aa = pb.buildDefaultAAPipeline();
1258     fam.registerPass([&] { return std::move(aa); });
1259 
1260     // Register all the basic analyses with the managers.
1261     pb.registerModuleAnalyses(mam);
1262     pb.registerCGSCCAnalyses(cgam);
1263     pb.registerFunctionAnalyses(fam);
1264     pb.registerLoopAnalyses(lam);
1265     pb.crossRegisterProxies(lam, fam, cgam, mam);
1266     ModulePassManager mpm(debug_pass_manager);
1267 
1268     PassBuilder::OptimizationLevel level = PassBuilder::OptimizationLevel::O3;
1269 
1270     if (get_target().has_feature(Target::ASAN)) {
1271         pb.registerPipelineStartEPCallback([&](ModulePassManager &mpm) {
1272             mpm.addPass(
1273                 RequireAnalysisPass<ASanGlobalsMetadataAnalysis, llvm::Module>());
1274         });
1275 #if LLVM_VERSION >= 110
1276         pb.registerOptimizerLastEPCallback(
1277             [](ModulePassManager &mpm, PassBuilder::OptimizationLevel level) {
1278                 constexpr bool compile_kernel = false;
1279                 constexpr bool recover = false;
1280                 constexpr bool use_after_scope = true;
1281                 mpm.addPass(createModuleToFunctionPassAdaptor(AddressSanitizerPass(
1282                     compile_kernel, recover, use_after_scope)));
1283             });
1284 #else
1285         pb.registerOptimizerLastEPCallback(
1286             [](FunctionPassManager &fpm, PassBuilder::OptimizationLevel level) {
1287                 constexpr bool compile_kernel = false;
1288                 constexpr bool recover = false;
1289                 constexpr bool use_after_scope = true;
1290                 fpm.addPass(AddressSanitizerPass(
1291                     compile_kernel, recover, use_after_scope));
1292             });
1293 #endif
1294         pb.registerPipelineStartEPCallback(
1295             [](ModulePassManager &mpm) {
1296                 constexpr bool compile_kernel = false;
1297                 constexpr bool recover = false;
1298                 constexpr bool module_use_after_scope = false;
1299                 constexpr bool use_odr_indicator = true;
1300                 mpm.addPass(ModuleAddressSanitizerPass(
1301                     compile_kernel, recover, module_use_after_scope,
1302                     use_odr_indicator));
1303             });
1304     }
1305 
1306     if (get_target().has_feature(Target::TSAN)) {
1307 #if LLVM_VERSION >= 110
1308         pb.registerOptimizerLastEPCallback(
1309             [](ModulePassManager &mpm, PassBuilder::OptimizationLevel level) {
1310                 mpm.addPass(
1311                     createModuleToFunctionPassAdaptor(ThreadSanitizerPass()));
1312             });
1313 #else
1314         pb.registerOptimizerLastEPCallback(
1315             [](FunctionPassManager &fpm, PassBuilder::OptimizationLevel level) {
1316                 fpm.addPass(ThreadSanitizerPass());
1317             });
1318 #endif
1319     }
1320 
1321     for (llvm::Module::iterator i = module->begin(); i != module->end(); i++) {
1322         if (get_target().has_feature(Target::ASAN)) {
1323             i->addFnAttr(Attribute::SanitizeAddress);
1324         }
1325         if (get_target().has_feature(Target::TSAN)) {
1326             // Do not annotate any of Halide's low-level synchronization code as it has
1327             // tsan interface calls to mark its behavior and is much faster if
1328             // it is not analyzed instruction by instruction.
1329             if (!(i->getName().startswith("_ZN6Halide7Runtime8Internal15Synchronization") ||
1330                   // TODO: this is a benign data race that re-initializes the detected features;
1331                   // we should really fix it properly inside the implementation, rather than disabling
1332                   // it here as a band-aid.
1333                   i->getName().startswith("halide_default_can_use_target_features") ||
1334                   i->getName().startswith("halide_mutex_") ||
1335                   i->getName().startswith("halide_cond_"))) {
1336                 i->addFnAttr(Attribute::SanitizeThread);
1337             }
1338         }
1339     }
1340 
1341     mpm = pb.buildPerModuleDefaultPipeline(level, debug_pass_manager);
1342     mpm.run(*module, mam);
1343 
1344     if (llvm::verifyModule(*module, &errs()))
1345         report_fatal_error("Transformation resulted in an invalid module\n");
1346 
1347     debug(3) << "After LLVM optimizations:\n";
1348     if (debug::debug_level() >= 2) {
1349         module->print(dbgs(), nullptr, false, true);
1350     }
1351 
1352     auto *logger = get_compiler_logger();
1353     if (logger) {
1354         auto time_end = std::chrono::high_resolution_clock::now();
1355         std::chrono::duration<double> diff = time_end - time_start;
1356         logger->record_compilation_time(CompilerLogger::Phase::LLVM, diff.count());
1357     }
1358 }
1359 
sym_push(const string & name,llvm::Value * value)1360 void CodeGen_LLVM::sym_push(const string &name, llvm::Value *value) {
1361     if (!value->getType()->isVoidTy()) {
1362         value->setName(name);
1363     }
1364     symbol_table.push(name, value);
1365 }
1366 
sym_pop(const string & name)1367 void CodeGen_LLVM::sym_pop(const string &name) {
1368     symbol_table.pop(name);
1369 }
1370 
sym_get(const string & name,bool must_succeed) const1371 llvm::Value *CodeGen_LLVM::sym_get(const string &name, bool must_succeed) const {
1372     // look in the symbol table
1373     if (!symbol_table.contains(name)) {
1374         if (must_succeed) {
1375             std::ostringstream err;
1376             err << "Symbol not found: " << name << "\n";
1377 
1378             if (debug::debug_level() > 0) {
1379                 err << "The following names are in scope:\n"
1380                     << symbol_table << "\n";
1381             }
1382 
1383             internal_error << err.str();
1384         } else {
1385             return nullptr;
1386         }
1387     }
1388     return symbol_table.get(name);
1389 }
1390 
sym_exists(const string & name) const1391 bool CodeGen_LLVM::sym_exists(const string &name) const {
1392     return symbol_table.contains(name);
1393 }
1394 
codegen(const Expr & e)1395 Value *CodeGen_LLVM::codegen(const Expr &e) {
1396     internal_assert(e.defined());
1397     debug(4) << "Codegen: " << e.type() << ", " << e << "\n";
1398     value = nullptr;
1399     e.accept(this);
1400     internal_assert(value) << "Codegen of an expr did not produce an llvm value\n";
1401 
1402     // Halide's type system doesn't distinguish between scalars and
1403     // vectors of size 1, so if a codegen method returned a vector of
1404     // size one, just extract it out as a scalar.
1405     if (e.type().is_scalar() &&
1406         value->getType()->isVectorTy()) {
1407         internal_assert(get_vector_num_elements(value->getType()) == 1);
1408         value = builder->CreateExtractElement(value, ConstantInt::get(i32_t, 0));
1409     }
1410 
1411     // TODO: skip this correctness check for bool vectors,
1412     // as eliminate_bool_vectors() will cause a discrepancy for some backends
1413     // (eg OpenCL, HVX, WASM); for now we're just ignoring the assert, but
1414     // in the long run we should improve the smarts. See https://github.com/halide/Halide/issues/4194.
1415     const bool is_bool_vector = e.type().is_bool() && e.type().lanes() > 1;
1416     // TODO: skip this correctness check for prefetch, because the return type
1417     // of prefetch indicates the type being prefetched, which does not match the
1418     // implementation of prefetch.
1419     // See https://github.com/halide/Halide/issues/4211.
1420     const bool is_prefetch = e.as<Call>() && e.as<Call>()->is_intrinsic(Call::prefetch);
1421     internal_assert(is_bool_vector || is_prefetch ||
1422                     e.type().is_handle() ||
1423                     value->getType()->isVoidTy() ||
1424                     value->getType() == llvm_type_of(e.type()))
1425         << "Codegen of Expr " << e
1426         << " of type " << e.type()
1427         << " did not produce llvm IR of the corresponding llvm type.\n";
1428     return value;
1429 }
1430 
codegen(const Stmt & s)1431 void CodeGen_LLVM::codegen(const Stmt &s) {
1432     internal_assert(s.defined());
1433     debug(3) << "Codegen: " << s << "\n";
1434     value = nullptr;
1435     s.accept(this);
1436 }
1437 
upgrade_type_for_arithmetic(const Type & t) const1438 Type CodeGen_LLVM::upgrade_type_for_arithmetic(const Type &t) const {
1439     if (t.is_bfloat() || (t.is_float() && t.bits() < 32)) {
1440         return Float(32, t.lanes());
1441     } else {
1442         return t;
1443     }
1444 }
1445 
upgrade_type_for_argument_passing(const Type & t) const1446 Type CodeGen_LLVM::upgrade_type_for_argument_passing(const Type &t) const {
1447     if (t.is_bfloat() || (t.is_float() && t.bits() < 32)) {
1448         return t.with_code(halide_type_uint);
1449     } else {
1450         return t;
1451     }
1452 }
1453 
upgrade_type_for_storage(const Type & t) const1454 Type CodeGen_LLVM::upgrade_type_for_storage(const Type &t) const {
1455     if (t.is_bfloat() || (t.is_float() && t.bits() < 32)) {
1456         return t.with_code(halide_type_uint);
1457     } else if (t.is_bool()) {
1458         return t.with_bits(8);
1459     } else if (t.is_handle()) {
1460         return UInt(64, t.lanes());
1461     } else {
1462         return t;
1463     }
1464 }
1465 
visit(const IntImm * op)1466 void CodeGen_LLVM::visit(const IntImm *op) {
1467     value = ConstantInt::getSigned(llvm_type_of(op->type), op->value);
1468 }
1469 
visit(const UIntImm * op)1470 void CodeGen_LLVM::visit(const UIntImm *op) {
1471     value = ConstantInt::get(llvm_type_of(op->type), op->value);
1472 }
1473 
visit(const FloatImm * op)1474 void CodeGen_LLVM::visit(const FloatImm *op) {
1475     if (op->type.is_bfloat()) {
1476         codegen(reinterpret(BFloat(16), make_const(UInt(16), bfloat16_t(op->value).to_bits())));
1477     } else if (op->type.bits() == 16) {
1478         codegen(reinterpret(Float(16), make_const(UInt(16), float16_t(op->value).to_bits())));
1479     } else {
1480         value = ConstantFP::get(llvm_type_of(op->type), op->value);
1481     }
1482 }
1483 
visit(const StringImm * op)1484 void CodeGen_LLVM::visit(const StringImm *op) {
1485     value = create_string_constant(op->value);
1486 }
1487 
visit(const Cast * op)1488 void CodeGen_LLVM::visit(const Cast *op) {
1489     Halide::Type src = op->value.type();
1490     Halide::Type dst = op->type;
1491 
1492     if (upgrade_type_for_arithmetic(src) != src ||
1493         upgrade_type_for_arithmetic(dst) != dst) {
1494         // Handle casts to and from types for which we don't have native support.
1495         debug(4) << "Emulating cast from " << src << " to " << dst << "\n";
1496         if ((src.is_float() && src.bits() < 32) ||
1497             (dst.is_float() && dst.bits() < 32)) {
1498             Expr equiv = lower_float16_cast(op);
1499             internal_assert(equiv.type() == op->type);
1500             codegen(equiv);
1501         } else {
1502             internal_error << "Cast from type: " << src
1503                            << " to " << dst
1504                            << " unimplemented\n";
1505         }
1506         return;
1507     }
1508 
1509     value = codegen(op->value);
1510     llvm::Type *llvm_dst = llvm_type_of(dst);
1511 
1512     if (dst.is_handle() && src.is_handle()) {
1513         value = builder->CreateBitCast(value, llvm_dst);
1514     } else if (dst.is_handle() || src.is_handle()) {
1515         internal_error << "Can't cast from " << src << " to " << dst << "\n";
1516     } else if (!src.is_float() && !dst.is_float()) {
1517         // Widening integer casts either zero extend or sign extend,
1518         // depending on the source type. Narrowing integer casts
1519         // always truncate.
1520         value = builder->CreateIntCast(value, llvm_dst, src.is_int());
1521     } else if (src.is_float() && dst.is_int()) {
1522         value = builder->CreateFPToSI(value, llvm_dst);
1523     } else if (src.is_float() && dst.is_uint()) {
1524         // fptoui has undefined behavior on overflow. Seems reasonable
1525         // to get an unspecified uint on overflow, but because uint1s
1526         // are stored in uint8s for float->uint1 casts this undefined
1527         // behavior manifests itself as uint1 values greater than 1,
1528         // which could in turn break our bounds inference
1529         // guarantees. So go via uint8 in this case.
1530         if (dst.bits() < 8) {
1531             value = builder->CreateFPToUI(value, llvm_type_of(dst.with_bits(8)));
1532             value = builder->CreateIntCast(value, llvm_dst, false);
1533         } else {
1534             value = builder->CreateFPToUI(value, llvm_dst);
1535         }
1536     } else if (src.is_int() && dst.is_float()) {
1537         value = builder->CreateSIToFP(value, llvm_dst);
1538     } else if (src.is_uint() && dst.is_float()) {
1539         value = builder->CreateUIToFP(value, llvm_dst);
1540     } else {
1541         internal_assert(src.is_float() && dst.is_float());
1542         // Float widening or narrowing
1543         value = builder->CreateFPCast(value, llvm_dst);
1544     }
1545 }
1546 
visit(const Variable * op)1547 void CodeGen_LLVM::visit(const Variable *op) {
1548     value = sym_get(op->name);
1549 }
1550 
1551 template<typename Op>
try_to_fold_vector_reduce(const Op * op)1552 bool CodeGen_LLVM::try_to_fold_vector_reduce(const Op *op) {
1553     const VectorReduce *red = op->a.template as<VectorReduce>();
1554     Expr b = op->b;
1555     if (!red) {
1556         red = op->b.template as<VectorReduce>();
1557         b = op->a;
1558     }
1559     if (red &&
1560         ((std::is_same<Op, Add>::value && red->op == VectorReduce::Add) ||
1561          (std::is_same<Op, Min>::value && red->op == VectorReduce::Min) ||
1562          (std::is_same<Op, Max>::value && red->op == VectorReduce::Max) ||
1563          (std::is_same<Op, Mul>::value && red->op == VectorReduce::Mul) ||
1564          (std::is_same<Op, And>::value && red->op == VectorReduce::And) ||
1565          (std::is_same<Op, Or>::value && red->op == VectorReduce::Or))) {
1566         codegen_vector_reduce(red, b);
1567         return true;
1568     }
1569     return false;
1570 }
1571 
visit(const Add * op)1572 void CodeGen_LLVM::visit(const Add *op) {
1573     Type t = upgrade_type_for_arithmetic(op->type);
1574     if (t != op->type) {
1575         codegen(cast(op->type, Add::make(cast(t, op->a), cast(t, op->b))));
1576         return;
1577     }
1578 
1579     // Some backends can fold the add into a vector reduce
1580     if (try_to_fold_vector_reduce(op)) {
1581         return;
1582     }
1583 
1584     Value *a = codegen(op->a);
1585     Value *b = codegen(op->b);
1586     if (op->type.is_float()) {
1587         value = builder->CreateFAdd(a, b);
1588     } else if (op->type.is_int() && op->type.bits() >= 32) {
1589         // We tell llvm integers don't wrap, so that it generates good
1590         // code for loop indices.
1591         value = builder->CreateNSWAdd(a, b);
1592     } else {
1593         value = builder->CreateAdd(a, b);
1594     }
1595 }
1596 
visit(const Sub * op)1597 void CodeGen_LLVM::visit(const Sub *op) {
1598     Type t = upgrade_type_for_arithmetic(op->type);
1599     if (t != op->type) {
1600         codegen(cast(op->type, Sub::make(cast(t, op->a), cast(t, op->b))));
1601         return;
1602     }
1603 
1604     Value *a = codegen(op->a);
1605     Value *b = codegen(op->b);
1606     if (op->type.is_float()) {
1607         value = builder->CreateFSub(a, b);
1608     } else if (op->type.is_int() && op->type.bits() >= 32) {
1609         // We tell llvm integers don't wrap, so that it generates good
1610         // code for loop indices.
1611         value = builder->CreateNSWSub(a, b);
1612     } else {
1613         value = builder->CreateSub(a, b);
1614     }
1615 }
1616 
visit(const Mul * op)1617 void CodeGen_LLVM::visit(const Mul *op) {
1618     Type t = upgrade_type_for_arithmetic(op->type);
1619     if (t != op->type) {
1620         codegen(cast(op->type, Mul::make(cast(t, op->a), cast(t, op->b))));
1621         return;
1622     }
1623 
1624     if (try_to_fold_vector_reduce(op)) {
1625         return;
1626     }
1627 
1628     Value *a = codegen(op->a);
1629     Value *b = codegen(op->b);
1630     if (op->type.is_float()) {
1631         value = builder->CreateFMul(a, b);
1632     } else if (op->type.is_int() && op->type.bits() >= 32) {
1633         // We tell llvm integers don't wrap, so that it generates good
1634         // code for loop indices.
1635         value = builder->CreateNSWMul(a, b);
1636     } else {
1637         value = builder->CreateMul(a, b);
1638     }
1639 }
1640 
visit(const Div * op)1641 void CodeGen_LLVM::visit(const Div *op) {
1642     user_assert(!is_zero(op->b)) << "Division by constant zero in expression: " << Expr(op) << "\n";
1643 
1644     Type t = upgrade_type_for_arithmetic(op->type);
1645     if (t != op->type) {
1646         codegen(cast(op->type, Div::make(cast(t, op->a), cast(t, op->b))));
1647         return;
1648     }
1649 
1650     if (op->type.is_float()) {
1651         // Don't call codegen() multiple times within an argument list:
1652         // order-of-evaluation isn't guaranteed and can vary by compiler,
1653         // leading to different LLVM IR ordering, which makes comparing
1654         // output hard.
1655         Value *a = codegen(op->a);
1656         Value *b = codegen(op->b);
1657         value = builder->CreateFDiv(a, b);
1658     } else {
1659         value = codegen(lower_int_uint_div(op->a, op->b));
1660     }
1661 }
1662 
visit(const Mod * op)1663 void CodeGen_LLVM::visit(const Mod *op) {
1664     Type t = upgrade_type_for_arithmetic(op->type);
1665     if (t != op->type) {
1666         codegen(cast(op->type, Mod::make(cast(t, op->a), cast(t, op->b))));
1667         return;
1668     }
1669 
1670     if (op->type.is_float()) {
1671         value = codegen(simplify(op->a - op->b * floor(op->a / op->b)));
1672     } else {
1673         value = codegen(lower_int_uint_mod(op->a, op->b));
1674     }
1675 }
1676 
visit(const Min * op)1677 void CodeGen_LLVM::visit(const Min *op) {
1678     Type t = upgrade_type_for_arithmetic(op->type);
1679     if (t != op->type) {
1680         codegen(cast(op->type, Min::make(cast(t, op->a), cast(t, op->b))));
1681         return;
1682     }
1683 
1684     if (try_to_fold_vector_reduce(op)) {
1685         return;
1686     }
1687 
1688     string a_name = unique_name('a');
1689     string b_name = unique_name('b');
1690     Expr a = Variable::make(op->a.type(), a_name);
1691     Expr b = Variable::make(op->b.type(), b_name);
1692     value = codegen(Let::make(a_name, op->a,
1693                               Let::make(b_name, op->b,
1694                                         select(a < b, a, b))));
1695 }
1696 
visit(const Max * op)1697 void CodeGen_LLVM::visit(const Max *op) {
1698     Type t = upgrade_type_for_arithmetic(op->type);
1699     if (t != op->type) {
1700         codegen(cast(op->type, Max::make(cast(t, op->a), cast(t, op->b))));
1701         return;
1702     }
1703 
1704     if (try_to_fold_vector_reduce(op)) {
1705         return;
1706     }
1707 
1708     string a_name = unique_name('a');
1709     string b_name = unique_name('b');
1710     Expr a = Variable::make(op->a.type(), a_name);
1711     Expr b = Variable::make(op->b.type(), b_name);
1712     value = codegen(Let::make(a_name, op->a,
1713                               Let::make(b_name, op->b,
1714                                         select(a > b, a, b))));
1715 }
1716 
visit(const EQ * op)1717 void CodeGen_LLVM::visit(const EQ *op) {
1718     Type t = upgrade_type_for_arithmetic(op->a.type());
1719     if (t != op->a.type()) {
1720         codegen(EQ::make(cast(t, op->a), cast(t, op->b)));
1721         return;
1722     }
1723 
1724     Value *a = codegen(op->a);
1725     Value *b = codegen(op->b);
1726     if (t.is_float()) {
1727         value = builder->CreateFCmpOEQ(a, b);
1728     } else {
1729         value = builder->CreateICmpEQ(a, b);
1730     }
1731 }
1732 
visit(const NE * op)1733 void CodeGen_LLVM::visit(const NE *op) {
1734     Type t = upgrade_type_for_arithmetic(op->a.type());
1735     if (t != op->a.type()) {
1736         codegen(NE::make(cast(t, op->a), cast(t, op->b)));
1737         return;
1738     }
1739 
1740     Value *a = codegen(op->a);
1741     Value *b = codegen(op->b);
1742     if (t.is_float()) {
1743         value = builder->CreateFCmpONE(a, b);
1744     } else {
1745         value = builder->CreateICmpNE(a, b);
1746     }
1747 }
1748 
visit(const LT * op)1749 void CodeGen_LLVM::visit(const LT *op) {
1750     Type t = upgrade_type_for_arithmetic(op->a.type());
1751     if (t != op->a.type()) {
1752         codegen(LT::make(cast(t, op->a), cast(t, op->b)));
1753         return;
1754     }
1755 
1756     Value *a = codegen(op->a);
1757     Value *b = codegen(op->b);
1758     if (t.is_float()) {
1759         value = builder->CreateFCmpOLT(a, b);
1760     } else if (t.is_int()) {
1761         value = builder->CreateICmpSLT(a, b);
1762     } else {
1763         value = builder->CreateICmpULT(a, b);
1764     }
1765 }
1766 
visit(const LE * op)1767 void CodeGen_LLVM::visit(const LE *op) {
1768     Type t = upgrade_type_for_arithmetic(op->a.type());
1769     if (t != op->a.type()) {
1770         codegen(LE::make(cast(t, op->a), cast(t, op->b)));
1771         return;
1772     }
1773 
1774     Value *a = codegen(op->a);
1775     Value *b = codegen(op->b);
1776     if (t.is_float()) {
1777         value = builder->CreateFCmpOLE(a, b);
1778     } else if (t.is_int()) {
1779         value = builder->CreateICmpSLE(a, b);
1780     } else {
1781         value = builder->CreateICmpULE(a, b);
1782     }
1783 }
1784 
visit(const GT * op)1785 void CodeGen_LLVM::visit(const GT *op) {
1786     Type t = upgrade_type_for_arithmetic(op->a.type());
1787     if (t != op->a.type()) {
1788         codegen(GT::make(cast(t, op->a), cast(t, op->b)));
1789         return;
1790     }
1791 
1792     Value *a = codegen(op->a);
1793     Value *b = codegen(op->b);
1794 
1795     if (t.is_float()) {
1796         value = builder->CreateFCmpOGT(a, b);
1797     } else if (t.is_int()) {
1798         value = builder->CreateICmpSGT(a, b);
1799     } else {
1800         value = builder->CreateICmpUGT(a, b);
1801     }
1802 }
1803 
visit(const GE * op)1804 void CodeGen_LLVM::visit(const GE *op) {
1805     Type t = upgrade_type_for_arithmetic(op->a.type());
1806     if (t != op->a.type()) {
1807         codegen(GE::make(cast(t, op->a), cast(t, op->b)));
1808         return;
1809     }
1810 
1811     Value *a = codegen(op->a);
1812     Value *b = codegen(op->b);
1813     if (t.is_float()) {
1814         value = builder->CreateFCmpOGE(a, b);
1815     } else if (t.is_int()) {
1816         value = builder->CreateICmpSGE(a, b);
1817     } else {
1818         value = builder->CreateICmpUGE(a, b);
1819     }
1820 }
1821 
visit(const And * op)1822 void CodeGen_LLVM::visit(const And *op) {
1823     if (try_to_fold_vector_reduce(op)) {
1824         return;
1825     }
1826 
1827     Value *a = codegen(op->a);
1828     Value *b = codegen(op->b);
1829     value = builder->CreateAnd(a, b);
1830 }
1831 
visit(const Or * op)1832 void CodeGen_LLVM::visit(const Or *op) {
1833     if (try_to_fold_vector_reduce(op)) {
1834         return;
1835     }
1836 
1837     Value *a = codegen(op->a);
1838     Value *b = codegen(op->b);
1839     value = builder->CreateOr(a, b);
1840 }
1841 
visit(const Not * op)1842 void CodeGen_LLVM::visit(const Not *op) {
1843     Value *a = codegen(op->a);
1844     value = builder->CreateNot(a);
1845 }
1846 
visit(const Select * op)1847 void CodeGen_LLVM::visit(const Select *op) {
1848     Value *cmp = codegen(op->condition);
1849     Value *a = codegen(op->true_value);
1850     Value *b = codegen(op->false_value);
1851     value = builder->CreateSelect(cmp, a, b);
1852 }
1853 
1854 namespace {
promote_64(const Expr & e)1855 Expr promote_64(const Expr &e) {
1856     if (const Add *a = e.as<Add>()) {
1857         return Add::make(promote_64(a->a), promote_64(a->b));
1858     } else if (const Sub *s = e.as<Sub>()) {
1859         return Sub::make(promote_64(s->a), promote_64(s->b));
1860     } else if (const Mul *m = e.as<Mul>()) {
1861         return Mul::make(promote_64(m->a), promote_64(m->b));
1862     } else if (const Min *m = e.as<Min>()) {
1863         return Min::make(promote_64(m->a), promote_64(m->b));
1864     } else if (const Max *m = e.as<Max>()) {
1865         return Max::make(promote_64(m->a), promote_64(m->b));
1866     } else {
1867         return cast(Int(64), e);
1868     }
1869 }
1870 }  // namespace
1871 
codegen_buffer_pointer(const string & buffer,Halide::Type type,Expr index)1872 Value *CodeGen_LLVM::codegen_buffer_pointer(const string &buffer, Halide::Type type, Expr index) {
1873     // Find the base address from the symbol table
1874     Value *base_address = symbol_table.get(buffer);
1875     return codegen_buffer_pointer(base_address, type, std::move(index));
1876 }
1877 
codegen_buffer_pointer(Value * base_address,Halide::Type type,Expr index)1878 Value *CodeGen_LLVM::codegen_buffer_pointer(Value *base_address, Halide::Type type, Expr index) {
1879     // Promote index to 64-bit on targets that use 64-bit pointers.
1880     llvm::DataLayout d(module.get());
1881     if (promote_indices() && d.getPointerSize() == 8) {
1882         index = promote_64(index);
1883     }
1884 
1885     // Peel off a constant offset as a second GEP. This helps LLVM's
1886     // aliasing analysis, especially for backends that do address
1887     // computation in 32 bits but use 64-bit pointers.
1888     if (const Add *add = index.as<Add>()) {
1889         if (const int64_t *offset = as_const_int(add->b)) {
1890             Value *base = codegen_buffer_pointer(base_address, type, add->a);
1891             Value *off = codegen(make_const(Int(8 * d.getPointerSize()), *offset));
1892             return builder->CreateInBoundsGEP(base, off);
1893         }
1894     }
1895 
1896     return codegen_buffer_pointer(base_address, type, codegen(index));
1897 }
1898 
codegen_buffer_pointer(const string & buffer,Halide::Type type,Value * index)1899 Value *CodeGen_LLVM::codegen_buffer_pointer(const string &buffer, Halide::Type type, Value *index) {
1900     // Find the base address from the symbol table
1901     Value *base_address = symbol_table.get(buffer);
1902     return codegen_buffer_pointer(base_address, type, index);
1903 }
1904 
codegen_buffer_pointer(Value * base_address,Halide::Type type,Value * index)1905 Value *CodeGen_LLVM::codegen_buffer_pointer(Value *base_address, Halide::Type type, Value *index) {
1906     llvm::Type *base_address_type = base_address->getType();
1907     unsigned address_space = base_address_type->getPointerAddressSpace();
1908 
1909     type = upgrade_type_for_storage(type);
1910 
1911     llvm::Type *load_type = llvm_type_of(type)->getPointerTo(address_space);
1912 
1913     // If the type doesn't match the expected type, we need to pointer cast
1914     if (load_type != base_address_type) {
1915         base_address = builder->CreatePointerCast(base_address, load_type);
1916     }
1917 
1918     llvm::Constant *constant_index = dyn_cast<llvm::Constant>(index);
1919     if (constant_index && constant_index->isZeroValue()) {
1920         return base_address;
1921     }
1922 
1923     // Promote index to 64-bit on targets that use 64-bit pointers.
1924     llvm::DataLayout d(module.get());
1925     if (d.getPointerSize() == 8) {
1926         index = builder->CreateIntCast(index, i64_t, true);
1927     }
1928 
1929     return builder->CreateInBoundsGEP(base_address, index);
1930 }
1931 
1932 namespace {
next_power_of_two(int x)1933 int next_power_of_two(int x) {
1934     for (int p2 = 1;; p2 *= 2) {
1935         if (p2 >= x) {
1936             return p2;
1937         }
1938     }
1939     // unreachable.
1940 }
1941 }  // namespace
1942 
add_tbaa_metadata(llvm::Instruction * inst,string buffer,const Expr & index)1943 void CodeGen_LLVM::add_tbaa_metadata(llvm::Instruction *inst, string buffer, const Expr &index) {
1944 
1945     // Get the unique name for the block of memory this allocate node
1946     // is using.
1947     buffer = get_allocation_name(buffer);
1948 
1949     // If the index is constant, we generate some TBAA info that helps
1950     // LLVM understand our loads/stores aren't aliased.
1951     bool constant_index = false;
1952     int64_t base = 0;
1953     int64_t width = 1;
1954 
1955     if (index.defined()) {
1956         if (const Ramp *ramp = index.as<Ramp>()) {
1957             const int64_t *pstride = as_const_int(ramp->stride);
1958             const int64_t *pbase = as_const_int(ramp->base);
1959             if (pstride && pbase) {
1960                 // We want to find the smallest aligned width and offset
1961                 // that contains this ramp.
1962                 int64_t stride = *pstride;
1963                 base = *pbase;
1964                 internal_assert(base >= 0);
1965                 width = next_power_of_two(ramp->lanes * stride);
1966 
1967                 while (base % width) {
1968                     base -= base % width;
1969                     width *= 2;
1970                 }
1971                 constant_index = true;
1972             }
1973         } else {
1974             const int64_t *pbase = as_const_int(index);
1975             if (pbase) {
1976                 base = *pbase;
1977                 constant_index = true;
1978             }
1979         }
1980     }
1981 
1982     llvm::MDBuilder builder(*context);
1983 
1984     // Add type-based-alias-analysis metadata to the pointer, so that
1985     // loads and stores to different buffers can get reordered.
1986     MDNode *tbaa = builder.createTBAARoot("Halide buffer");
1987 
1988     tbaa = builder.createTBAAScalarTypeNode(buffer, tbaa);
1989 
1990     // We also add metadata for constant indices to allow loads and
1991     // stores to the same buffer to get reordered.
1992     if (constant_index) {
1993         for (int w = 1024; w >= width; w /= 2) {
1994             int64_t b = (base / w) * w;
1995 
1996             std::stringstream level;
1997             level << buffer << ".width" << w << ".base" << b;
1998             tbaa = builder.createTBAAScalarTypeNode(level.str(), tbaa);
1999         }
2000     }
2001 
2002     tbaa = builder.createTBAAStructTagNode(tbaa, tbaa, 0);
2003 
2004     inst->setMetadata("tbaa", tbaa);
2005 }
2006 
visit(const Load * op)2007 void CodeGen_LLVM::visit(const Load *op) {
2008     // If the type should be stored as some other type, insert a reinterpret cast.
2009     Type storage_type = upgrade_type_for_storage(op->type);
2010     if (op->type != storage_type) {
2011         codegen(reinterpret(op->type, Load::make(storage_type, op->name,
2012                                                  op->index, op->image,
2013                                                  op->param, op->predicate, op->alignment)));
2014         return;
2015     }
2016 
2017     // Predicated load
2018     if (!is_one(op->predicate)) {
2019         codegen_predicated_vector_load(op);
2020         return;
2021     }
2022 
2023     // There are several cases. Different architectures may wish to override some.
2024     if (op->type.is_scalar()) {
2025         // Scalar loads
2026         Value *ptr = codegen_buffer_pointer(op->name, op->type, op->index);
2027         LoadInst *load = builder->CreateAlignedLoad(ptr, make_alignment(op->type.bytes()));
2028         add_tbaa_metadata(load, op->name, op->index);
2029         value = load;
2030     } else {
2031         const Ramp *ramp = op->index.as<Ramp>();
2032         const IntImm *stride = ramp ? ramp->stride.as<IntImm>() : nullptr;
2033 
2034         if (ramp && stride && stride->value == 1) {
2035             value = codegen_dense_vector_load(op);
2036         } else if (ramp && stride && stride->value == 2) {
2037             // Load two vectors worth and then shuffle
2038             Expr base_a = ramp->base, base_b = ramp->base + ramp->lanes;
2039             Expr stride_a = make_one(base_a.type());
2040             Expr stride_b = make_one(base_b.type());
2041 
2042             ModulusRemainder align_a = op->alignment;
2043             ModulusRemainder align_b = align_a + ramp->lanes;
2044 
2045             // False indicates we should take the even-numbered lanes
2046             // from the load, true indicates we should take the
2047             // odd-numbered-lanes.
2048             bool shifted_a = false, shifted_b = false;
2049 
2050             bool external = op->param.defined() || op->image.defined();
2051 
2052             // Don't read beyond the end of an external buffer.
2053             // (In ASAN mode, don't read beyond the end of internal buffers either,
2054             // as ASAN will complain even about harmless stack overreads.)
2055             if (external || target.has_feature(Target::ASAN)) {
2056                 base_b -= 1;
2057                 align_b = align_b - 1;
2058                 shifted_b = true;
2059             } else {
2060                 // If the base ends in an odd constant, then subtract one
2061                 // and do a different shuffle. This helps expressions like
2062                 // (f(2*x) + f(2*x+1) share loads
2063                 const Add *add = ramp->base.as<Add>();
2064                 const IntImm *offset = add ? add->b.as<IntImm>() : nullptr;
2065                 if (offset && offset->value & 1) {
2066                     base_a -= 1;
2067                     align_a = align_a - 1;
2068                     shifted_a = true;
2069                     base_b -= 1;
2070                     align_b = align_b - 1;
2071                     shifted_b = true;
2072                 }
2073             }
2074 
2075             // Do each load.
2076             Expr ramp_a = Ramp::make(base_a, stride_a, ramp->lanes);
2077             Expr ramp_b = Ramp::make(base_b, stride_b, ramp->lanes);
2078             Expr load_a = Load::make(op->type, op->name, ramp_a, op->image, op->param, op->predicate, align_a);
2079             Expr load_b = Load::make(op->type, op->name, ramp_b, op->image, op->param, op->predicate, align_b);
2080             Value *vec_a = codegen(load_a);
2081             Value *vec_b = codegen(load_b);
2082 
2083             // Shuffle together the results.
2084             vector<int> indices(ramp->lanes);
2085             for (int i = 0; i < (ramp->lanes + 1) / 2; i++) {
2086                 indices[i] = i * 2 + (shifted_a ? 1 : 0);
2087             }
2088             for (int i = (ramp->lanes + 1) / 2; i < ramp->lanes; i++) {
2089                 indices[i] = i * 2 + (shifted_b ? 1 : 0);
2090             }
2091 
2092             value = shuffle_vectors(vec_a, vec_b, indices);
2093         } else if (ramp && stride && stride->value == -1) {
2094             // Load the vector and then flip it in-place
2095             Expr flipped_base = ramp->base - ramp->lanes + 1;
2096             Expr flipped_stride = make_one(flipped_base.type());
2097             Expr flipped_index = Ramp::make(flipped_base, flipped_stride, ramp->lanes);
2098             ModulusRemainder align = op->alignment;
2099             // Switch to the alignment of the last lane
2100             align = align - (ramp->lanes - 1);
2101             Expr flipped_load = Load::make(op->type, op->name, flipped_index, op->image, op->param, op->predicate, align);
2102 
2103             Value *flipped = codegen(flipped_load);
2104 
2105             vector<int> indices(ramp->lanes);
2106             for (int i = 0; i < ramp->lanes; i++) {
2107                 indices[i] = ramp->lanes - 1 - i;
2108             }
2109 
2110             value = shuffle_vectors(flipped, indices);
2111         } else if (ramp) {
2112             // Gather without generating the indices as a vector
2113             Value *ptr = codegen_buffer_pointer(op->name, op->type.element_of(), ramp->base);
2114             Value *stride = codegen(ramp->stride);
2115             value = UndefValue::get(llvm_type_of(op->type));
2116             for (int i = 0; i < ramp->lanes; i++) {
2117                 Value *lane = ConstantInt::get(i32_t, i);
2118                 LoadInst *val = builder->CreateLoad(ptr);
2119                 add_tbaa_metadata(val, op->name, op->index);
2120                 value = builder->CreateInsertElement(value, val, lane);
2121                 ptr = builder->CreateInBoundsGEP(ptr, stride);
2122             }
2123         } else if ((false)) { /* should_scalarize(op->index) */
2124             // TODO: put something sensible in for
2125             // should_scalarize. Probably a good idea if there are no
2126             // loads in it, and it's all int32.
2127 
2128             // Compute the index as scalars, and then do a gather
2129             Value *vec = UndefValue::get(llvm_type_of(op->type));
2130             for (int i = 0; i < op->type.lanes(); i++) {
2131                 Expr idx = extract_lane(op->index, i);
2132                 Value *ptr = codegen_buffer_pointer(op->name, op->type.element_of(), idx);
2133                 LoadInst *val = builder->CreateLoad(ptr);
2134                 add_tbaa_metadata(val, op->name, op->index);
2135                 vec = builder->CreateInsertElement(vec, val, ConstantInt::get(i32_t, i));
2136             }
2137             value = vec;
2138         } else {
2139             // General gathers
2140             Value *index = codegen(op->index);
2141             Value *vec = UndefValue::get(llvm_type_of(op->type));
2142             for (int i = 0; i < op->type.lanes(); i++) {
2143                 Value *idx = builder->CreateExtractElement(index, ConstantInt::get(i32_t, i));
2144                 Value *ptr = codegen_buffer_pointer(op->name, op->type.element_of(), idx);
2145                 LoadInst *val = builder->CreateLoad(ptr);
2146                 add_tbaa_metadata(val, op->name, op->index);
2147                 vec = builder->CreateInsertElement(vec, val, ConstantInt::get(i32_t, i));
2148             }
2149             value = vec;
2150         }
2151     }
2152 }
2153 
visit(const Ramp * op)2154 void CodeGen_LLVM::visit(const Ramp *op) {
2155     if (is_const(op->stride) && !is_const(op->base)) {
2156         // If the stride is const and the base is not (e.g. ramp(x, 1,
2157         // 4)), we can lift out the stride and broadcast the base so
2158         // we can do a single vector broadcast and add instead of
2159         // repeated insertion
2160         Expr broadcast = Broadcast::make(op->base, op->lanes);
2161         Expr ramp = Ramp::make(make_zero(op->base.type()), op->stride, op->lanes);
2162         value = codegen(broadcast + ramp);
2163     } else {
2164         // Otherwise we generate element by element by adding the stride to the base repeatedly
2165 
2166         Value *base = codegen(op->base);
2167         Value *stride = codegen(op->stride);
2168 
2169         value = UndefValue::get(llvm_type_of(op->type));
2170         for (int i = 0; i < op->type.lanes(); i++) {
2171             if (i > 0) {
2172                 if (op->type.is_float()) {
2173                     base = builder->CreateFAdd(base, stride);
2174                 } else if (op->type.is_int() && op->type.bits() >= 32) {
2175                     base = builder->CreateNSWAdd(base, stride);
2176                 } else {
2177                     base = builder->CreateAdd(base, stride);
2178                 }
2179             }
2180             value = builder->CreateInsertElement(value, base, ConstantInt::get(i32_t, i));
2181         }
2182     }
2183 }
2184 
create_broadcast(llvm::Value * v,int lanes)2185 llvm::Value *CodeGen_LLVM::create_broadcast(llvm::Value *v, int lanes) {
2186     Constant *undef = UndefValue::get(get_vector_type(v->getType(), lanes));
2187     Constant *zero = ConstantInt::get(i32_t, 0);
2188     v = builder->CreateInsertElement(undef, v, zero);
2189     Constant *zeros = ConstantVector::getSplat(element_count(lanes), zero);
2190     return builder->CreateShuffleVector(v, undef, zeros);
2191 }
2192 
visit(const Broadcast * op)2193 void CodeGen_LLVM::visit(const Broadcast *op) {
2194     Value *v = codegen(op->value);
2195     value = create_broadcast(v, op->lanes);
2196 }
2197 
interleave_vectors(const std::vector<Value * > & vecs)2198 Value *CodeGen_LLVM::interleave_vectors(const std::vector<Value *> &vecs) {
2199     internal_assert(!vecs.empty());
2200     for (size_t i = 1; i < vecs.size(); i++) {
2201         internal_assert(vecs[0]->getType() == vecs[i]->getType());
2202     }
2203     int vec_elements = get_vector_num_elements(vecs[0]->getType());
2204 
2205     if (vecs.size() == 1) {
2206         return vecs[0];
2207     } else if (vecs.size() == 2) {
2208         Value *a = vecs[0];
2209         Value *b = vecs[1];
2210         vector<int> indices(vec_elements * 2);
2211         for (int i = 0; i < vec_elements * 2; i++) {
2212             indices[i] = i % 2 == 0 ? i / 2 : i / 2 + vec_elements;
2213         }
2214         return shuffle_vectors(a, b, indices);
2215     } else {
2216         // Grab the even and odd elements of vecs.
2217         vector<Value *> even_vecs;
2218         vector<Value *> odd_vecs;
2219         for (size_t i = 0; i < vecs.size(); i++) {
2220             if (i % 2 == 0) {
2221                 even_vecs.push_back(vecs[i]);
2222             } else {
2223                 odd_vecs.push_back(vecs[i]);
2224             }
2225         }
2226 
2227         // If the number of vecs is odd, save the last one for later.
2228         Value *last = nullptr;
2229         if (even_vecs.size() > odd_vecs.size()) {
2230             last = even_vecs.back();
2231             even_vecs.pop_back();
2232         }
2233         internal_assert(even_vecs.size() == odd_vecs.size());
2234 
2235         // Interleave the even and odd parts.
2236         Value *even = interleave_vectors(even_vecs);
2237         Value *odd = interleave_vectors(odd_vecs);
2238 
2239         if (last) {
2240             int result_elements = vec_elements * vecs.size();
2241 
2242             // Interleave even and odd, leaving a space for the last element.
2243             vector<int> indices(result_elements, -1);
2244             for (int i = 0, idx = 0; i < result_elements; i++) {
2245                 if (i % vecs.size() < vecs.size() - 1) {
2246                     indices[i] = idx % 2 == 0 ? idx / 2 : idx / 2 + vec_elements * even_vecs.size();
2247                     idx++;
2248                 }
2249             }
2250             Value *even_odd = shuffle_vectors(even, odd, indices);
2251 
2252             // Interleave the last vector into the result.
2253             last = slice_vector(last, 0, result_elements);
2254             for (int i = 0; i < result_elements; i++) {
2255                 if (i % vecs.size() < vecs.size() - 1) {
2256                     indices[i] = i;
2257                 } else {
2258                     indices[i] = i / vecs.size() + result_elements;
2259                 }
2260             }
2261 
2262             return shuffle_vectors(even_odd, last, indices);
2263         } else {
2264             return interleave_vectors({even, odd});
2265         }
2266     }
2267 }
2268 
scalarize(const Expr & e)2269 void CodeGen_LLVM::scalarize(const Expr &e) {
2270     llvm::Type *result_type = llvm_type_of(e.type());
2271 
2272     Value *result = UndefValue::get(result_type);
2273 
2274     for (int i = 0; i < e.type().lanes(); i++) {
2275         Value *v = codegen(extract_lane(e, i));
2276         result = builder->CreateInsertElement(result, v, ConstantInt::get(i32_t, i));
2277     }
2278     value = result;
2279 }
2280 
codegen_predicated_vector_store(const Store * op)2281 void CodeGen_LLVM::codegen_predicated_vector_store(const Store *op) {
2282     const Ramp *ramp = op->index.as<Ramp>();
2283     if (ramp && is_one(ramp->stride)) {  // Dense vector store
2284         debug(4) << "Predicated dense vector store\n\t" << Stmt(op) << "\n";
2285         Value *vpred = codegen(op->predicate);
2286         Halide::Type value_type = op->value.type();
2287         Value *val = codegen(op->value);
2288         bool is_external = (external_buffer.find(op->name) != external_buffer.end());
2289         int alignment = value_type.bytes();
2290         int native_bits = native_vector_bits();
2291         int native_bytes = native_bits / 8;
2292 
2293         // Boost the alignment if possible, up to the native vector width.
2294         ModulusRemainder mod_rem = op->alignment;
2295         while ((mod_rem.remainder & 1) == 0 &&
2296                (mod_rem.modulus & 1) == 0 &&
2297                alignment < native_bytes) {
2298             mod_rem.modulus /= 2;
2299             mod_rem.remainder /= 2;
2300             alignment *= 2;
2301         }
2302 
2303         // If it is an external buffer, then we cannot assume that the host pointer
2304         // is aligned to at least the native vector width. However, we may be able to do
2305         // better than just assuming that it is unaligned.
2306         if (is_external && op->param.defined()) {
2307             int host_alignment = op->param.host_alignment();
2308             alignment = gcd(alignment, host_alignment);
2309         }
2310 
2311         // For dense vector stores wider than the native vector
2312         // width, bust them up into native vectors.
2313         int store_lanes = value_type.lanes();
2314         int native_lanes = native_bits / value_type.bits();
2315 
2316         for (int i = 0; i < store_lanes; i += native_lanes) {
2317             int slice_lanes = std::min(native_lanes, store_lanes - i);
2318             Expr slice_base = simplify(ramp->base + i);
2319             Expr slice_stride = make_one(slice_base.type());
2320             Expr slice_index = slice_lanes == 1 ? slice_base : Ramp::make(slice_base, slice_stride, slice_lanes);
2321             Value *slice_val = slice_vector(val, i, slice_lanes);
2322             Value *elt_ptr = codegen_buffer_pointer(op->name, value_type.element_of(), slice_base);
2323             Value *vec_ptr = builder->CreatePointerCast(elt_ptr, slice_val->getType()->getPointerTo());
2324 
2325             Value *slice_mask = slice_vector(vpred, i, slice_lanes);
2326 #if LLVM_VERSION >= 110
2327             Instruction *store_inst =
2328                 builder->CreateMaskedStore(slice_val, vec_ptr, make_alignment(alignment), slice_mask);
2329 #else
2330             Instruction *store_inst =
2331                 builder->CreateMaskedStore(slice_val, vec_ptr, alignment, slice_mask);
2332 #endif
2333             add_tbaa_metadata(store_inst, op->name, slice_index);
2334         }
2335     } else {  // It's not dense vector store, we need to scalarize it
2336         debug(4) << "Scalarize predicated vector store\n";
2337         Type value_type = op->value.type().element_of();
2338         Value *vpred = codegen(op->predicate);
2339         Value *vval = codegen(op->value);
2340         Value *vindex = codegen(op->index);
2341         for (int i = 0; i < op->index.type().lanes(); i++) {
2342             Constant *lane = ConstantInt::get(i32_t, i);
2343             Value *p = builder->CreateExtractElement(vpred, lane);
2344             if (p->getType() != i1_t) {
2345                 p = builder->CreateIsNotNull(p);
2346             }
2347 
2348             Value *v = builder->CreateExtractElement(vval, lane);
2349             Value *idx = builder->CreateExtractElement(vindex, lane);
2350             internal_assert(p && v && idx);
2351 
2352             BasicBlock *true_bb = BasicBlock::Create(*context, "true_bb", function);
2353             BasicBlock *after_bb = BasicBlock::Create(*context, "after_bb", function);
2354             builder->CreateCondBr(p, true_bb, after_bb);
2355 
2356             builder->SetInsertPoint(true_bb);
2357 
2358             // Scalar
2359             Value *ptr = codegen_buffer_pointer(op->name, value_type, idx);
2360             builder->CreateAlignedStore(v, ptr, make_alignment(value_type.bytes()));
2361 
2362             builder->CreateBr(after_bb);
2363             builder->SetInsertPoint(after_bb);
2364         }
2365     }
2366 }
2367 
codegen_dense_vector_load(const Load * load,Value * vpred)2368 Value *CodeGen_LLVM::codegen_dense_vector_load(const Load *load, Value *vpred) {
2369     debug(4) << "Vectorize predicated dense vector load:\n\t" << Expr(load) << "\n";
2370 
2371     const Ramp *ramp = load->index.as<Ramp>();
2372     internal_assert(ramp && is_one(ramp->stride)) << "Should be dense vector load\n";
2373 
2374     bool is_external = (external_buffer.find(load->name) != external_buffer.end());
2375     int alignment = load->type.bytes();  // The size of a single element
2376 
2377     int native_bits = native_vector_bits();
2378     int native_bytes = native_bits / 8;
2379 
2380     // We assume halide_malloc for the platform returns buffers
2381     // aligned to at least the native vector width. So this is the
2382     // maximum alignment we can infer based on the index alone.
2383 
2384     // Boost the alignment if possible, up to the native vector width.
2385     ModulusRemainder mod_rem = load->alignment;
2386     while ((mod_rem.remainder & 1) == 0 &&
2387            (mod_rem.modulus & 1) == 0 &&
2388            alignment < native_bytes) {
2389         mod_rem.modulus /= 2;
2390         mod_rem.remainder /= 2;
2391         alignment *= 2;
2392     }
2393 
2394     // If it is an external buffer, then we cannot assume that the host pointer
2395     // is aligned to at least native vector width. However, we may be able to do
2396     // better than just assuming that it is unaligned.
2397     if (is_external) {
2398         if (load->param.defined()) {
2399             int host_alignment = load->param.host_alignment();
2400             alignment = gcd(alignment, host_alignment);
2401         } else if (get_target().has_feature(Target::JIT) && load->image.defined()) {
2402             // If we're JITting, use the actual pointer value to determine alignment for embedded buffers.
2403             alignment = gcd(alignment, (int)(((uintptr_t)load->image.data()) & std::numeric_limits<int>::max()));
2404         }
2405     }
2406 
2407     // For dense vector loads wider than the native vector
2408     // width, bust them up into native vectors
2409     int load_lanes = load->type.lanes();
2410     int native_lanes = std::max(1, native_bits / load->type.bits());
2411     vector<Value *> slices;
2412     for (int i = 0; i < load_lanes; i += native_lanes) {
2413         int slice_lanes = std::min(native_lanes, load_lanes - i);
2414         Expr slice_base = simplify(ramp->base + i);
2415         Expr slice_stride = make_one(slice_base.type());
2416         Expr slice_index = slice_lanes == 1 ? slice_base : Ramp::make(slice_base, slice_stride, slice_lanes);
2417         llvm::Type *slice_type = get_vector_type(llvm_type_of(load->type.element_of()), slice_lanes);
2418         Value *elt_ptr = codegen_buffer_pointer(load->name, load->type.element_of(), slice_base);
2419         Value *vec_ptr = builder->CreatePointerCast(elt_ptr, slice_type->getPointerTo());
2420 
2421         Instruction *load_inst;
2422         if (vpred != nullptr) {
2423             Value *slice_mask = slice_vector(vpred, i, slice_lanes);
2424 #if LLVM_VERSION >= 110
2425             load_inst = builder->CreateMaskedLoad(vec_ptr, make_alignment(alignment), slice_mask);
2426 #else
2427             load_inst = builder->CreateMaskedLoad(vec_ptr, alignment, slice_mask);
2428 #endif
2429         } else {
2430             load_inst = builder->CreateAlignedLoad(vec_ptr, make_alignment(alignment));
2431         }
2432         add_tbaa_metadata(load_inst, load->name, slice_index);
2433         slices.push_back(load_inst);
2434     }
2435     value = concat_vectors(slices);
2436     return value;
2437 }
2438 
codegen_predicated_vector_load(const Load * op)2439 void CodeGen_LLVM::codegen_predicated_vector_load(const Load *op) {
2440     const Ramp *ramp = op->index.as<Ramp>();
2441     const IntImm *stride = ramp ? ramp->stride.as<IntImm>() : nullptr;
2442 
2443     if (ramp && is_one(ramp->stride)) {  // Dense vector load
2444         Value *vpred = codegen(op->predicate);
2445         value = codegen_dense_vector_load(op, vpred);
2446     } else if (ramp && stride && stride->value == -1) {
2447         debug(4) << "Predicated dense vector load with stride -1\n\t" << Expr(op) << "\n";
2448         vector<int> indices(ramp->lanes);
2449         for (int i = 0; i < ramp->lanes; i++) {
2450             indices[i] = ramp->lanes - 1 - i;
2451         }
2452 
2453         // Flip the predicate
2454         Value *vpred = codegen(op->predicate);
2455         vpred = shuffle_vectors(vpred, indices);
2456 
2457         // Load the vector and then flip it in-place
2458         Expr flipped_base = ramp->base - ramp->lanes + 1;
2459         Expr flipped_stride = make_one(flipped_base.type());
2460         Expr flipped_index = Ramp::make(flipped_base, flipped_stride, ramp->lanes);
2461         ModulusRemainder align = op->alignment;
2462         align = align - (ramp->lanes - 1);
2463 
2464         Expr flipped_load = Load::make(op->type, op->name, flipped_index, op->image,
2465                                        op->param, const_true(op->type.lanes()), align);
2466 
2467         Value *flipped = codegen_dense_vector_load(flipped_load.as<Load>(), vpred);
2468         value = shuffle_vectors(flipped, indices);
2469     } else {  // It's not dense vector load, we need to scalarize it
2470         Expr load_expr = Load::make(op->type, op->name, op->index, op->image,
2471                                     op->param, const_true(op->type.lanes()), op->alignment);
2472         debug(4) << "Scalarize predicated vector load\n\t" << load_expr << "\n";
2473         Expr pred_load = Call::make(load_expr.type(),
2474                                     Call::if_then_else,
2475                                     {op->predicate, load_expr, make_zero(load_expr.type())},
2476                                     Internal::Call::Intrinsic);
2477         value = codegen(pred_load);
2478     }
2479 }
2480 
codegen_atomic_store(const Store * op)2481 void CodeGen_LLVM::codegen_atomic_store(const Store *op) {
2482     // TODO: predicated store (see https://github.com/halide/Halide/issues/4298).
2483     user_assert(is_one(op->predicate)) << "Atomic predicated store is not supported.\n";
2484 
2485     // Detect whether we can describe this as an atomic-read-modify-write,
2486     // otherwise fallback to a compare-and-swap loop.
2487     // Currently we only test for atomicAdd.
2488     Expr val_expr = op->value;
2489     Halide::Type value_type = op->value.type();
2490     // For atomicAdd, we check if op->value - store[index] is independent of store.
2491     // For llvm version < 9, the atomicRMW operations only support integers so we also check that.
2492     Expr equiv_load = Load::make(value_type, op->name,
2493                                  op->index,
2494                                  Buffer<>(),
2495                                  op->param,
2496                                  op->predicate,
2497                                  op->alignment);
2498     Expr delta = simplify(common_subexpression_elimination(op->value - equiv_load));
2499     bool is_atomic_add = supports_atomic_add(value_type) && !expr_uses_var(delta, op->name);
2500     if (is_atomic_add) {
2501         Value *val = codegen(delta);
2502         if (value_type.is_scalar()) {
2503             Value *ptr = codegen_buffer_pointer(op->name,
2504                                                 op->value.type(),
2505                                                 op->index);
2506             // llvm 9 has FAdd which can be used for atomic floats.
2507             if (value_type.is_float()) {
2508                 builder->CreateAtomicRMW(AtomicRMWInst::FAdd, ptr, val, AtomicOrdering::Monotonic);
2509             } else {
2510                 builder->CreateAtomicRMW(AtomicRMWInst::Add, ptr, val, AtomicOrdering::Monotonic);
2511             }
2512         } else {
2513             Value *index = codegen(op->index);
2514             // Scalarize vector store.
2515             for (int i = 0; i < value_type.lanes(); i++) {
2516                 Value *lane = ConstantInt::get(i32_t, i);
2517                 Value *idx = builder->CreateExtractElement(index, lane);
2518                 Value *v = builder->CreateExtractElement(val, lane);
2519                 Value *ptr = codegen_buffer_pointer(op->name, value_type.element_of(), idx);
2520                 if (value_type.is_float()) {
2521                     builder->CreateAtomicRMW(AtomicRMWInst::FAdd, ptr, v, AtomicOrdering::Monotonic);
2522                 } else {
2523                     builder->CreateAtomicRMW(AtomicRMWInst::Add, ptr, v, AtomicOrdering::Monotonic);
2524                 }
2525             }
2526         }
2527     } else {
2528         // We want to create the following CAS loop:
2529         // entry:
2530         //   %orig = load atomic op->name[op->index]
2531         //   br label %casloop.start
2532         // casloop.start:
2533         //   %cmp = phi [%orig, %entry], [%value_loaded %casloop.start]
2534         //   %val = ...
2535         //   %val_success = cmpxchg %ptr, %cmp, %val, monotonic
2536         //   %val_loaded = extractvalue %val_success, 0
2537         //   %success = extractvalue %val_success, 1
2538         //   br %success, label %casloop.end, label %casloop.start
2539         // casloop.end:
2540         Value *vec_index = nullptr;
2541         if (!value_type.is_scalar()) {
2542             // Precompute index for vector store.
2543             vec_index = codegen(op->index);
2544         }
2545         // Scalarize vector store.
2546         for (int lane_id = 0; lane_id < value_type.lanes(); lane_id++) {
2547             LLVMContext &ctx = builder->getContext();
2548             BasicBlock *bb = builder->GetInsertBlock();
2549             llvm::Function *f = bb->getParent();
2550             BasicBlock *loop_bb =
2551                 BasicBlock::Create(ctx, "casloop.start", f);
2552             // Load the old value for compare and swap test.
2553             Value *ptr = nullptr;
2554             if (value_type.is_scalar()) {
2555                 ptr = codegen_buffer_pointer(op->name, value_type, op->index);
2556             } else {
2557                 Value *idx = builder->CreateExtractElement(vec_index, ConstantInt::get(i32_t, lane_id));
2558                 ptr = codegen_buffer_pointer(op->name, value_type.element_of(), idx);
2559             }
2560             LoadInst *orig = builder->CreateAlignedLoad(ptr, make_alignment(value_type.bytes()));
2561             orig->setOrdering(AtomicOrdering::Monotonic);
2562             add_tbaa_metadata(orig, op->name, op->index);
2563             // Explicit fall through from the current block to the cas loop body.
2564             builder->CreateBr(loop_bb);
2565 
2566             // CAS loop body:
2567             builder->SetInsertPoint(loop_bb);
2568             llvm::Type *ptr_type = ptr->getType();
2569             PHINode *cmp = builder->CreatePHI(ptr_type->getPointerElementType(), 2, "loaded");
2570             Value *cmp_val = cmp;
2571             cmp->addIncoming(orig, bb);
2572             Value *val = nullptr;
2573             if (value_type.is_scalar()) {
2574                 val = codegen(op->value);
2575             } else {
2576                 val = codegen(extract_lane(op->value, lane_id));
2577             }
2578             llvm::Type *val_type = val->getType();
2579             bool need_bit_cast = val_type->isFloatingPointTy();
2580             if (need_bit_cast) {
2581                 IntegerType *int_type = builder->getIntNTy(val_type->getPrimitiveSizeInBits());
2582                 unsigned int addr_space = ptr_type->getPointerAddressSpace();
2583                 ptr = builder->CreateBitCast(ptr, int_type->getPointerTo(addr_space));
2584                 val = builder->CreateBitCast(val, int_type);
2585                 cmp_val = builder->CreateBitCast(cmp_val, int_type);
2586             }
2587             Value *cmpxchg_pair = builder->CreateAtomicCmpXchg(
2588                 ptr, cmp_val, val, AtomicOrdering::Monotonic, AtomicOrdering::Monotonic);
2589             Value *val_loaded = builder->CreateExtractValue(cmpxchg_pair, 0, "val_loaded");
2590             Value *success = builder->CreateExtractValue(cmpxchg_pair, 1, "success");
2591             if (need_bit_cast) {
2592                 val_loaded = builder->CreateBitCast(val_loaded, val_type);
2593             }
2594             cmp->addIncoming(val_loaded, loop_bb);
2595             BasicBlock *exit_bb =
2596                 BasicBlock::Create(ctx, "casloop.end", f);
2597             builder->CreateCondBr(success, exit_bb, loop_bb);
2598             builder->SetInsertPoint(exit_bb);
2599         }
2600     }
2601 }
2602 
visit(const Call * op)2603 void CodeGen_LLVM::visit(const Call *op) {
2604     internal_assert(op->is_extern() || op->is_intrinsic())
2605         << "Can only codegen extern calls and intrinsics\n";
2606 
2607     // Some call nodes are actually injected at various stages as a
2608     // cue for llvm to generate particular ops. In general these are
2609     // handled in the standard library, but ones with e.g. varying
2610     // types are handled here.
2611     if (op->is_intrinsic(Call::debug_to_file)) {
2612         internal_assert(op->args.size() == 3);
2613         const StringImm *filename = op->args[0].as<StringImm>();
2614         internal_assert(filename) << "Malformed debug_to_file node\n";
2615         // Grab the function from the initial module
2616         llvm::Function *debug_to_file = module->getFunction("halide_debug_to_file");
2617         internal_assert(debug_to_file) << "Could not find halide_debug_to_file function in initial module\n";
2618 
2619         // Make the filename a global string constant
2620         Value *user_context = get_user_context();
2621         Value *char_ptr = codegen(Expr(filename));
2622         vector<Value *> args = {user_context, char_ptr, codegen(op->args[1])};
2623 
2624         Value *buffer = codegen(op->args[2]);
2625         buffer = builder->CreatePointerCast(buffer, debug_to_file->getFunctionType()->getParamType(3));
2626         args.push_back(buffer);
2627 
2628         value = builder->CreateCall(debug_to_file, args);
2629 
2630     } else if (op->is_intrinsic(Call::bitwise_and)) {
2631         internal_assert(op->args.size() == 2);
2632         Value *a = codegen(op->args[0]);
2633         Value *b = codegen(op->args[1]);
2634         value = builder->CreateAnd(a, b);
2635     } else if (op->is_intrinsic(Call::bitwise_xor)) {
2636         internal_assert(op->args.size() == 2);
2637         Value *a = codegen(op->args[0]);
2638         Value *b = codegen(op->args[1]);
2639         value = builder->CreateXor(a, b);
2640     } else if (op->is_intrinsic(Call::bitwise_or)) {
2641         internal_assert(op->args.size() == 2);
2642         Value *a = codegen(op->args[0]);
2643         Value *b = codegen(op->args[1]);
2644         value = builder->CreateOr(a, b);
2645     } else if (op->is_intrinsic(Call::bitwise_not)) {
2646         internal_assert(op->args.size() == 1);
2647         Value *a = codegen(op->args[0]);
2648         value = builder->CreateNot(a);
2649     } else if (op->is_intrinsic(Call::reinterpret)) {
2650         internal_assert(op->args.size() == 1);
2651         Type dst = op->type;
2652         Type src = op->args[0].type();
2653         llvm::Type *llvm_dst = llvm_type_of(dst);
2654         value = codegen(op->args[0]);
2655         if (src.is_handle() && !dst.is_handle()) {
2656             internal_assert(dst.is_uint() && dst.bits() == 64);
2657 
2658             // Handle -> UInt64
2659             llvm::DataLayout d(module.get());
2660             if (d.getPointerSize() == 4) {
2661                 llvm::Type *intermediate = llvm_type_of(UInt(32, dst.lanes()));
2662                 value = builder->CreatePtrToInt(value, intermediate);
2663                 value = builder->CreateZExt(value, llvm_dst);
2664             } else if (d.getPointerSize() == 8) {
2665                 value = builder->CreatePtrToInt(value, llvm_dst);
2666             } else {
2667                 internal_error << "Pointer size is neither 4 nor 8 bytes\n";
2668             }
2669 
2670         } else if (dst.is_handle() && !src.is_handle()) {
2671             internal_assert(src.is_uint() && src.bits() == 64);
2672 
2673             // UInt64 -> Handle
2674             llvm::DataLayout d(module.get());
2675             if (d.getPointerSize() == 4) {
2676                 llvm::Type *intermediate = llvm_type_of(UInt(32, src.lanes()));
2677                 value = builder->CreateTrunc(value, intermediate);
2678                 value = builder->CreateIntToPtr(value, llvm_dst);
2679             } else if (d.getPointerSize() == 8) {
2680                 value = builder->CreateIntToPtr(value, llvm_dst);
2681             } else {
2682                 internal_error << "Pointer size is neither 4 nor 8 bytes\n";
2683             }
2684 
2685         } else {
2686             value = builder->CreateBitCast(value, llvm_dst);
2687         }
2688     } else if (op->is_intrinsic(Call::shift_left)) {
2689         internal_assert(op->args.size() == 2);
2690         Value *a = codegen(op->args[0]);
2691         Value *b = codegen(op->args[1]);
2692         if (op->args[1].type().is_uint()) {
2693             value = builder->CreateShl(a, b);
2694         } else {
2695             value = codegen(lower_signed_shift_left(op->args[0], op->args[1]));
2696         }
2697     } else if (op->is_intrinsic(Call::shift_right)) {
2698         internal_assert(op->args.size() == 2);
2699         Value *a = codegen(op->args[0]);
2700         Value *b = codegen(op->args[1]);
2701         if (op->args[1].type().is_uint()) {
2702             if (op->type.is_int()) {
2703                 value = builder->CreateAShr(a, b);
2704             } else {
2705                 value = builder->CreateLShr(a, b);
2706             }
2707         } else {
2708             value = codegen(lower_signed_shift_right(op->args[0], op->args[1]));
2709         }
2710     } else if (op->is_intrinsic(Call::abs)) {
2711 
2712         internal_assert(op->args.size() == 1);
2713 
2714         // Check if an appropriate vector abs for this type exists in the initial module
2715         Type t = op->args[0].type();
2716         string name = (t.is_float() ? "abs_f" : "abs_i") + std::to_string(t.bits());
2717         llvm::Function *builtin_abs =
2718             find_vector_runtime_function(name, op->type.lanes()).first;
2719 
2720         if (t.is_vector() && builtin_abs) {
2721             codegen(Call::make(op->type, name, op->args, Call::Extern));
2722         } else {
2723             // Generate select(x >= 0, x, -x) instead
2724             string x_name = unique_name('x');
2725             Expr x = Variable::make(op->args[0].type(), x_name);
2726             value = codegen(Let::make(x_name, op->args[0], select(x >= 0, x, -x)));
2727         }
2728     } else if (op->is_intrinsic(Call::absd)) {
2729 
2730         internal_assert(op->args.size() == 2);
2731 
2732         Expr a = op->args[0];
2733         Expr b = op->args[1];
2734 
2735         // Check if an appropriate vector abs for this type exists in the initial module
2736         Type t = a.type();
2737         string name;
2738         if (t.is_float()) {
2739             codegen(abs(a - b));
2740             return;
2741         } else if (t.is_int()) {
2742             name = "absd_i" + std::to_string(t.bits());
2743         } else {
2744             name = "absd_u" + std::to_string(t.bits());
2745         }
2746 
2747         llvm::Function *builtin_absd =
2748             find_vector_runtime_function(name, op->type.lanes()).first;
2749 
2750         if (t.is_vector() && builtin_absd) {
2751             codegen(Call::make(op->type, name, op->args, Call::Extern));
2752         } else {
2753             // Use a select instead
2754             string a_name = unique_name('a');
2755             string b_name = unique_name('b');
2756             Expr a_var = Variable::make(op->args[0].type(), a_name);
2757             Expr b_var = Variable::make(op->args[1].type(), b_name);
2758             codegen(Let::make(a_name, op->args[0],
2759                               Let::make(b_name, op->args[1],
2760                                         Select::make(a_var < b_var, b_var - a_var, a_var - b_var))));
2761         }
2762     } else if (op->is_intrinsic(Call::div_round_to_zero)) {
2763         internal_assert(op->args.size() == 2);
2764         Value *a = codegen(op->args[0]);
2765         Value *b = codegen(op->args[1]);
2766         if (op->type.is_int()) {
2767             value = builder->CreateSDiv(a, b);
2768         } else if (op->type.is_uint()) {
2769             value = builder->CreateUDiv(a, b);
2770         } else {
2771             internal_error << "div_round_to_zero of non-integer type.\n";
2772         }
2773     } else if (op->is_intrinsic(Call::mod_round_to_zero)) {
2774         internal_assert(op->args.size() == 2);
2775         Value *a = codegen(op->args[0]);
2776         Value *b = codegen(op->args[1]);
2777         if (op->type.is_int()) {
2778             value = builder->CreateSRem(a, b);
2779         } else if (op->type.is_uint()) {
2780             value = builder->CreateURem(a, b);
2781         } else {
2782             internal_error << "mod_round_to_zero of non-integer type.\n";
2783         }
2784     } else if (op->is_intrinsic(Call::mulhi_shr)) {
2785         internal_assert(op->args.size() == 3);
2786 
2787         Type ty = op->type;
2788         Type wide_ty = ty.with_bits(ty.bits() * 2);
2789 
2790         Expr p_wide = cast(wide_ty, op->args[0]) * cast(wide_ty, op->args[1]);
2791         const UIntImm *shift = op->args[2].as<UIntImm>();
2792         internal_assert(shift != nullptr) << "Third argument to mulhi_shr intrinsic must be an unsigned integer immediate.\n";
2793         value = codegen(cast(ty, p_wide >> (shift->value + ty.bits())));
2794     } else if (op->is_intrinsic(Call::sorted_avg)) {
2795         internal_assert(op->args.size() == 2);
2796         // b > a, so the following works without widening:
2797         // a + (b - a)/2
2798         value = codegen(op->args[0] + (op->args[1] - op->args[0]) / 2);
2799     } else if (op->is_intrinsic(Call::lerp)) {
2800         internal_assert(op->args.size() == 3);
2801         // If we need to upgrade the type, do the entire lerp in the
2802         // upgraded type for better precision.
2803         Type t = upgrade_type_for_arithmetic(op->type);
2804         Type wt = upgrade_type_for_arithmetic(op->args[2].type());
2805         Expr e = lower_lerp(cast(t, op->args[0]),
2806                             cast(t, op->args[1]),
2807                             cast(wt, op->args[2]));
2808         e = cast(op->type, e);
2809         codegen(e);
2810     } else if (op->is_intrinsic(Call::popcount)) {
2811         internal_assert(op->args.size() == 1);
2812         std::vector<llvm::Type *> arg_type(1);
2813         arg_type[0] = llvm_type_of(op->args[0].type());
2814         llvm::Function *fn = Intrinsic::getDeclaration(module.get(), Intrinsic::ctpop, arg_type);
2815         Value *a = codegen(op->args[0]);
2816         CallInst *call = builder->CreateCall(fn, a);
2817         value = call;
2818     } else if (op->is_intrinsic(Call::count_leading_zeros) ||
2819                op->is_intrinsic(Call::count_trailing_zeros)) {
2820         internal_assert(op->args.size() == 1);
2821         std::vector<llvm::Type *> arg_type(1);
2822         arg_type[0] = llvm_type_of(op->args[0].type());
2823         llvm::Function *fn = Intrinsic::getDeclaration(module.get(),
2824                                                        (op->is_intrinsic(Call::count_leading_zeros)) ? Intrinsic::ctlz : Intrinsic::cttz,
2825                                                        arg_type);
2826         llvm::Value *is_zero_undef = llvm::ConstantInt::getFalse(*context);
2827         llvm::Value *args[2] = {codegen(op->args[0]), is_zero_undef};
2828         CallInst *call = builder->CreateCall(fn, args);
2829         value = call;
2830     } else if (op->is_intrinsic(Call::return_second)) {
2831         internal_assert(op->args.size() == 2);
2832         codegen(op->args[0]);
2833         value = codegen(op->args[1]);
2834     } else if (op->is_intrinsic(Call::if_then_else)) {
2835         Expr cond = op->args[0];
2836         if (const Broadcast *b = cond.as<Broadcast>()) {
2837             cond = b->value;
2838         }
2839         if (cond.type().is_vector()) {
2840             scalarize(op);
2841         } else {
2842 
2843             internal_assert(op->args.size() == 3);
2844 
2845             BasicBlock *true_bb = BasicBlock::Create(*context, "true_bb", function);
2846             BasicBlock *false_bb = BasicBlock::Create(*context, "false_bb", function);
2847             BasicBlock *after_bb = BasicBlock::Create(*context, "after_bb", function);
2848             Value *c = codegen(cond);
2849             if (c->getType() != i1_t) {
2850                 c = builder->CreateIsNotNull(c);
2851             }
2852             builder->CreateCondBr(c, true_bb, false_bb);
2853             builder->SetInsertPoint(true_bb);
2854             Value *true_value = codegen(op->args[1]);
2855             builder->CreateBr(after_bb);
2856             BasicBlock *true_pred = builder->GetInsertBlock();
2857 
2858             builder->SetInsertPoint(false_bb);
2859             Value *false_value = codegen(op->args[2]);
2860             builder->CreateBr(after_bb);
2861             BasicBlock *false_pred = builder->GetInsertBlock();
2862 
2863             builder->SetInsertPoint(after_bb);
2864             PHINode *phi = builder->CreatePHI(true_value->getType(), 2);
2865             phi->addIncoming(true_value, true_pred);
2866             phi->addIncoming(false_value, false_pred);
2867 
2868             value = phi;
2869         }
2870     } else if (op->is_intrinsic(Call::require)) {
2871         internal_assert(op->args.size() == 3);
2872         Expr cond = op->args[0];
2873         if (cond.type().is_vector()) {
2874             scalarize(op);
2875         } else {
2876             Value *c = codegen(cond);
2877             create_assertion(c, op->args[2]);
2878             value = codegen(op->args[1]);
2879         }
2880     } else if (op->is_intrinsic(Call::make_struct)) {
2881         if (op->type.is_vector()) {
2882             // Make a vector of pointers to distinct structs
2883             scalarize(op);
2884         } else if (op->args.empty()) {
2885             // Empty structs can be emitted for arrays of size zero
2886             // (e.g. the shape of a zero-dimensional buffer). We
2887             // generate a null in this situation. */
2888             value = ConstantPointerNull::get(dyn_cast<PointerType>(llvm_type_of(op->type)));
2889         } else {
2890             // Codegen each element.
2891             bool all_same_type = true;
2892             vector<llvm::Value *> args(op->args.size());
2893             vector<llvm::Type *> types(op->args.size());
2894             for (size_t i = 0; i < op->args.size(); i++) {
2895                 args[i] = codegen(op->args[i]);
2896                 types[i] = args[i]->getType();
2897                 all_same_type &= (types[0] == types[i]);
2898             }
2899 
2900             // Use either a single scalar, a fixed-size array, or a
2901             // struct. The struct type would always be correct, but
2902             // the array or scalar type produce slightly simpler IR.
2903             if (args.size() == 1) {
2904                 value = create_alloca_at_entry(types[0], 1);
2905                 builder->CreateStore(args[0], value);
2906             } else {
2907                 llvm::Type *aggregate_t = (all_same_type ? (llvm::Type *)ArrayType::get(types[0], types.size()) : (llvm::Type *)StructType::get(*context, types));
2908 
2909                 value = create_alloca_at_entry(aggregate_t, 1);
2910                 for (size_t i = 0; i < args.size(); i++) {
2911                     Value *elem_ptr = builder->CreateConstInBoundsGEP2_32(aggregate_t, value, 0, i);
2912                     builder->CreateStore(args[i], elem_ptr);
2913                 }
2914             }
2915         }
2916 
2917     } else if (op->is_intrinsic(Call::stringify)) {
2918         internal_assert(!op->args.empty());
2919 
2920         if (op->type.is_vector()) {
2921             scalarize(op);
2922         } else {
2923 
2924             // Compute the maximum possible size of the message.
2925             int buf_size = 1;  // One for the terminating zero.
2926             for (size_t i = 0; i < op->args.size(); i++) {
2927                 Type t = op->args[i].type();
2928                 if (op->args[i].as<StringImm>()) {
2929                     buf_size += op->args[i].as<StringImm>()->value.size();
2930                 } else if (t.is_int() || t.is_uint()) {
2931                     buf_size += 19;  // 2^64 = 18446744073709551616
2932                 } else if (t.is_float()) {
2933                     if (t.bits() == 32) {
2934                         buf_size += 47;  // %f format of max negative float
2935                     } else {
2936                         buf_size += 14;  // Scientific notation with 6 decimal places.
2937                     }
2938                 } else if (t == type_of<halide_buffer_t *>()) {
2939                     // Not a strict upper bound (there isn't one), but ought to be enough for most buffers.
2940                     buf_size += 512;
2941                 } else {
2942                     internal_assert(t.is_handle());
2943                     buf_size += 18;  // 0x0123456789abcdef
2944                 }
2945             }
2946             // Round up to a multiple of 16 bytes.
2947             buf_size = ((buf_size + 15) / 16) * 16;
2948 
2949             // Clamp to at most 8k.
2950             if (buf_size > 8 * 1024) buf_size = 8 * 1024;
2951 
2952             // Allocate a stack array to hold the message.
2953             llvm::Value *buf = create_alloca_at_entry(i8_t, buf_size);
2954 
2955             llvm::Value *dst = buf;
2956             llvm::Value *buf_end = builder->CreateConstGEP1_32(buf, buf_size);
2957 
2958             llvm::Function *append_string = module->getFunction("halide_string_to_string");
2959             llvm::Function *append_int64 = module->getFunction("halide_int64_to_string");
2960             llvm::Function *append_uint64 = module->getFunction("halide_uint64_to_string");
2961             llvm::Function *append_double = module->getFunction("halide_double_to_string");
2962             llvm::Function *append_pointer = module->getFunction("halide_pointer_to_string");
2963             llvm::Function *append_buffer = module->getFunction("halide_buffer_to_string");
2964 
2965             internal_assert(append_string);
2966             internal_assert(append_int64);
2967             internal_assert(append_uint64);
2968             internal_assert(append_double);
2969             internal_assert(append_pointer);
2970             internal_assert(append_buffer);
2971 
2972             for (size_t i = 0; i < op->args.size(); i++) {
2973                 const StringImm *s = op->args[i].as<StringImm>();
2974                 Type t = op->args[i].type();
2975                 internal_assert(t.lanes() == 1);
2976                 vector<Value *> call_args(2);
2977                 call_args[0] = dst;
2978                 call_args[1] = buf_end;
2979 
2980                 if (s) {
2981                     call_args.push_back(codegen(op->args[i]));
2982                     dst = builder->CreateCall(append_string, call_args);
2983                 } else if (t.is_bool()) {
2984                     Value *a = codegen(op->args[i]);
2985                     Value *t = codegen(StringImm::make("true"));
2986                     Value *f = codegen(StringImm::make("false"));
2987                     call_args.push_back(builder->CreateSelect(a, t, f));
2988                     dst = builder->CreateCall(append_string, call_args);
2989                 } else if (t.is_int()) {
2990                     call_args.push_back(codegen(Cast::make(Int(64), op->args[i])));
2991                     call_args.push_back(ConstantInt::get(i32_t, 1));
2992                     dst = builder->CreateCall(append_int64, call_args);
2993                 } else if (t.is_uint()) {
2994                     call_args.push_back(codegen(Cast::make(UInt(64), op->args[i])));
2995                     call_args.push_back(ConstantInt::get(i32_t, 1));
2996                     dst = builder->CreateCall(append_uint64, call_args);
2997                 } else if (t.is_float()) {
2998                     call_args.push_back(codegen(Cast::make(Float(64), op->args[i])));
2999                     // Use scientific notation for doubles
3000                     call_args.push_back(ConstantInt::get(i32_t, t.bits() == 64 ? 1 : 0));
3001                     dst = builder->CreateCall(append_double, call_args);
3002                 } else if (t == type_of<halide_buffer_t *>()) {
3003                     Value *buf = codegen(op->args[i]);
3004                     buf = builder->CreatePointerCast(buf, append_buffer->getFunctionType()->getParamType(2));
3005                     call_args.push_back(buf);
3006                     dst = builder->CreateCall(append_buffer, call_args);
3007                 } else {
3008                     internal_assert(t.is_handle());
3009                     call_args.push_back(codegen(op->args[i]));
3010                     dst = builder->CreateCall(append_pointer, call_args);
3011                 }
3012             }
3013             if (get_target().has_feature(Target::MSAN)) {
3014                 // Note that we mark the entire buffer as initialized;
3015                 // it would be more accurate to just mark (dst - buf)
3016                 llvm::Function *annotate = module->getFunction("halide_msan_annotate_memory_is_initialized");
3017                 vector<Value *> annotate_args(3);
3018                 annotate_args[0] = get_user_context();
3019                 annotate_args[1] = buf;
3020                 annotate_args[2] = codegen(Cast::make(Int(64), buf_size));
3021                 builder->CreateCall(annotate, annotate_args);
3022             }
3023             value = buf;
3024         }
3025     } else if (op->is_intrinsic(Call::memoize_expr)) {
3026         // Used as an annotation for caching, should be invisible to
3027         // codegen. Ignore arguments beyond the first as they are only
3028         // used in the cache key.
3029         internal_assert(!op->args.empty());
3030         value = codegen(op->args[0]);
3031     } else if (op->is_intrinsic(Call::alloca)) {
3032         // The argument is the number of bytes. For now it must be
3033         // const, or a call to size_of_halide_buffer_t.
3034         internal_assert(op->args.size() == 1);
3035 
3036         // We can generate slightly cleaner IR with fewer alignment
3037         // restrictions if we recognize the most common types we
3038         // expect to get alloca'd.
3039         const Call *call = op->args[0].as<Call>();
3040         if (op->type == type_of<struct halide_buffer_t *>() &&
3041             call && call->is_intrinsic(Call::size_of_halide_buffer_t)) {
3042             value = create_alloca_at_entry(halide_buffer_t_type, 1);
3043         } else {
3044             const int64_t *sz = as_const_int(op->args[0]);
3045             internal_assert(sz);
3046             if (op->type == type_of<struct halide_dimension_t *>()) {
3047                 value = create_alloca_at_entry(dimension_t_type, *sz / sizeof(halide_dimension_t));
3048             } else {
3049                 // Just use an i8* and make the users bitcast it.
3050                 value = create_alloca_at_entry(i8_t, *sz);
3051             }
3052         }
3053     } else if (op->is_intrinsic(Call::register_destructor)) {
3054         internal_assert(op->args.size() == 2);
3055         const StringImm *fn = op->args[0].as<StringImm>();
3056         internal_assert(fn);
3057         llvm::Function *f = module->getFunction(fn->value);
3058         if (!f) {
3059             llvm::Type *arg_types[] = {i8_t->getPointerTo(), i8_t->getPointerTo()};
3060             FunctionType *func_t = FunctionType::get(void_t, arg_types, false);
3061             f = llvm::Function::Create(func_t, llvm::Function::ExternalLinkage, fn->value, module.get());
3062             f->setCallingConv(CallingConv::C);
3063         }
3064         internal_assert(op->args[1].type().is_handle());
3065         Value *arg = codegen(op->args[1]);
3066         value = register_destructor(f, arg, Always);
3067     } else if (op->is_intrinsic(Call::call_cached_indirect_function)) {
3068         // Arguments to call_cached_indirect_function are of the form
3069         //
3070         //    cond_1, "sub_function_name_1",
3071         //    cond_2, "sub_function_name_2",
3072         //    ...
3073         //    cond_N, "sub_function_name_N"
3074         //
3075         // This will generate code that corresponds (roughly) to
3076         //
3077         //    static FunctionPtr f = []{
3078         //      if (cond_1) return sub_function_name_1;
3079         //      if (cond_2) return sub_function_name_2;
3080         //      ...
3081         //      if (cond_N) return sub_function_name_N;
3082         //    }
3083         //    return f(args)
3084         //
3085         // i.e.: the conditions will be evaluated *in order*; the first one
3086         // evaluating to true will have its corresponding function cached,
3087         // which will be used to complete this (and all subsequent) calls.
3088         //
3089         // The final condition (cond_N) must evaluate to a constant TRUE
3090         // value (so that the final function will be selected if all others
3091         // fail); failure to do so will cause unpredictable results.
3092         //
3093         // There is currently no way to clear the cached function pointer.
3094         //
3095         // It is assumed/required that all of the conditions are "pure"; each
3096         // must evaluate to the same value (within a given runtime environment)
3097         // across multiple evaluations.
3098         //
3099         // It is assumed/required that all of the sub-functions have arguments
3100         // (and return values) that are identical to those of this->function.
3101         //
3102         // Note that we require >= 4 arguments: fewer would imply
3103         // only one condition+function pair, which is pointless to use
3104         // (the function should always be called directly).
3105         //
3106         internal_assert(op->args.size() >= 4);
3107         internal_assert(!(op->args.size() & 1));
3108 
3109         // Gather information we need about each function.
3110         struct SubFn {
3111             llvm::Function *fn;
3112             llvm::GlobalValue *fn_ptr;
3113             Expr cond;
3114         };
3115         vector<SubFn> sub_fns;
3116         for (size_t i = 0; i < op->args.size(); i += 2) {
3117             const string sub_fn_name = op->args[i + 1].as<StringImm>()->value;
3118             string extern_sub_fn_name = sub_fn_name;
3119             llvm::Function *sub_fn = module->getFunction(sub_fn_name);
3120             if (!sub_fn) {
3121                 extern_sub_fn_name = get_mangled_names(sub_fn_name,
3122                                                        LinkageType::External,
3123                                                        NameMangling::Default,
3124                                                        current_function_args,
3125                                                        get_target())
3126                                          .extern_name;
3127                 debug(1) << "Did not find function " << sub_fn_name
3128                          << ", assuming extern \"C\" " << extern_sub_fn_name << "\n";
3129                 vector<llvm::Type *> arg_types;
3130                 for (const auto &arg : function->args()) {
3131                     arg_types.push_back(arg.getType());
3132                 }
3133                 llvm::Type *result_type = llvm_type_of(upgrade_type_for_argument_passing(op->type));
3134                 FunctionType *func_t = FunctionType::get(result_type, arg_types, false);
3135                 sub_fn = llvm::Function::Create(func_t, llvm::Function::ExternalLinkage,
3136                                                 extern_sub_fn_name, module.get());
3137                 sub_fn->setCallingConv(CallingConv::C);
3138             }
3139 
3140             llvm::GlobalValue *sub_fn_ptr = module->getNamedValue(extern_sub_fn_name);
3141             if (!sub_fn_ptr) {
3142                 debug(1) << "Did not find function ptr " << extern_sub_fn_name << ", assuming extern \"C\".\n";
3143                 sub_fn_ptr = new GlobalVariable(*module, sub_fn->getType(),
3144                                                 /*isConstant*/ true, GlobalValue::ExternalLinkage,
3145                                                 /*initializer*/ nullptr, extern_sub_fn_name);
3146             }
3147             auto cond = op->args[i];
3148             sub_fns.push_back({sub_fn, sub_fn_ptr, cond});
3149         }
3150 
3151         // Create a null-initialized global to track this object.
3152         const auto base_fn = sub_fns.back().fn;
3153         const string global_name = unique_name(base_fn->getName().str() + "_indirect_fn_ptr");
3154         GlobalVariable *global = new GlobalVariable(
3155             *module,
3156             base_fn->getType(),
3157             /*isConstant*/ false,
3158             GlobalValue::PrivateLinkage,
3159             ConstantPointerNull::get(base_fn->getType()),
3160             global_name);
3161         LoadInst *loaded_value = builder->CreateLoad(global);
3162 
3163         BasicBlock *global_inited_bb = BasicBlock::Create(*context, "global_inited_bb", function);
3164         BasicBlock *global_not_inited_bb = BasicBlock::Create(*context, "global_not_inited_bb", function);
3165         BasicBlock *call_fn_bb = BasicBlock::Create(*context, "call_fn_bb", function);
3166 
3167         // Only init the global if not already inited.
3168         //
3169         // Note that we deliberately do not attempt to make this threadsafe via (e.g.) mutexes;
3170         // the requirements of the conditions above mean that multiple writes *should* only
3171         // be able to re-write the same value, which is harmless for our purposes, and
3172         // avoiding such code simplifies and speeds the resulting code.
3173         //
3174         // (Note that if we ever need to add a way to clear the cached function pointer,
3175         // we may need to reconsider this, to avoid amusingly horrible race conditions.)
3176         builder->CreateCondBr(builder->CreateIsNotNull(loaded_value),
3177                               global_inited_bb, global_not_inited_bb, very_likely_branch);
3178 
3179         // Build the not-already-inited case
3180         builder->SetInsertPoint(global_not_inited_bb);
3181         llvm::Value *selected_value = nullptr;
3182         for (int i = sub_fns.size() - 1; i >= 0; i--) {
3183             const auto sub_fn = sub_fns[i];
3184             if (!selected_value) {
3185                 selected_value = sub_fn.fn_ptr;
3186             } else {
3187                 Value *c = codegen(sub_fn.cond);
3188                 selected_value = builder->CreateSelect(c, sub_fn.fn_ptr, selected_value);
3189             }
3190         }
3191         builder->CreateStore(selected_value, global);
3192         builder->CreateBr(call_fn_bb);
3193 
3194         // Just an incoming edge for the Phi node
3195         builder->SetInsertPoint(global_inited_bb);
3196         builder->CreateBr(call_fn_bb);
3197 
3198         builder->SetInsertPoint(call_fn_bb);
3199         PHINode *phi = builder->CreatePHI(selected_value->getType(), 2);
3200         phi->addIncoming(selected_value, global_not_inited_bb);
3201         phi->addIncoming(loaded_value, global_inited_bb);
3202 
3203         std::vector<llvm::Value *> call_args;
3204         for (auto &arg : function->args()) {
3205             call_args.push_back(&arg);
3206         }
3207 
3208         llvm::CallInst *call = builder->CreateCall(base_fn->getFunctionType(), phi, call_args);
3209         value = call;
3210     } else if (op->is_intrinsic(Call::prefetch)) {
3211         user_assert((op->args.size() == 4) && is_one(op->args[2]))
3212             << "Only prefetch of 1 cache line is supported.\n";
3213 
3214         llvm::Function *prefetch_fn = module->getFunction("_halide_prefetch");
3215         internal_assert(prefetch_fn);
3216 
3217         vector<llvm::Value *> args;
3218         args.push_back(codegen_buffer_pointer(codegen(op->args[0]), op->type, op->args[1]));
3219         // The first argument is a pointer, which has type i8*. We
3220         // need to cast the argument, which might be a pointer to a
3221         // different type.
3222         llvm::Type *ptr_type = prefetch_fn->getFunctionType()->params()[0];
3223         args[0] = builder->CreateBitCast(args[0], ptr_type);
3224 
3225         value = builder->CreateCall(prefetch_fn, args);
3226 
3227     } else if (op->is_intrinsic(Call::signed_integer_overflow)) {
3228         user_error << "Signed integer overflow occurred during constant-folding. Signed"
3229                       " integer overflow for int32 and int64 is undefined behavior in"
3230                       " Halide.\n";
3231     } else if (op->is_intrinsic(Call::undef)) {
3232         value = UndefValue::get(llvm_type_of(op->type));
3233     } else if (op->is_intrinsic(Call::size_of_halide_buffer_t)) {
3234         llvm::DataLayout d(module.get());
3235         value = ConstantInt::get(i32_t, (int)d.getTypeAllocSize(halide_buffer_t_type));
3236     } else if (op->is_intrinsic(Call::strict_float)) {
3237         IRBuilder<llvm::ConstantFolder, llvm::IRBuilderDefaultInserter>::FastMathFlagGuard guard(*builder);
3238         llvm::FastMathFlags safe_flags;
3239         safe_flags.clear();
3240         builder->setFastMathFlags(safe_flags);
3241         builder->setDefaultFPMathTag(strict_fp_math_md);
3242         value = codegen(op->args[0]);
3243     } else if (is_float16_transcendental(op)) {
3244         value = codegen(lower_float16_transcendental_to_float32_equivalent(op));
3245     } else if (op->is_intrinsic()) {
3246         internal_error << "Unknown intrinsic: " << op->name << "\n";
3247     } else if (op->call_type == Call::PureExtern && op->name == "pow_f32") {
3248         internal_assert(op->args.size() == 2);
3249         Expr x = op->args[0];
3250         Expr y = op->args[1];
3251         Halide::Expr abs_x_pow_y = Internal::halide_exp(Internal::halide_log(abs(x)) * y);
3252         Halide::Expr nan_expr = Call::make(x.type(), "nan_f32", {}, Call::PureExtern);
3253         Expr iy = floor(y);
3254         Expr one = make_one(x.type());
3255         Expr zero = make_zero(x.type());
3256         Expr e = select(x > 0, abs_x_pow_y,        // Strictly positive x
3257                         y == 0.0f, one,            // x^0 == 1
3258                         x == 0.0f, zero,           // 0^y == 0
3259                         y != iy, nan_expr,         // negative x to a non-integer power
3260                         iy % 2 == 0, abs_x_pow_y,  // negative x to an even power
3261                         -abs_x_pow_y);             // negative x to an odd power
3262         e = common_subexpression_elimination(e);
3263         e.accept(this);
3264     } else if (op->call_type == Call::PureExtern && op->name == "log_f32") {
3265         internal_assert(op->args.size() == 1);
3266         Expr e = Internal::halide_log(op->args[0]);
3267         e.accept(this);
3268     } else if (op->call_type == Call::PureExtern && op->name == "exp_f32") {
3269         internal_assert(op->args.size() == 1);
3270         Expr e = Internal::halide_exp(op->args[0]);
3271         e.accept(this);
3272     } else if (op->call_type == Call::PureExtern &&
3273                (op->name == "is_nan_f32" || op->name == "is_nan_f64")) {
3274         internal_assert(op->args.size() == 1);
3275         Value *a = codegen(op->args[0]);
3276 
3277         /* NaNs are not supposed to exist in "no NaNs" compilation
3278          * mode, but it appears llvm special cases the unordered
3279          * compare instruction when the global NoNaNsFPMath option is
3280          * set and still checks for a NaN. However if the nnan flag is
3281          * set on the instruction itself, llvm treats the comparison
3282          * as always false. Thus we always turn off the per-instruction
3283          * fast-math flags for this instruction. I.e. it is always
3284          * treated as strict. Note that compilation may still be in
3285          * fast-math mode due to global options, but that's ok due to
3286          * the aforementioned special casing. */
3287         IRBuilder<llvm::ConstantFolder, llvm::IRBuilderDefaultInserter>::FastMathFlagGuard guard(*builder);
3288         llvm::FastMathFlags safe_flags;
3289         safe_flags.clear();
3290         builder->setFastMathFlags(safe_flags);
3291         builder->setDefaultFPMathTag(strict_fp_math_md);
3292 
3293         value = builder->CreateFCmpUNO(a, a);
3294     } else if (op->call_type == Call::PureExtern &&
3295                (op->name == "is_inf_f32" || op->name == "is_inf_f64")) {
3296         internal_assert(op->args.size() == 1);
3297 
3298         IRBuilder<llvm::ConstantFolder, llvm::IRBuilderDefaultInserter>::FastMathFlagGuard guard(*builder);
3299         llvm::FastMathFlags safe_flags;
3300         safe_flags.clear();
3301         builder->setFastMathFlags(safe_flags);
3302         builder->setDefaultFPMathTag(strict_fp_math_md);
3303 
3304         // isinf(e) -> (fabs(e) == infinity)
3305         Expr e = op->args[0];
3306         internal_assert(e.type().is_float());
3307         Expr inf = e.type().max();
3308         codegen(abs(e) == inf);
3309     } else if (op->call_type == Call::PureExtern &&
3310                (op->name == "is_finite_f32" || op->name == "is_finite_f64")) {
3311         internal_assert(op->args.size() == 1);
3312         internal_assert(op->args[0].type().is_float());
3313 
3314         IRBuilder<llvm::ConstantFolder, llvm::IRBuilderDefaultInserter>::FastMathFlagGuard guard(*builder);
3315         llvm::FastMathFlags safe_flags;
3316         safe_flags.clear();
3317         builder->setFastMathFlags(safe_flags);
3318         builder->setDefaultFPMathTag(strict_fp_math_md);
3319 
3320         // isfinite(e) -> (fabs(e) != infinity && !isnan(e)) -> (fabs(e) != infinity && e == e)
3321         Expr e = op->args[0];
3322         internal_assert(e.type().is_float());
3323         Expr inf = e.type().max();
3324         codegen(abs(e) != inf && e == e);
3325     } else {
3326         // It's an extern call.
3327 
3328         std::string name;
3329         if (op->call_type == Call::ExternCPlusPlus) {
3330             user_assert(get_target().has_feature(Target::CPlusPlusMangling)) << "Target must specify C++ name mangling (\"c_plus_plus_name_mangling\") in order to call C++ externs. (" << op->name << ")\n";
3331 
3332             std::vector<std::string> namespaces;
3333             name = extract_namespaces(op->name, namespaces);
3334             std::vector<ExternFuncArgument> mangle_args;
3335             for (const auto &arg : op->args) {
3336                 mangle_args.emplace_back(arg);
3337             }
3338             name = cplusplus_function_mangled_name(name, namespaces, op->type, mangle_args, get_target());
3339         } else {
3340             name = op->name;
3341         }
3342 
3343         // Codegen the args
3344         vector<Value *> args(op->args.size());
3345         for (size_t i = 0; i < op->args.size(); i++) {
3346             args[i] = codegen(op->args[i]);
3347         }
3348 
3349         llvm::Function *fn = module->getFunction(name);
3350 
3351         llvm::Type *result_type = llvm_type_of(upgrade_type_for_argument_passing(op->type));
3352 
3353         // Add a user context arg as needed. It's never a vector.
3354         bool takes_user_context = function_takes_user_context(op->name);
3355         if (takes_user_context) {
3356             internal_assert(fn) << "External function " << op->name << " is marked as taking user_context, but is not in the runtime module. Check if runtime_api.cpp needs to be rebuilt.\n";
3357             debug(4) << "Adding user_context to " << op->name << " args\n";
3358             args.insert(args.begin(), get_user_context());
3359         }
3360 
3361         // If we can't find it, declare it extern "C"
3362         if (!fn) {
3363             vector<llvm::Type *> arg_types(args.size());
3364             for (size_t i = 0; i < args.size(); i++) {
3365                 arg_types[i] = args[i]->getType();
3366                 if (arg_types[i]->isVectorTy()) {
3367                     VectorType *vt = dyn_cast<VectorType>(arg_types[i]);
3368                     arg_types[i] = vt->getElementType();
3369                 }
3370             }
3371 
3372             llvm::Type *scalar_result_type = result_type;
3373             if (result_type->isVectorTy()) {
3374                 VectorType *vt = dyn_cast<VectorType>(result_type);
3375                 scalar_result_type = vt->getElementType();
3376             }
3377 
3378             FunctionType *func_t = FunctionType::get(scalar_result_type, arg_types, false);
3379 
3380             fn = llvm::Function::Create(func_t, llvm::Function::ExternalLinkage, name, module.get());
3381             fn->setCallingConv(CallingConv::C);
3382             debug(4) << "Did not find " << op->name << ". Declared it extern \"C\".\n";
3383         } else {
3384             debug(4) << "Found " << op->name << "\n";
3385 
3386             // TODO: Say something more accurate here as there is now
3387             // partial information in the handle_type field, but it is
3388             // not clear it can be matched to the LLVM types and it is
3389             // not always there.
3390             // Halide's type system doesn't preserve pointer types
3391             // correctly (they just get called "Handle()"), so we may
3392             // need to pointer cast to the appropriate type. Only look at
3393             // fixed params (not varags) in llvm function.
3394             FunctionType *func_t = fn->getFunctionType();
3395             for (size_t i = takes_user_context ? 1 : 0;
3396                  i < std::min(args.size(), (size_t)(func_t->getNumParams()));
3397                  i++) {
3398                 Expr halide_arg = takes_user_context ? op->args[i - 1] : op->args[i];
3399                 if (halide_arg.type().is_handle()) {
3400                     llvm::Type *t = func_t->getParamType(i);
3401 
3402                     // Widen to vector-width as needed. If the
3403                     // function doesn't actually take a vector,
3404                     // individual lanes will be extracted below.
3405                     if (halide_arg.type().is_vector() &&
3406                         !t->isVectorTy()) {
3407                         t = get_vector_type(t, halide_arg.type().lanes());
3408                     }
3409 
3410                     if (t != args[i]->getType()) {
3411                         debug(4) << "Pointer casting argument to extern call: "
3412                                  << halide_arg << "\n";
3413                         args[i] = builder->CreatePointerCast(args[i], t);
3414                     }
3415                 }
3416             }
3417         }
3418 
3419         if (op->type.is_scalar()) {
3420             CallInst *call = builder->CreateCall(fn, args);
3421             if (op->is_pure()) {
3422                 call->setDoesNotAccessMemory();
3423             }
3424             call->setDoesNotThrow();
3425             value = call;
3426         } else {
3427 
3428             // Check if a vector version of the function already
3429             // exists at some useful width.
3430             pair<llvm::Function *, int> vec =
3431                 find_vector_runtime_function(name, op->type.lanes());
3432             llvm::Function *vec_fn = vec.first;
3433             int w = vec.second;
3434 
3435             if (vec_fn) {
3436                 value = call_intrin(llvm_type_of(op->type), w,
3437                                     get_llvm_function_name(vec_fn), args);
3438             } else {
3439 
3440                 // No vector version found. Scalarize. Extract each simd
3441                 // lane in turn and do one scalar call to the function.
3442                 value = UndefValue::get(result_type);
3443                 for (int i = 0; i < op->type.lanes(); i++) {
3444                     Value *idx = ConstantInt::get(i32_t, i);
3445                     vector<Value *> arg_lane(args.size());
3446                     for (size_t j = 0; j < args.size(); j++) {
3447                         if (args[j]->getType()->isVectorTy()) {
3448                             arg_lane[j] = builder->CreateExtractElement(args[j], idx);
3449                         } else {
3450                             arg_lane[j] = args[j];
3451                         }
3452                     }
3453                     CallInst *call = builder->CreateCall(fn, arg_lane);
3454                     if (op->is_pure()) {
3455                         call->setDoesNotAccessMemory();
3456                     }
3457                     call->setDoesNotThrow();
3458                     if (!call->getType()->isVoidTy()) {
3459                         value = builder->CreateInsertElement(value, call, idx);
3460                     }  // otherwise leave it as undef.
3461                 }
3462             }
3463         }
3464     }
3465 }
3466 
visit(const Prefetch * op)3467 void CodeGen_LLVM::visit(const Prefetch *op) {
3468     internal_error << "Prefetch encountered during codegen\n";
3469 }
3470 
visit(const Let * op)3471 void CodeGen_LLVM::visit(const Let *op) {
3472     sym_push(op->name, codegen(op->value));
3473     value = codegen(op->body);
3474     sym_pop(op->name);
3475 }
3476 
visit(const LetStmt * op)3477 void CodeGen_LLVM::visit(const LetStmt *op) {
3478     sym_push(op->name, codegen(op->value));
3479     codegen(op->body);
3480     sym_pop(op->name);
3481 }
3482 
visit(const AssertStmt * op)3483 void CodeGen_LLVM::visit(const AssertStmt *op) {
3484     create_assertion(codegen(op->condition), op->message);
3485 }
3486 
create_string_constant(const string & s)3487 Constant *CodeGen_LLVM::create_string_constant(const string &s) {
3488     map<string, Constant *>::iterator iter = string_constants.find(s);
3489     if (iter == string_constants.end()) {
3490         vector<char> data;
3491         data.reserve(s.size() + 1);
3492         data.insert(data.end(), s.begin(), s.end());
3493         data.push_back(0);
3494         Constant *val = create_binary_blob(data, "str");
3495         string_constants[s] = val;
3496         return val;
3497     } else {
3498         return iter->second;
3499     }
3500 }
3501 
create_binary_blob(const vector<char> & data,const string & name,bool constant)3502 Constant *CodeGen_LLVM::create_binary_blob(const vector<char> &data, const string &name, bool constant) {
3503     internal_assert(!data.empty());
3504     llvm::Type *type = ArrayType::get(i8_t, data.size());
3505     GlobalVariable *global = new GlobalVariable(*module, type,
3506                                                 constant, GlobalValue::PrivateLinkage,
3507                                                 0, name);
3508     ArrayRef<unsigned char> data_array((const unsigned char *)&data[0], data.size());
3509     global->setInitializer(ConstantDataArray::get(*context, data_array));
3510     size_t alignment = 32;
3511     size_t native_vector_bytes = (size_t)(native_vector_bits() / 8);
3512     if (data.size() > alignment && native_vector_bytes > alignment) {
3513         alignment = native_vector_bytes;
3514     }
3515     global->setAlignment(make_alignment(alignment));
3516 
3517     Constant *zero = ConstantInt::get(i32_t, 0);
3518     Constant *zeros[] = {zero, zero};
3519     Constant *ptr = ConstantExpr::getInBoundsGetElementPtr(type, global, zeros);
3520     return ptr;
3521 }
3522 
create_assertion(Value * cond,const Expr & message,llvm::Value * error_code)3523 void CodeGen_LLVM::create_assertion(Value *cond, const Expr &message, llvm::Value *error_code) {
3524 
3525     internal_assert(!message.defined() || message.type() == Int(32))
3526         << "Assertion result is not an int: " << message;
3527 
3528     if (target.has_feature(Target::NoAsserts)) return;
3529 
3530     // If the condition is a vector, fold it down to a scalar
3531     VectorType *vt = dyn_cast<VectorType>(cond->getType());
3532     if (vt) {
3533         Value *scalar_cond = builder->CreateExtractElement(cond, ConstantInt::get(i32_t, 0));
3534         for (int i = 1; i < get_vector_num_elements(vt); i++) {
3535             Value *lane = builder->CreateExtractElement(cond, ConstantInt::get(i32_t, i));
3536             scalar_cond = builder->CreateAnd(scalar_cond, lane);
3537         }
3538         cond = scalar_cond;
3539     }
3540 
3541     // Make a new basic block for the assert
3542     BasicBlock *assert_fails_bb = BasicBlock::Create(*context, "assert failed", function);
3543     BasicBlock *assert_succeeds_bb = BasicBlock::Create(*context, "assert succeeded", function);
3544 
3545     // If the condition fails, enter the assert body, otherwise, enter the block after
3546     builder->CreateCondBr(cond, assert_succeeds_bb, assert_fails_bb, very_likely_branch);
3547 
3548     // Build the failure case
3549     builder->SetInsertPoint(assert_fails_bb);
3550 
3551     // Call the error handler
3552     if (!error_code) error_code = codegen(message);
3553 
3554     return_with_error_code(error_code);
3555 
3556     // Continue on using the success case
3557     builder->SetInsertPoint(assert_succeeds_bb);
3558 }
3559 
return_with_error_code(llvm::Value * error_code)3560 void CodeGen_LLVM::return_with_error_code(llvm::Value *error_code) {
3561     // Branch to the destructor block, which cleans up and then bails out.
3562     BasicBlock *dtors = get_destructor_block();
3563 
3564     // Hook up our error code to the phi node that the destructor block starts with.
3565     PHINode *phi = dyn_cast<PHINode>(dtors->begin());
3566     internal_assert(phi) << "The destructor block is supposed to start with a phi node\n";
3567     phi->addIncoming(error_code, builder->GetInsertBlock());
3568 
3569     builder->CreateBr(get_destructor_block());
3570 }
3571 
visit(const ProducerConsumer * op)3572 void CodeGen_LLVM::visit(const ProducerConsumer *op) {
3573     string name;
3574     if (op->is_producer) {
3575         name = std::string("produce ") + op->name;
3576     } else {
3577         name = std::string("consume ") + op->name;
3578     }
3579     BasicBlock *produce = BasicBlock::Create(*context, name, function);
3580     builder->CreateBr(produce);
3581     builder->SetInsertPoint(produce);
3582     codegen(op->body);
3583 }
3584 
visit(const For * op)3585 void CodeGen_LLVM::visit(const For *op) {
3586     Value *min = codegen(op->min);
3587     Value *extent = codegen(op->extent);
3588     const Acquire *acquire = op->body.as<Acquire>();
3589 
3590     if (op->for_type == ForType::Parallel ||
3591         (op->for_type == ForType::Serial &&
3592          acquire &&
3593          !expr_uses_var(acquire->count, op->name))) {
3594         do_as_parallel_task(op);
3595     } else if (op->for_type == ForType::Serial) {
3596 
3597         Value *max = builder->CreateNSWAdd(min, extent);
3598 
3599         BasicBlock *preheader_bb = builder->GetInsertBlock();
3600 
3601         // Make a new basic block for the loop
3602         BasicBlock *loop_bb = BasicBlock::Create(*context, std::string("for ") + op->name, function);
3603         // Create the block that comes after the loop
3604         BasicBlock *after_bb = BasicBlock::Create(*context, std::string("end for ") + op->name, function);
3605 
3606         // If min < max, fall through to the loop bb
3607         Value *enter_condition = builder->CreateICmpSLT(min, max);
3608         builder->CreateCondBr(enter_condition, loop_bb, after_bb, very_likely_branch);
3609         builder->SetInsertPoint(loop_bb);
3610 
3611         // Make our phi node.
3612         PHINode *phi = builder->CreatePHI(i32_t, 2);
3613         phi->addIncoming(min, preheader_bb);
3614 
3615         // Within the loop, the variable is equal to the phi value
3616         sym_push(op->name, phi);
3617 
3618         // Emit the loop body
3619         codegen(op->body);
3620 
3621         // Update the counter
3622         Value *next_var = builder->CreateNSWAdd(phi, ConstantInt::get(i32_t, 1));
3623 
3624         // Add the back-edge to the phi node
3625         phi->addIncoming(next_var, builder->GetInsertBlock());
3626 
3627         // Maybe exit the loop
3628         Value *end_condition = builder->CreateICmpNE(next_var, max);
3629         builder->CreateCondBr(end_condition, loop_bb, after_bb);
3630 
3631         builder->SetInsertPoint(after_bb);
3632 
3633         // Pop the loop variable from the scope
3634         sym_pop(op->name);
3635     } else {
3636         internal_error << "Unknown type of For node. Only Serial and Parallel For nodes should survive down to codegen.\n";
3637     }
3638 }
3639 
do_parallel_tasks(const vector<ParallelTask> & tasks)3640 void CodeGen_LLVM::do_parallel_tasks(const vector<ParallelTask> &tasks) {
3641     Closure closure;
3642     for (const auto &t : tasks) {
3643         Stmt s = t.body;
3644         if (!t.loop_var.empty()) {
3645             s = LetStmt::make(t.loop_var, 0, s);
3646         }
3647         s.accept(&closure);
3648     }
3649 
3650     // Allocate a closure
3651     StructType *closure_t = build_closure_type(closure, halide_buffer_t_type, context);
3652     Value *closure_ptr = create_alloca_at_entry(closure_t, 1);
3653 
3654     // Fill in the closure
3655     pack_closure(closure_t, closure_ptr, closure, symbol_table, halide_buffer_t_type, builder);
3656 
3657     closure_ptr = builder->CreatePointerCast(closure_ptr, i8_t->getPointerTo());
3658 
3659     int num_tasks = (int)tasks.size();
3660 
3661     // Make space on the stack for the tasks
3662     llvm::Value *task_stack_ptr = create_alloca_at_entry(parallel_task_t_type, num_tasks);
3663 
3664     llvm::Type *args_t[] = {i8_t->getPointerTo(), i32_t, i8_t->getPointerTo()};
3665     FunctionType *task_t = FunctionType::get(i32_t, args_t, false);
3666     llvm::Type *loop_args_t[] = {i8_t->getPointerTo(), i32_t, i32_t, i8_t->getPointerTo(), i8_t->getPointerTo()};
3667     FunctionType *loop_task_t = FunctionType::get(i32_t, loop_args_t, false);
3668 
3669     Value *result = nullptr;
3670 
3671     for (int i = 0; i < num_tasks; i++) {
3672         ParallelTask t = tasks[i];
3673 
3674         // Analyze the task body
3675         class MayBlock : public IRVisitor {
3676             using IRVisitor::visit;
3677             void visit(const Acquire *op) override {
3678                 result = true;
3679             }
3680 
3681         public:
3682             bool result = false;
3683         };
3684 
3685         // TODO(zvookin|abadams): This makes multiple passes over the
3686         // IR to cover each node. (One tree walk produces the min
3687         // thread count for all nodes, but we redo each subtree when
3688         // compiling a given node.) Ideally we'd move to a lowering pass
3689         // that converts our parallelism constructs to Call nodes, or
3690         // direct hardware operations in some cases.
3691         // Also, this code has to exactly mirror the logic in get_parallel_tasks.
3692         // It would be better to do one pass on the tree and centralize the task
3693         // deduction logic in one place.
3694         class MinThreads : public IRVisitor {
3695             using IRVisitor::visit;
3696 
3697             std::pair<Stmt, int> skip_acquires(Stmt first) {
3698                 int count = 0;
3699                 while (first.defined()) {
3700                     const Acquire *acq = first.as<Acquire>();
3701                     if (acq == nullptr) {
3702                         break;
3703                     }
3704                     count++;
3705                     first = acq->body;
3706                 }
3707                 return {first, count};
3708             }
3709 
3710             void visit(const Fork *op) override {
3711                 int total_threads = 0;
3712                 int direct_acquires = 0;
3713                 // Take the sum of min threads across all
3714                 // cascaded Fork nodes.
3715                 const Fork *node = op;
3716                 while (node != NULL) {
3717                     result = 0;
3718                     auto after_acquires = skip_acquires(node->first);
3719                     direct_acquires += after_acquires.second;
3720 
3721                     after_acquires.first.accept(this);
3722                     total_threads += result;
3723 
3724                     const Fork *continued_branches = node->rest.as<Fork>();
3725                     if (continued_branches == NULL) {
3726                         result = 0;
3727                         after_acquires = skip_acquires(node->rest);
3728                         direct_acquires += after_acquires.second;
3729                         after_acquires.first.accept(this);
3730                         total_threads += result;
3731                     }
3732                     node = continued_branches;
3733                 }
3734                 if (direct_acquires == 0 && total_threads == 0) {
3735                     result = 0;
3736                 } else {
3737                     result = total_threads + 1;
3738                 }
3739             }
3740 
3741             void visit(const For *op) override {
3742                 result = 0;
3743 
3744                 if (op->for_type == ForType::Parallel) {
3745                     IRVisitor::visit(op);
3746                     if (result > 0) {
3747                         result += 1;
3748                     }
3749                 } else if (op->for_type == ForType::Serial) {
3750                     auto after_acquires = skip_acquires(op->body);
3751                     if (after_acquires.second > 0 &&
3752                         !expr_uses_var(op->body.as<Acquire>()->count, op->name)) {
3753                         after_acquires.first.accept(this);
3754                         result++;
3755                     } else {
3756                         IRVisitor::visit(op);
3757                     }
3758                 } else {
3759                     IRVisitor::visit(op);
3760                 }
3761             }
3762 
3763             // This is a "standalone" Acquire and will result in its own task.
3764             // Treat it requiring one more thread than its body.
3765             void visit(const Acquire *op) override {
3766                 result = 0;
3767                 auto after_inner_acquires = skip_acquires(op);
3768                 after_inner_acquires.first.accept(this);
3769                 result = result + 1;
3770             }
3771 
3772             void visit(const Block *op) override {
3773                 result = 0;
3774                 op->first.accept(this);
3775                 int result_first = result;
3776                 result = 0;
3777                 op->rest.accept(this);
3778                 result = std::max(result, result_first);
3779             }
3780 
3781         public:
3782             int result = 0;
3783         };
3784         MinThreads min_threads;
3785         t.body.accept(&min_threads);
3786 
3787         // Decide if we're going to call do_par_for or
3788         // do_parallel_tasks. halide_do_par_for is simpler, but
3789         // assumes a bunch of things. Programs that don't use async
3790         // can also enter the task system via do_par_for.
3791         Value *task_parent = sym_get("__task_parent", false);
3792         bool use_do_par_for = (num_tasks == 1 &&
3793                                min_threads.result == 0 &&
3794                                t.semaphores.empty() &&
3795                                !task_parent);
3796 
3797         // Make the array of semaphore acquisitions this task needs to do before it runs.
3798         Value *semaphores;
3799         Value *num_semaphores = ConstantInt::get(i32_t, (int)t.semaphores.size());
3800         if (!t.semaphores.empty()) {
3801             semaphores = create_alloca_at_entry(semaphore_acquire_t_type, (int)t.semaphores.size());
3802             for (int i = 0; i < (int)t.semaphores.size(); i++) {
3803                 Value *semaphore = codegen(t.semaphores[i].semaphore);
3804                 semaphore = builder->CreatePointerCast(semaphore, semaphore_t_type->getPointerTo());
3805                 Value *count = codegen(t.semaphores[i].count);
3806                 Value *slot_ptr = builder->CreateConstGEP2_32(semaphore_acquire_t_type, semaphores, i, 0);
3807                 builder->CreateStore(semaphore, slot_ptr);
3808                 slot_ptr = builder->CreateConstGEP2_32(semaphore_acquire_t_type, semaphores, i, 1);
3809                 builder->CreateStore(count, slot_ptr);
3810             }
3811         } else {
3812             semaphores = ConstantPointerNull::get(semaphore_acquire_t_type->getPointerTo());
3813         }
3814 
3815         FunctionType *fn_type = use_do_par_for ? task_t : loop_task_t;
3816         int closure_arg_idx = use_do_par_for ? 2 : 3;
3817 
3818         // Make a new function that does the body
3819         llvm::Function *containing_function = function;
3820         function = llvm::Function::Create(fn_type, llvm::Function::InternalLinkage,
3821                                           t.name, module.get());
3822 
3823         llvm::Value *task_ptr = builder->CreatePointerCast(function, fn_type->getPointerTo());
3824 
3825         function->addParamAttr(closure_arg_idx, Attribute::NoAlias);
3826 
3827         set_function_attributes_for_target(function, target);
3828 
3829         // Make the initial basic block and jump the builder into the new function
3830         IRBuilderBase::InsertPoint call_site = builder->saveIP();
3831         BasicBlock *block = BasicBlock::Create(*context, "entry", function);
3832         builder->SetInsertPoint(block);
3833 
3834         // Save the destructor block
3835         BasicBlock *parent_destructor_block = destructor_block;
3836         destructor_block = nullptr;
3837 
3838         // Make a new scope to use
3839         Scope<Value *> saved_symbol_table;
3840         symbol_table.swap(saved_symbol_table);
3841 
3842         // Get the function arguments
3843 
3844         // The user context is first argument of the function; it's
3845         // important that we override the name to be "__user_context",
3846         // since the LLVM function has a random auto-generated name for
3847         // this argument.
3848         llvm::Function::arg_iterator iter = function->arg_begin();
3849         sym_push("__user_context", iterator_to_pointer(iter));
3850 
3851         if (use_do_par_for) {
3852             // Next is the loop variable.
3853             ++iter;
3854             sym_push(t.loop_var, iterator_to_pointer(iter));
3855         } else if (!t.loop_var.empty()) {
3856             // We peeled off a loop. Wrap a new loop around the body
3857             // that just does the slice given by the arguments.
3858             string loop_min_name = unique_name('t');
3859             string loop_extent_name = unique_name('t');
3860             t.body = For::make(t.loop_var,
3861                                Variable::make(Int(32), loop_min_name),
3862                                Variable::make(Int(32), loop_extent_name),
3863                                ForType::Serial,
3864                                DeviceAPI::None,
3865                                t.body);
3866             ++iter;
3867             sym_push(loop_min_name, iterator_to_pointer(iter));
3868             ++iter;
3869             sym_push(loop_extent_name, iterator_to_pointer(iter));
3870         } else {
3871             // This task is not any kind of loop, so skip these args.
3872             ++iter;
3873             ++iter;
3874         }
3875 
3876         // The closure pointer is either the last (for halide_do_par_for) or
3877         // second to last argument (for halide_do_parallel_tasks).
3878         ++iter;
3879         iter->setName("closure");
3880         Value *closure_handle = builder->CreatePointerCast(iterator_to_pointer(iter),
3881                                                            closure_t->getPointerTo());
3882 
3883         // Load everything from the closure into the new scope
3884         unpack_closure(closure, symbol_table, closure_t, closure_handle, builder);
3885 
3886         if (!use_do_par_for) {
3887             // For halide_do_parallel_tasks the threading runtime task parent
3888             // is the last argument.
3889             ++iter;
3890             iter->setName("task_parent");
3891             sym_push("__task_parent", iterator_to_pointer(iter));
3892         }
3893 
3894         // Generate the new function body
3895         codegen(t.body);
3896 
3897         // Return success
3898         return_with_error_code(ConstantInt::get(i32_t, 0));
3899 
3900         // Move the builder back to the main function.
3901         builder->restoreIP(call_site);
3902 
3903         // Now restore the scope
3904         symbol_table.swap(saved_symbol_table);
3905         function = containing_function;
3906 
3907         // Restore the destructor block
3908         destructor_block = parent_destructor_block;
3909 
3910         Value *min = codegen(t.min);
3911         Value *extent = codegen(t.extent);
3912         Value *serial = codegen(cast(UInt(8), t.serial));
3913 
3914         if (use_do_par_for) {
3915             llvm::Function *do_par_for = module->getFunction("halide_do_par_for");
3916             internal_assert(do_par_for) << "Could not find halide_do_par_for in initial module\n";
3917             do_par_for->addParamAttr(4, Attribute::NoAlias);
3918             Value *args[] = {get_user_context(), task_ptr, min, extent, closure_ptr};
3919             debug(4) << "Creating call to do_par_for\n";
3920             result = builder->CreateCall(do_par_for, args);
3921         } else {
3922             // Populate the task struct
3923             Value *slot_ptr = builder->CreateConstGEP2_32(parallel_task_t_type, task_stack_ptr, i, 0);
3924             builder->CreateStore(task_ptr, slot_ptr);
3925             slot_ptr = builder->CreateConstGEP2_32(parallel_task_t_type, task_stack_ptr, i, 1);
3926             builder->CreateStore(closure_ptr, slot_ptr);
3927             slot_ptr = builder->CreateConstGEP2_32(parallel_task_t_type, task_stack_ptr, i, 2);
3928             builder->CreateStore(create_string_constant(t.name), slot_ptr);
3929             slot_ptr = builder->CreateConstGEP2_32(parallel_task_t_type, task_stack_ptr, i, 3);
3930             builder->CreateStore(semaphores, slot_ptr);
3931             slot_ptr = builder->CreateConstGEP2_32(parallel_task_t_type, task_stack_ptr, i, 4);
3932             builder->CreateStore(num_semaphores, slot_ptr);
3933             slot_ptr = builder->CreateConstGEP2_32(parallel_task_t_type, task_stack_ptr, i, 5);
3934             builder->CreateStore(min, slot_ptr);
3935             slot_ptr = builder->CreateConstGEP2_32(parallel_task_t_type, task_stack_ptr, i, 6);
3936             builder->CreateStore(extent, slot_ptr);
3937             slot_ptr = builder->CreateConstGEP2_32(parallel_task_t_type, task_stack_ptr, i, 7);
3938             builder->CreateStore(ConstantInt::get(i32_t, min_threads.result), slot_ptr);
3939             slot_ptr = builder->CreateConstGEP2_32(parallel_task_t_type, task_stack_ptr, i, 8);
3940             builder->CreateStore(serial, slot_ptr);
3941         }
3942     }
3943 
3944     if (!result) {
3945         llvm::Function *do_parallel_tasks = module->getFunction("halide_do_parallel_tasks");
3946         internal_assert(do_parallel_tasks) << "Could not find halide_do_parallel_tasks in initial module\n";
3947         do_parallel_tasks->addParamAttr(2, Attribute::NoAlias);
3948         Value *task_parent = sym_get("__task_parent", false);
3949         if (!task_parent) {
3950             task_parent = ConstantPointerNull::get(i8_t->getPointerTo());  // void*
3951         }
3952         Value *args[] = {get_user_context(),
3953                          ConstantInt::get(i32_t, num_tasks),
3954                          task_stack_ptr,
3955                          task_parent};
3956         result = builder->CreateCall(do_parallel_tasks, args);
3957     }
3958 
3959     // Check for success
3960     Value *did_succeed = builder->CreateICmpEQ(result, ConstantInt::get(i32_t, 0));
3961     create_assertion(did_succeed, Expr(), result);
3962 }
3963 
3964 namespace {
3965 
task_debug_name(const std::pair<string,int> & prefix)3966 string task_debug_name(const std::pair<string, int> &prefix) {
3967     if (prefix.second <= 1) {
3968         return prefix.first;
3969     } else {
3970         return prefix.first + "_" + std::to_string(prefix.second - 1);
3971     }
3972 }
3973 
add_fork(std::pair<string,int> & prefix)3974 void add_fork(std::pair<string, int> &prefix) {
3975     if (prefix.second == 0) {
3976         prefix.first += ".fork";
3977     }
3978     prefix.second++;
3979 }
3980 
add_suffix(std::pair<string,int> & prefix,const string & suffix)3981 void add_suffix(std::pair<string, int> &prefix, const string &suffix) {
3982     if (prefix.second > 1) {
3983         prefix.first += "_" + std::to_string(prefix.second - 1);
3984         prefix.second = 0;
3985     }
3986     prefix.first += suffix;
3987 }
3988 
3989 }  // namespace
3990 
get_parallel_tasks(const Stmt & s,vector<ParallelTask> & result,std::pair<string,int> prefix)3991 void CodeGen_LLVM::get_parallel_tasks(const Stmt &s, vector<ParallelTask> &result, std::pair<string, int> prefix) {
3992     const For *loop = s.as<For>();
3993     const Acquire *acquire = loop ? loop->body.as<Acquire>() : s.as<Acquire>();
3994     if (const Fork *f = s.as<Fork>()) {
3995         add_fork(prefix);
3996         get_parallel_tasks(f->first, result, prefix);
3997         get_parallel_tasks(f->rest, result, prefix);
3998     } else if (!loop && acquire) {
3999         const Variable *v = acquire->semaphore.as<Variable>();
4000         internal_assert(v);
4001         add_suffix(prefix, "." + v->name);
4002         ParallelTask t{s, {}, "", 0, 1, const_false(), task_debug_name(prefix)};
4003         while (acquire) {
4004             t.semaphores.push_back({acquire->semaphore, acquire->count});
4005             t.body = acquire->body;
4006             acquire = t.body.as<Acquire>();
4007         }
4008         result.push_back(t);
4009     } else if (loop && loop->for_type == ForType::Parallel) {
4010         add_suffix(prefix, ".par_for." + loop->name);
4011         result.push_back(ParallelTask{loop->body, {}, loop->name, loop->min, loop->extent, const_false(), task_debug_name(prefix)});
4012     } else if (loop &&
4013                loop->for_type == ForType::Serial &&
4014                acquire &&
4015                !expr_uses_var(acquire->count, loop->name)) {
4016         const Variable *v = acquire->semaphore.as<Variable>();
4017         internal_assert(v);
4018         add_suffix(prefix, ".for." + v->name);
4019         ParallelTask t{loop->body, {}, loop->name, loop->min, loop->extent, const_true(), task_debug_name(prefix)};
4020         while (acquire) {
4021             t.semaphores.push_back({acquire->semaphore, acquire->count});
4022             t.body = acquire->body;
4023             acquire = t.body.as<Acquire>();
4024         }
4025         result.push_back(t);
4026     } else {
4027         add_suffix(prefix, "." + std::to_string(result.size()));
4028         result.push_back(ParallelTask{s, {}, "", 0, 1, const_false(), task_debug_name(prefix)});
4029     }
4030 }
4031 
do_as_parallel_task(const Stmt & s)4032 void CodeGen_LLVM::do_as_parallel_task(const Stmt &s) {
4033     vector<ParallelTask> tasks;
4034     get_parallel_tasks(s, tasks, {function->getName().str(), 0});
4035     do_parallel_tasks(tasks);
4036 }
4037 
visit(const Acquire * op)4038 void CodeGen_LLVM::visit(const Acquire *op) {
4039     do_as_parallel_task(op);
4040 }
4041 
visit(const Fork * op)4042 void CodeGen_LLVM::visit(const Fork *op) {
4043     do_as_parallel_task(op);
4044 }
4045 
visit(const Store * op)4046 void CodeGen_LLVM::visit(const Store *op) {
4047     Halide::Type value_type = op->value.type();
4048     Halide::Type storage_type = upgrade_type_for_storage(value_type);
4049     if (value_type != storage_type) {
4050         Expr v = reinterpret(storage_type, op->value);
4051         codegen(Store::make(op->name, v, op->index, op->param, op->predicate, op->alignment));
4052         return;
4053     }
4054 
4055     if (inside_atomic_mutex_node) {
4056         user_assert(value_type.is_scalar())
4057             << "The vectorized atomic operation for the store " << op->name
4058             << " is lowered into a mutex lock, which does not support vectorization.\n";
4059     }
4060 
4061     // Issue atomic store if we are inside an atomic node.
4062     if (emit_atomic_stores) {
4063         codegen_atomic_store(op);
4064         return;
4065     }
4066 
4067     // Predicated store.
4068     if (!is_one(op->predicate)) {
4069         codegen_predicated_vector_store(op);
4070         return;
4071     }
4072 
4073     Value *val = codegen(op->value);
4074     bool is_external = (external_buffer.find(op->name) != external_buffer.end());
4075     // Scalar
4076     if (value_type.is_scalar()) {
4077         Value *ptr = codegen_buffer_pointer(op->name, value_type, op->index);
4078         StoreInst *store = builder->CreateAlignedStore(val, ptr, make_alignment(value_type.bytes()));
4079         add_tbaa_metadata(store, op->name, op->index);
4080     } else if (const Let *let = op->index.as<Let>()) {
4081         Stmt s = Store::make(op->name, op->value, let->body, op->param, op->predicate, op->alignment);
4082         codegen(LetStmt::make(let->name, let->value, s));
4083     } else {
4084         int alignment = value_type.bytes();
4085         const Ramp *ramp = op->index.as<Ramp>();
4086         if (ramp && is_one(ramp->stride)) {
4087 
4088             int native_bits = native_vector_bits();
4089             int native_bytes = native_bits / 8;
4090 
4091             // Boost the alignment if possible, up to the native vector width.
4092             ModulusRemainder mod_rem = op->alignment;
4093             while ((mod_rem.remainder & 1) == 0 &&
4094                    (mod_rem.modulus & 1) == 0 &&
4095                    alignment < native_bytes) {
4096                 mod_rem.modulus /= 2;
4097                 mod_rem.remainder /= 2;
4098                 alignment *= 2;
4099             }
4100 
4101             // If it is an external buffer, then we cannot assume that the host pointer
4102             // is aligned to at least the native vector width. However, we may be able to do
4103             // better than just assuming that it is unaligned.
4104             if (is_external && op->param.defined()) {
4105                 int host_alignment = op->param.host_alignment();
4106                 alignment = gcd(alignment, host_alignment);
4107             }
4108 
4109             // For dense vector stores wider than the native vector
4110             // width, bust them up into native vectors.
4111             int store_lanes = value_type.lanes();
4112             int native_lanes = native_bits / value_type.bits();
4113 
4114             for (int i = 0; i < store_lanes; i += native_lanes) {
4115                 int slice_lanes = std::min(native_lanes, store_lanes - i);
4116                 Expr slice_base = simplify(ramp->base + i);
4117                 Expr slice_stride = make_one(slice_base.type());
4118                 Expr slice_index = slice_lanes == 1 ? slice_base : Ramp::make(slice_base, slice_stride, slice_lanes);
4119                 Value *slice_val = slice_vector(val, i, slice_lanes);
4120                 Value *elt_ptr = codegen_buffer_pointer(op->name, value_type.element_of(), slice_base);
4121                 Value *vec_ptr = builder->CreatePointerCast(elt_ptr, slice_val->getType()->getPointerTo());
4122                 StoreInst *store = builder->CreateAlignedStore(slice_val, vec_ptr, make_alignment(alignment));
4123                 add_tbaa_metadata(store, op->name, slice_index);
4124             }
4125         } else if (ramp) {
4126             Type ptr_type = value_type.element_of();
4127             Value *ptr = codegen_buffer_pointer(op->name, ptr_type, ramp->base);
4128             const IntImm *const_stride = ramp->stride.as<IntImm>();
4129             Value *stride = codegen(ramp->stride);
4130             // Scatter without generating the indices as a vector
4131             for (int i = 0; i < ramp->lanes; i++) {
4132                 Constant *lane = ConstantInt::get(i32_t, i);
4133                 Value *v = builder->CreateExtractElement(val, lane);
4134                 if (const_stride) {
4135                     // Use a constant offset from the base pointer
4136                     Value *p =
4137                         builder->CreateConstInBoundsGEP1_32(
4138                             llvm_type_of(ptr_type),
4139                             ptr,
4140                             const_stride->value * i);
4141                     StoreInst *store = builder->CreateStore(v, p);
4142                     add_tbaa_metadata(store, op->name, op->index);
4143                 } else {
4144                     // Increment the pointer by the stride for each element
4145                     StoreInst *store = builder->CreateStore(v, ptr);
4146                     add_tbaa_metadata(store, op->name, op->index);
4147                     ptr = builder->CreateInBoundsGEP(ptr, stride);
4148                 }
4149             }
4150         } else {
4151             // Scatter
4152             Value *index = codegen(op->index);
4153             for (int i = 0; i < value_type.lanes(); i++) {
4154                 Value *lane = ConstantInt::get(i32_t, i);
4155                 Value *idx = builder->CreateExtractElement(index, lane);
4156                 Value *v = builder->CreateExtractElement(val, lane);
4157                 Value *ptr = codegen_buffer_pointer(op->name, value_type.element_of(), idx);
4158                 StoreInst *store = builder->CreateStore(v, ptr);
4159                 add_tbaa_metadata(store, op->name, op->index);
4160             }
4161         }
4162     }
4163 }
4164 
codegen_asserts(const vector<const AssertStmt * > & asserts)4165 void CodeGen_LLVM::codegen_asserts(const vector<const AssertStmt *> &asserts) {
4166     if (target.has_feature(Target::NoAsserts)) {
4167         return;
4168     }
4169 
4170     if (asserts.size() < 4) {
4171         for (const auto *a : asserts) {
4172             codegen(Stmt(a));
4173         }
4174         return;
4175     }
4176 
4177     internal_assert(asserts.size() <= 63);
4178 
4179     // Mix all the conditions together into a bitmask
4180 
4181     Expr bitmask = cast<uint64_t>(1) << 63;
4182     for (size_t i = 0; i < asserts.size(); i++) {
4183         bitmask = bitmask | (cast<uint64_t>(!asserts[i]->condition) << i);
4184     }
4185 
4186     Expr switch_case = count_trailing_zeros(bitmask);
4187 
4188     BasicBlock *no_errors_bb = BasicBlock::Create(*context, "no_errors_bb", function);
4189 
4190     // Now switch on the bitmask to the correct failure
4191     Expr case_idx = cast<int32_t>(count_trailing_zeros(bitmask));
4192     llvm::SmallVector<uint32_t, 64> weights;
4193     weights.push_back(1 << 30);
4194     for (int i = 0; i < (int)asserts.size(); i++) {
4195         weights.push_back(0);
4196     }
4197     llvm::MDBuilder md_builder(*context);
4198     llvm::MDNode *switch_very_likely_branch = md_builder.createBranchWeights(weights);
4199     auto *switch_inst = builder->CreateSwitch(codegen(case_idx), no_errors_bb, asserts.size(), switch_very_likely_branch);
4200     for (int i = 0; i < (int)asserts.size(); i++) {
4201         BasicBlock *fail_bb = BasicBlock::Create(*context, "assert_failed", function);
4202         switch_inst->addCase(ConstantInt::get(IntegerType::get(*context, 32), i), fail_bb);
4203         builder->SetInsertPoint(fail_bb);
4204         Value *v = codegen(asserts[i]->message);
4205         builder->CreateRet(v);
4206     }
4207     builder->SetInsertPoint(no_errors_bb);
4208 }
4209 
visit(const Block * op)4210 void CodeGen_LLVM::visit(const Block *op) {
4211     // Peel blocks of assertions with pure conditions
4212     const AssertStmt *a = op->first.as<AssertStmt>();
4213     if (a && is_pure(a->condition)) {
4214         vector<const AssertStmt *> asserts;
4215         asserts.push_back(a);
4216         Stmt s = op->rest;
4217         while ((op = s.as<Block>()) && (a = op->first.as<AssertStmt>()) && is_pure(a->condition) && asserts.size() < 63) {
4218             asserts.push_back(a);
4219             s = op->rest;
4220         }
4221         codegen_asserts(asserts);
4222         codegen(s);
4223     } else {
4224         codegen(op->first);
4225         codegen(op->rest);
4226     }
4227 }
4228 
visit(const Realize * op)4229 void CodeGen_LLVM::visit(const Realize *op) {
4230     internal_error << "Realize encountered during codegen\n";
4231 }
4232 
visit(const Provide * op)4233 void CodeGen_LLVM::visit(const Provide *op) {
4234     internal_error << "Provide encountered during codegen\n";
4235 }
4236 
visit(const IfThenElse * op)4237 void CodeGen_LLVM::visit(const IfThenElse *op) {
4238     BasicBlock *true_bb = BasicBlock::Create(*context, "true_bb", function);
4239     BasicBlock *false_bb = BasicBlock::Create(*context, "false_bb", function);
4240     BasicBlock *after_bb = BasicBlock::Create(*context, "after_bb", function);
4241     builder->CreateCondBr(codegen(op->condition), true_bb, false_bb);
4242 
4243     builder->SetInsertPoint(true_bb);
4244     codegen(op->then_case);
4245     builder->CreateBr(after_bb);
4246 
4247     builder->SetInsertPoint(false_bb);
4248     if (op->else_case.defined()) {
4249         codegen(op->else_case);
4250     }
4251     builder->CreateBr(after_bb);
4252 
4253     builder->SetInsertPoint(after_bb);
4254 }
4255 
visit(const Evaluate * op)4256 void CodeGen_LLVM::visit(const Evaluate *op) {
4257     codegen(op->value);
4258 
4259     // Discard result
4260     value = nullptr;
4261 }
4262 
visit(const Shuffle * op)4263 void CodeGen_LLVM::visit(const Shuffle *op) {
4264     if (op->is_interleave()) {
4265         vector<Value *> vecs;
4266         for (Expr i : op->vectors) {
4267             vecs.push_back(codegen(i));
4268         }
4269         value = interleave_vectors(vecs);
4270     } else {
4271         vector<Value *> vecs;
4272         for (Expr i : op->vectors) {
4273             vecs.push_back(codegen(i));
4274         }
4275         value = concat_vectors(vecs);
4276         if (op->is_concat()) {
4277             // If this is just a concat, we're done.
4278         } else if (op->is_slice() && op->slice_stride() == 1) {
4279             value = slice_vector(value, op->indices[0], op->indices.size());
4280         } else {
4281             value = shuffle_vectors(value, op->indices);
4282         }
4283     }
4284 
4285     if (op->type.is_scalar() && value->getType()->isVectorTy()) {
4286         value = builder->CreateExtractElement(value, ConstantInt::get(i32_t, 0));
4287     }
4288 }
4289 
visit(const VectorReduce * op)4290 void CodeGen_LLVM::visit(const VectorReduce *op) {
4291     codegen_vector_reduce(op, Expr());
4292 }
4293 
codegen_vector_reduce(const VectorReduce * op,const Expr & init)4294 void CodeGen_LLVM::codegen_vector_reduce(const VectorReduce *op, const Expr &init) {
4295     Expr val = op->value;
4296     const int output_lanes = op->type.lanes();
4297     const int native_lanes = native_vector_bits() / op->type.bits();
4298     const int factor = val.type().lanes() / output_lanes;
4299 
4300     Expr (*binop)(Expr, Expr) = nullptr;
4301     switch (op->op) {
4302     case VectorReduce::Add:
4303         binop = Add::make;
4304         break;
4305     case VectorReduce::Mul:
4306         binop = Mul::make;
4307         break;
4308     case VectorReduce::Min:
4309         binop = Min::make;
4310         break;
4311     case VectorReduce::Max:
4312         binop = Max::make;
4313         break;
4314     case VectorReduce::And:
4315         binop = And::make;
4316         break;
4317     case VectorReduce::Or:
4318         binop = Or::make;
4319         break;
4320     }
4321 
4322     if (op->type.is_bool() && op->op == VectorReduce::Or) {
4323         // Cast to u8, use max, cast back to bool.
4324         Expr equiv = cast(op->value.type().with_bits(8), op->value);
4325         equiv = VectorReduce::make(VectorReduce::Max, equiv, op->type.lanes());
4326         if (init.defined()) {
4327             equiv = max(equiv, init);
4328         }
4329         equiv = cast(op->type, equiv);
4330         equiv.accept(this);
4331         return;
4332     }
4333 
4334     if (op->type.is_bool() && op->op == VectorReduce::And) {
4335         // Cast to u8, use min, cast back to bool.
4336         Expr equiv = cast(op->value.type().with_bits(8), op->value);
4337         equiv = VectorReduce::make(VectorReduce::Min, equiv, op->type.lanes());
4338         equiv = cast(op->type, equiv);
4339         if (init.defined()) {
4340             equiv = min(equiv, init);
4341         }
4342         equiv.accept(this);
4343         return;
4344     }
4345 
4346     if (op->type.element_of() == Float(16)) {
4347         Expr equiv = cast(op->value.type().with_bits(32), op->value);
4348         equiv = VectorReduce::make(op->op, equiv, op->type.lanes());
4349         if (init.defined()) {
4350             equiv = binop(equiv, init);
4351         }
4352         equiv = cast(op->type, equiv);
4353         equiv.accept(this);
4354         return;
4355     }
4356 
4357 #if LLVM_VERSION >= 90
4358     if (output_lanes == 1 &&
4359         (target.arch != Target::ARM || LLVM_VERSION >= 100)) {
4360         const int input_lanes = val.type().lanes();
4361         const int input_bytes = input_lanes * val.type().bytes();
4362         const bool llvm_has_intrinsic =
4363             // Must be one of these ops
4364             ((op->op == VectorReduce::Add ||
4365               op->op == VectorReduce::Mul ||
4366               op->op == VectorReduce::Min ||
4367               op->op == VectorReduce::Max) &&
4368              // Must be a power of two lanes
4369              (input_lanes >= 2) &&
4370              ((input_lanes & (input_lanes - 1)) == 0) &&
4371              // int versions exist up to 1024 bits
4372              ((!op->type.is_float() && input_bytes <= 1024) ||
4373               // float versions exist up to 16 lanes
4374               input_lanes <= 16) &&
4375              // As of the release of llvm 10, the 64-bit experimental total
4376              // reductions don't seem to be done yet on arm.
4377              (val.type().bits() != 64 ||
4378               target.arch != Target::ARM));
4379 
4380         if (llvm_has_intrinsic) {
4381             std::stringstream name;
4382             name << "llvm.experimental.vector.reduce.";
4383             const int bits = op->type.bits();
4384             bool takes_initial_value = false;
4385             Expr initial_value = init;
4386             if (op->type.is_float()) {
4387                 switch (op->op) {
4388                 case VectorReduce::Add:
4389                     name << "v2.fadd.f" << bits;
4390                     takes_initial_value = true;
4391                     if (!initial_value.defined()) {
4392                         initial_value = make_zero(op->type);
4393                     }
4394                     break;
4395                 case VectorReduce::Mul:
4396                     name << "v2.fmul.f" << bits;
4397                     takes_initial_value = true;
4398                     if (!initial_value.defined()) {
4399                         initial_value = make_one(op->type);
4400                     }
4401                     break;
4402                 case VectorReduce::Min:
4403                     name << "fmin";
4404                     break;
4405                 case VectorReduce::Max:
4406                     name << "fmax";
4407                     break;
4408                 default:
4409                     break;
4410                 }
4411             } else if (op->type.is_int() || op->type.is_uint()) {
4412                 switch (op->op) {
4413                 case VectorReduce::Add:
4414                     name << "add";
4415                     break;
4416                 case VectorReduce::Mul:
4417                     name << "mul";
4418                     break;
4419                 case VectorReduce::Min:
4420                     name << (op->type.is_int() ? 's' : 'u') << "min";
4421                     break;
4422                 case VectorReduce::Max:
4423                     name << (op->type.is_int() ? 's' : 'u') << "max";
4424                     break;
4425                 default:
4426                     break;
4427                 }
4428             }
4429             name << ".v" << val.type().lanes() << (op->type.is_float() ? 'f' : 'i') << bits;
4430 
4431             string intrin_name = name.str();
4432 
4433             vector<Expr> args;
4434             if (takes_initial_value) {
4435                 args.push_back(initial_value);
4436                 initial_value = Expr();
4437             }
4438             args.push_back(op->value);
4439 
4440             // Make sure the declaration exists, or the codegen for
4441             // call will assume that the args should scalarize.
4442             if (!module->getFunction(intrin_name)) {
4443                 vector<llvm::Type *> arg_types;
4444                 for (const Expr &e : args) {
4445                     arg_types.push_back(llvm_type_of(e.type()));
4446                 }
4447                 FunctionType *func_t = FunctionType::get(llvm_type_of(op->type), arg_types, false);
4448                 llvm::Function::Create(func_t, llvm::Function::ExternalLinkage, intrin_name, module.get());
4449             }
4450 
4451             Expr equiv = Call::make(op->type, intrin_name, args, Call::PureExtern);
4452             if (initial_value.defined()) {
4453                 equiv = binop(initial_value, equiv);
4454             }
4455             equiv.accept(this);
4456             return;
4457         }
4458     }
4459 #endif
4460 
4461     if (output_lanes == 1 &&
4462         factor > native_lanes &&
4463         factor % native_lanes == 0) {
4464         // It's a total reduction of multiple native
4465         // vectors. Start by adding the vectors together.
4466         Expr equiv;
4467         for (int i = 0; i < factor / native_lanes; i++) {
4468             Expr next = Shuffle::make_slice(val, i * native_lanes, 1, native_lanes);
4469             if (equiv.defined()) {
4470                 equiv = binop(equiv, next);
4471             } else {
4472                 equiv = next;
4473             }
4474         }
4475         equiv = VectorReduce::make(op->op, equiv, 1);
4476         if (init.defined()) {
4477             equiv = binop(equiv, init);
4478         }
4479         equiv = common_subexpression_elimination(equiv);
4480         equiv.accept(this);
4481         return;
4482     }
4483 
4484     if (factor > 2 && ((factor & 1) == 0)) {
4485         // Factor the reduce into multiple stages. If we're going to
4486         // be widening the type by 4x or more we should also factor the
4487         // widening into multiple stages.
4488         Type intermediate_type = op->value.type().with_lanes(op->value.type().lanes() / 2);
4489         Expr equiv = VectorReduce::make(op->op, op->value, intermediate_type.lanes());
4490         if (op->op == VectorReduce::Add &&
4491             (op->type.is_int() || op->type.is_uint()) &&
4492             op->type.bits() >= 32) {
4493             Type narrower_type = op->value.type().with_bits(op->type.bits() / 4);
4494             Expr narrower = lossless_cast(narrower_type, op->value);
4495             if (!narrower.defined() && narrower_type.is_int()) {
4496                 // Maybe we can narrow to an unsigned int instead.
4497                 narrower_type = narrower_type.with_code(Type::UInt);
4498                 narrower = lossless_cast(narrower_type, op->value);
4499             }
4500             if (narrower.defined()) {
4501                 // Widen it by 2x before the horizontal add
4502                 narrower = cast(narrower.type().with_bits(narrower.type().bits() * 2), narrower);
4503                 equiv = VectorReduce::make(op->op, narrower, intermediate_type.lanes());
4504                 // Then widen it by 2x again afterwards
4505                 equiv = cast(intermediate_type, equiv);
4506             }
4507         }
4508         equiv = VectorReduce::make(op->op, equiv, op->type.lanes());
4509         if (init.defined()) {
4510             equiv = binop(equiv, init);
4511         }
4512         equiv = common_subexpression_elimination(equiv);
4513         codegen(equiv);
4514         return;
4515     }
4516 
4517     // Extract each slice and combine
4518     Expr equiv = init;
4519     for (int i = 0; i < factor; i++) {
4520         Expr next = Shuffle::make_slice(val, i, factor, val.type().lanes() / factor);
4521         if (equiv.defined()) {
4522             equiv = binop(equiv, next);
4523         } else {
4524             equiv = next;
4525         }
4526     }
4527     equiv = common_subexpression_elimination(equiv);
4528     codegen(equiv);
4529 }  // namespace Internal
4530 
visit(const Atomic * op)4531 void CodeGen_LLVM::visit(const Atomic *op) {
4532     if (op->mutex_name != "") {
4533         internal_assert(!inside_atomic_mutex_node)
4534             << "Nested atomic mutex locks detected. This might causes a deadlock.\n";
4535         ScopedValue<bool> old_inside_atomic_mutex_node(inside_atomic_mutex_node, true);
4536         // Mutex locking & unlocking are handled by function calls generated by previous lowering passes.
4537         codegen(op->body);
4538     } else {
4539         // Issue atomic stores.
4540         ScopedValue<bool> old_emit_atomic_stores(emit_atomic_stores, true);
4541         codegen(op->body);
4542     }
4543 }
4544 
create_alloca_at_entry(llvm::Type * t,int n,bool zero_initialize,const string & name)4545 Value *CodeGen_LLVM::create_alloca_at_entry(llvm::Type *t, int n, bool zero_initialize, const string &name) {
4546     IRBuilderBase::InsertPoint here = builder->saveIP();
4547     BasicBlock *entry = &builder->GetInsertBlock()->getParent()->getEntryBlock();
4548     if (entry->empty()) {
4549         builder->SetInsertPoint(entry);
4550     } else {
4551         builder->SetInsertPoint(entry, entry->getFirstInsertionPt());
4552     }
4553     Value *size = ConstantInt::get(i32_t, n);
4554     AllocaInst *ptr = builder->CreateAlloca(t, size, name);
4555     int align = native_vector_bits() / 8;
4556     llvm::DataLayout d(module.get());
4557     int allocated_size = n * (int)d.getTypeAllocSize(t);
4558     if (t->isVectorTy() || n > 1) {
4559         ptr->setAlignment(make_alignment(align));
4560     }
4561     requested_alloca_total += allocated_size;
4562 
4563     if (zero_initialize) {
4564         if (n == 1) {
4565             builder->CreateStore(Constant::getNullValue(t), ptr);
4566         } else {
4567             builder->CreateMemSet(ptr, Constant::getNullValue(t), n, make_alignment(align));
4568         }
4569     }
4570     builder->restoreIP(here);
4571     return ptr;
4572 }
4573 
get_user_context() const4574 Value *CodeGen_LLVM::get_user_context() const {
4575     Value *ctx = sym_get("__user_context", false);
4576     if (!ctx) {
4577         ctx = ConstantPointerNull::get(i8_t->getPointerTo());  // void*
4578     }
4579     return ctx;
4580 }
4581 
call_intrin(const Type & result_type,int intrin_lanes,const string & name,vector<Expr> args)4582 Value *CodeGen_LLVM::call_intrin(const Type &result_type, int intrin_lanes,
4583                                  const string &name, vector<Expr> args) {
4584     vector<Value *> arg_values(args.size());
4585     for (size_t i = 0; i < args.size(); i++) {
4586         arg_values[i] = codegen(args[i]);
4587     }
4588 
4589     llvm::Type *t = llvm_type_of(result_type);
4590 
4591     return call_intrin(t,
4592                        intrin_lanes,
4593                        name, arg_values);
4594 }
4595 
call_intrin(llvm::Type * result_type,int intrin_lanes,const string & name,vector<Value * > arg_values)4596 Value *CodeGen_LLVM::call_intrin(llvm::Type *result_type, int intrin_lanes,
4597                                  const string &name, vector<Value *> arg_values) {
4598     int arg_lanes = 1;
4599     if (result_type->isVectorTy()) {
4600         arg_lanes = get_vector_num_elements(result_type);
4601     }
4602 
4603     if (intrin_lanes != arg_lanes) {
4604         // Cut up each arg into appropriately-sized pieces, call the
4605         // intrinsic on each, then splice together the results.
4606         vector<Value *> results;
4607         for (int start = 0; start < arg_lanes; start += intrin_lanes) {
4608             vector<Value *> args;
4609             for (size_t i = 0; i < arg_values.size(); i++) {
4610                 int arg_i_lanes = 1;
4611                 if (arg_values[i]->getType()->isVectorTy()) {
4612                     arg_i_lanes = get_vector_num_elements(arg_values[i]->getType());
4613                 }
4614                 if (arg_i_lanes >= arg_lanes) {
4615                     // Horizontally reducing intrinsics may have
4616                     // arguments that have more lanes than the
4617                     // result. Assume that the horizontally reduce
4618                     // neighboring elements...
4619                     int reduce = arg_i_lanes / arg_lanes;
4620                     args.push_back(slice_vector(arg_values[i], start * reduce, intrin_lanes * reduce));
4621                 } else if (arg_i_lanes == 1) {
4622                     // It's a scalar arg to an intrinsic that returns
4623                     // a vector. Replicate it over the slices.
4624                     args.push_back(arg_values[i]);
4625                 } else {
4626                     internal_error << "Argument in call_intrin has " << arg_i_lanes
4627                                    << " with result type having " << arg_lanes << "\n";
4628                 }
4629             }
4630 
4631             llvm::Type *result_slice_type =
4632                 get_vector_type(result_type->getScalarType(), intrin_lanes);
4633 
4634             results.push_back(call_intrin(result_slice_type, intrin_lanes, name, args));
4635         }
4636         Value *result = concat_vectors(results);
4637         return slice_vector(result, 0, arg_lanes);
4638     }
4639 
4640     vector<llvm::Type *> arg_types(arg_values.size());
4641     for (size_t i = 0; i < arg_values.size(); i++) {
4642         arg_types[i] = arg_values[i]->getType();
4643     }
4644 
4645     llvm::Function *fn = module->getFunction(name);
4646 
4647     if (!fn) {
4648         llvm::Type *intrinsic_result_type = result_type->getScalarType();
4649         if (intrin_lanes > 1) {
4650             intrinsic_result_type = get_vector_type(result_type->getScalarType(), intrin_lanes);
4651         }
4652         FunctionType *func_t = FunctionType::get(intrinsic_result_type, arg_types, false);
4653         fn = llvm::Function::Create(func_t, llvm::Function::ExternalLinkage, name, module.get());
4654         fn->setCallingConv(CallingConv::C);
4655     }
4656 
4657     CallInst *call = builder->CreateCall(fn, arg_values);
4658 
4659     call->setDoesNotAccessMemory();
4660     call->setDoesNotThrow();
4661 
4662     return call;
4663 }
4664 
slice_vector(Value * vec,int start,int size)4665 Value *CodeGen_LLVM::slice_vector(Value *vec, int start, int size) {
4666     // Force the arg to be an actual vector
4667     if (!vec->getType()->isVectorTy()) {
4668         vec = create_broadcast(vec, 1);
4669     }
4670 
4671     int vec_lanes = get_vector_num_elements(vec->getType());
4672 
4673     if (start == 0 && size == vec_lanes) {
4674         return vec;
4675     }
4676 
4677     if (size == 1) {
4678         return builder->CreateExtractElement(vec, (uint64_t)start);
4679     }
4680 
4681     vector<int> indices(size);
4682     for (int i = 0; i < size; i++) {
4683         int idx = start + i;
4684         if (idx >= 0 && idx < vec_lanes) {
4685             indices[i] = idx;
4686         } else {
4687             indices[i] = -1;
4688         }
4689     }
4690     return shuffle_vectors(vec, indices);
4691 }
4692 
concat_vectors(const vector<Value * > & v)4693 Value *CodeGen_LLVM::concat_vectors(const vector<Value *> &v) {
4694     if (v.size() == 1) return v[0];
4695 
4696     internal_assert(!v.empty());
4697 
4698     vector<Value *> vecs = v;
4699 
4700     // Force them all to be actual vectors
4701     for (Value *&val : vecs) {
4702         if (!val->getType()->isVectorTy()) {
4703             val = create_broadcast(val, 1);
4704         }
4705     }
4706 
4707     while (vecs.size() > 1) {
4708         vector<Value *> new_vecs;
4709 
4710         for (size_t i = 0; i < vecs.size() - 1; i += 2) {
4711             Value *v1 = vecs[i];
4712             Value *v2 = vecs[i + 1];
4713 
4714             int w1 = get_vector_num_elements(v1->getType());
4715             int w2 = get_vector_num_elements(v2->getType());
4716 
4717             // Possibly pad one of the vectors to match widths.
4718             if (w1 < w2) {
4719                 v1 = slice_vector(v1, 0, w2);
4720             } else if (w2 < w1) {
4721                 v2 = slice_vector(v2, 0, w1);
4722             }
4723             int w_matched = std::max(w1, w2);
4724 
4725             internal_assert(v1->getType() == v2->getType());
4726 
4727             vector<int> indices(w1 + w2);
4728             for (int i = 0; i < w1; i++) {
4729                 indices[i] = i;
4730             }
4731             for (int i = 0; i < w2; i++) {
4732                 indices[w1 + i] = w_matched + i;
4733             }
4734 
4735             Value *merged = shuffle_vectors(v1, v2, indices);
4736 
4737             new_vecs.push_back(merged);
4738         }
4739 
4740         // If there were an odd number of them, we need to also push
4741         // the one that didn't get merged.
4742         if (vecs.size() & 1) {
4743             new_vecs.push_back(vecs.back());
4744         }
4745 
4746         vecs.swap(new_vecs);
4747     }
4748 
4749     return vecs[0];
4750 }
4751 
shuffle_vectors(Value * a,Value * b,const std::vector<int> & indices)4752 Value *CodeGen_LLVM::shuffle_vectors(Value *a, Value *b,
4753                                      const std::vector<int> &indices) {
4754     internal_assert(a->getType() == b->getType());
4755     vector<Constant *> llvm_indices(indices.size());
4756     for (size_t i = 0; i < llvm_indices.size(); i++) {
4757         if (indices[i] >= 0) {
4758             internal_assert(indices[i] < get_vector_num_elements(a->getType()) * 2);
4759             llvm_indices[i] = ConstantInt::get(i32_t, indices[i]);
4760         } else {
4761             // Only let -1 be undef.
4762             internal_assert(indices[i] == -1);
4763             llvm_indices[i] = UndefValue::get(i32_t);
4764         }
4765     }
4766 
4767     return builder->CreateShuffleVector(a, b, ConstantVector::get(llvm_indices));
4768 }
4769 
shuffle_vectors(Value * a,const std::vector<int> & indices)4770 Value *CodeGen_LLVM::shuffle_vectors(Value *a, const std::vector<int> &indices) {
4771     Value *b = UndefValue::get(a->getType());
4772     return shuffle_vectors(a, b, indices);
4773 }
4774 
find_vector_runtime_function(const std::string & name,int lanes)4775 std::pair<llvm::Function *, int> CodeGen_LLVM::find_vector_runtime_function(const std::string &name, int lanes) {
4776     // Check if a vector version of the function already
4777     // exists at some useful width. We use the naming
4778     // convention that a N-wide version of a function foo is
4779     // called fooxN. All of our intrinsics are power-of-two
4780     // sized, so starting at the first power of two >= the
4781     // vector width, we'll try all powers of two in decreasing
4782     // order.
4783     vector<int> sizes_to_try;
4784     int l = 1;
4785     while (l < lanes)
4786         l *= 2;
4787     for (int i = l; i > 1; i /= 2) {
4788         sizes_to_try.push_back(i);
4789     }
4790 
4791     // If none of those match, we'll also try doubling
4792     // the lanes up to the next power of two (this is to catch
4793     // cases where we're a 64-bit vector and have a 128-bit
4794     // vector implementation).
4795     sizes_to_try.push_back(l * 2);
4796 
4797     for (size_t i = 0; i < sizes_to_try.size(); i++) {
4798         int l = sizes_to_try[i];
4799         llvm::Function *vec_fn = module->getFunction(name + "x" + std::to_string(l));
4800         if (vec_fn) {
4801             return {vec_fn, l};
4802         }
4803     }
4804 
4805     return {nullptr, 0};
4806 }
4807 
supports_atomic_add(const Type & t) const4808 bool CodeGen_LLVM::supports_atomic_add(const Type &t) const {
4809     return t.is_int_or_uint();
4810 }
4811 
use_pic() const4812 bool CodeGen_LLVM::use_pic() const {
4813     return true;
4814 }
4815 
4816 }  // namespace Internal
4817 }  // namespace Halide
4818