1 #include <iostream>
2 #include <limits>
3 #include <mutex>
4 #include <sstream>
5
6 #include "CPlusPlusMangle.h"
7 #include "CSE.h"
8 #include "CodeGen_ARM.h"
9 #include "CodeGen_GPU_Host.h"
10 #include "CodeGen_Hexagon.h"
11 #include "CodeGen_Internal.h"
12 #include "CodeGen_LLVM.h"
13 #include "CodeGen_MIPS.h"
14 #include "CodeGen_PowerPC.h"
15 #include "CodeGen_RISCV.h"
16 #include "CodeGen_WebAssembly.h"
17 #include "CodeGen_X86.h"
18 #include "CompilerLogger.h"
19 #include "Debug.h"
20 #include "Deinterleave.h"
21 #include "EmulateFloat16Math.h"
22 #include "ExprUsesVar.h"
23 #include "IROperator.h"
24 #include "IRPrinter.h"
25 #include "IntegerDivisionTable.h"
26 #include "JITModule.h"
27 #include "LLVM_Headers.h"
28 #include "LLVM_Runtime_Linker.h"
29 #include "Lerp.h"
30 #include "MatlabWrapper.h"
31 #include "Pipeline.h"
32 #include "Simplify.h"
33 #include "Util.h"
34
35 #if !(__cplusplus > 199711L || _MSC_VER >= 1800)
36
37 // VS2013 isn't fully C++11 compatible, but it supports enough of what Halide
38 // needs for now to be an acceptable minimum for Windows.
39 #error "Halide requires C++11 or VS2013+; please upgrade your compiler."
40
41 #endif
42
43 namespace Halide {
44
codegen_llvm(const Module & module,llvm::LLVMContext & context)45 std::unique_ptr<llvm::Module> codegen_llvm(const Module &module, llvm::LLVMContext &context) {
46 std::unique_ptr<Internal::CodeGen_LLVM> cg(Internal::CodeGen_LLVM::new_for_target(module.target(), context));
47 return cg->compile(module);
48 }
49
50 namespace Internal {
51
52 using namespace llvm;
53 using std::map;
54 using std::ostringstream;
55 using std::pair;
56 using std::string;
57 using std::vector;
58
59 // Define a local empty inline function for each target
60 // to disable initialization.
61 #define LLVM_TARGET(target) \
62 inline void Initialize##target##Target() { \
63 }
64 #include <llvm/Config/Targets.def>
65 #undef LLVM_TARGET
66
67 #define LLVM_ASM_PARSER(target) \
68 inline void Initialize##target##AsmParser() { \
69 }
70 #include <llvm/Config/AsmParsers.def>
71 #undef LLVM_ASM_PARSER
72
73 #define LLVM_ASM_PRINTER(target) \
74 inline void Initialize##target##AsmPrinter() { \
75 }
76 #include <llvm/Config/AsmPrinters.def>
77 #undef LLVM_ASM_PRINTER
78
79 #define InitializeTarget(target) \
80 LLVMInitialize##target##Target(); \
81 LLVMInitialize##target##TargetInfo(); \
82 LLVMInitialize##target##TargetMC(); \
83 llvm_##target##_enabled = true;
84
85 #define InitializeAsmParser(target) \
86 LLVMInitialize##target##AsmParser();
87
88 #define InitializeAsmPrinter(target) \
89 LLVMInitialize##target##AsmPrinter();
90
91 // Override above empty init function with macro for supported targets.
92 #ifdef WITH_ARM
93 #define InitializeARMTarget() InitializeTarget(ARM)
94 #define InitializeARMAsmParser() InitializeAsmParser(ARM)
95 #define InitializeARMAsmPrinter() InitializeAsmPrinter(ARM)
96 #endif
97
98 #ifdef WITH_NVPTX
99 #define InitializeNVPTXTarget() InitializeTarget(NVPTX)
100 #define InitializeNVPTXAsmParser() InitializeAsmParser(NVPTX)
101 #define InitializeNVPTXAsmPrinter() InitializeAsmPrinter(NVPTX)
102 #endif
103
104 #ifdef WITH_AMDGPU
105 #define InitializeAMDGPUTarget() InitializeTarget(AMDGPU)
106 #define InitializeAMDGPUAsmParser() InitializeAsmParser(AMDGPU)
107 #define InitializeAMDGPUAsmPrinter() InitializeAsmParser(AMDGPU)
108 #endif
109
110 #ifdef WITH_AARCH64
111 #define InitializeAArch64Target() InitializeTarget(AArch64)
112 #define InitializeAArch64AsmParser() InitializeAsmParser(AArch64)
113 #define InitializeAArch64AsmPrinter() InitializeAsmPrinter(AArch64)
114 #endif
115
116 #ifdef WITH_HEXAGON
117 #define InitializeHexagonTarget() InitializeTarget(Hexagon)
118 #define InitializeHexagonAsmParser() InitializeAsmParser(Hexagon)
119 #define InitializeHexagonAsmPrinter() InitializeAsmPrinter(Hexagon)
120 #endif
121
122 #ifdef WITH_MIPS
123 #define InitializeMipsTarget() InitializeTarget(Mips)
124 #define InitializeMipsAsmParser() InitializeAsmParser(Mips)
125 #define InitializeMipsAsmPrinter() InitializeAsmPrinter(Mips)
126 #endif
127
128 #ifdef WITH_POWERPC
129 #define InitializePowerPCTarget() InitializeTarget(PowerPC)
130 #define InitializePowerPCAsmParser() InitializeAsmParser(PowerPC)
131 #define InitializePowerPCAsmPrinter() InitializeAsmPrinter(PowerPC)
132 #endif
133
134 #ifdef WITH_RISCV
135 #define InitializeRISCVTarget() InitializeTarget(RISCV)
136 #define InitializeRISCVAsmParser() InitializeAsmParser(RISCV)
137 #define InitializeRISCVAsmPrinter() InitializeAsmPrinter(RISCV)
138 #endif
139
140 #ifdef WITH_X86
141 #define InitializeX86Target() InitializeTarget(X86)
142 #define InitializeX86AsmParser() InitializeAsmParser(X86)
143 #define InitializeX86AsmPrinter() InitializeAsmPrinter(X86)
144 #endif
145
146 #ifdef WITH_WEBASSEMBLY
147 #define InitializeWebAssemblyTarget() InitializeTarget(WebAssembly)
148 #define InitializeWebAssemblyAsmParser() InitializeAsmParser(WebAssembly)
149 #define InitializeWebAssemblyAsmPrinter() InitializeAsmPrinter(WebAssembly)
150 #endif
151
152 namespace {
153
154 // Get the LLVM linkage corresponding to a Halide linkage type.
llvm_linkage(LinkageType t)155 llvm::GlobalValue::LinkageTypes llvm_linkage(LinkageType t) {
156 // TODO(dsharlet): For some reason, marking internal functions as
157 // private linkage on OSX is causing some of the static tests to
158 // fail. Figure out why so we can remove this.
159 return llvm::GlobalValue::ExternalLinkage;
160
161 // switch (t) {
162 // case LinkageType::ExternalPlusMetadata:
163 // case LinkageType::External:
164 // return llvm::GlobalValue::ExternalLinkage;
165 // default:
166 // return llvm::GlobalValue::PrivateLinkage;
167 // }
168 }
169
170 // A local helper to make an llvm value type representing
171 // alignment. Can't be declared in a header without introducing a
172 // dependence on the LLVM headers.
173 #if LLVM_VERSION >= 100
make_alignment(int a)174 llvm::Align make_alignment(int a) {
175 return llvm::Align(a);
176 }
177 #else
make_alignment(int a)178 int make_alignment(int a) {
179 return a;
180 }
181 #endif
182
183 } // namespace
184
CodeGen_LLVM(Target t)185 CodeGen_LLVM::CodeGen_LLVM(Target t)
186 : function(nullptr), context(nullptr),
187 builder(nullptr),
188 value(nullptr),
189 very_likely_branch(nullptr),
190 default_fp_math_md(nullptr),
191 strict_fp_math_md(nullptr),
192 target(t),
193 void_t(nullptr), i1_t(nullptr), i8_t(nullptr),
194 i16_t(nullptr), i32_t(nullptr), i64_t(nullptr),
195 f16_t(nullptr), f32_t(nullptr), f64_t(nullptr),
196 halide_buffer_t_type(nullptr),
197 metadata_t_type(nullptr),
198 argument_t_type(nullptr),
199 scalar_value_t_type(nullptr),
200 device_interface_t_type(nullptr),
201 pseudostack_slot_t_type(nullptr),
202
203 // Vector types. These need an LLVMContext before they can be initialized.
204 i8x8(nullptr),
205 i8x16(nullptr),
206 i8x32(nullptr),
207 i16x4(nullptr),
208 i16x8(nullptr),
209 i16x16(nullptr),
210 i32x2(nullptr),
211 i32x4(nullptr),
212 i32x8(nullptr),
213 i64x2(nullptr),
214 i64x4(nullptr),
215 f32x2(nullptr),
216 f32x4(nullptr),
217 f32x8(nullptr),
218 f64x2(nullptr),
219 f64x4(nullptr),
220
221 // Wildcards for pattern matching
222 wild_i8x8(Variable::make(Int(8, 8), "*")),
223 wild_i16x4(Variable::make(Int(16, 4), "*")),
224 wild_i32x2(Variable::make(Int(32, 2), "*")),
225
226 wild_u8x8(Variable::make(UInt(8, 8), "*")),
227 wild_u16x4(Variable::make(UInt(16, 4), "*")),
228 wild_u32x2(Variable::make(UInt(32, 2), "*")),
229
230 wild_i8x16(Variable::make(Int(8, 16), "*")),
231 wild_i16x8(Variable::make(Int(16, 8), "*")),
232 wild_i32x4(Variable::make(Int(32, 4), "*")),
233 wild_i64x2(Variable::make(Int(64, 2), "*")),
234
235 wild_u8x16(Variable::make(UInt(8, 16), "*")),
236 wild_u16x8(Variable::make(UInt(16, 8), "*")),
237 wild_u32x4(Variable::make(UInt(32, 4), "*")),
238 wild_u64x2(Variable::make(UInt(64, 2), "*")),
239
240 wild_i8x32(Variable::make(Int(8, 32), "*")),
241 wild_i16x16(Variable::make(Int(16, 16), "*")),
242 wild_i32x8(Variable::make(Int(32, 8), "*")),
243 wild_i64x4(Variable::make(Int(64, 4), "*")),
244
245 wild_u8x32(Variable::make(UInt(8, 32), "*")),
246 wild_u16x16(Variable::make(UInt(16, 16), "*")),
247 wild_u32x8(Variable::make(UInt(32, 8), "*")),
248 wild_u64x4(Variable::make(UInt(64, 4), "*")),
249
250 wild_f32x2(Variable::make(Float(32, 2), "*")),
251
252 wild_f32x4(Variable::make(Float(32, 4), "*")),
253 wild_f64x2(Variable::make(Float(64, 2), "*")),
254
255 wild_f32x8(Variable::make(Float(32, 8), "*")),
256 wild_f64x4(Variable::make(Float(64, 4), "*")),
257
258 wild_u1x_(Variable::make(UInt(1, 0), "*")),
259 wild_i8x_(Variable::make(Int(8, 0), "*")),
260 wild_u8x_(Variable::make(UInt(8, 0), "*")),
261 wild_i16x_(Variable::make(Int(16, 0), "*")),
262 wild_u16x_(Variable::make(UInt(16, 0), "*")),
263 wild_i32x_(Variable::make(Int(32, 0), "*")),
264 wild_u32x_(Variable::make(UInt(32, 0), "*")),
265 wild_i64x_(Variable::make(Int(64, 0), "*")),
266 wild_u64x_(Variable::make(UInt(64, 0), "*")),
267 wild_f32x_(Variable::make(Float(32, 0), "*")),
268 wild_f64x_(Variable::make(Float(64, 0), "*")),
269
270 // Bounds of types
271 min_i8(Int(8).min()),
272 max_i8(Int(8).max()),
273 max_u8(UInt(8).max()),
274
275 min_i16(Int(16).min()),
276 max_i16(Int(16).max()),
277 max_u16(UInt(16).max()),
278
279 min_i32(Int(32).min()),
280 max_i32(Int(32).max()),
281 max_u32(UInt(32).max()),
282
283 min_i64(Int(64).min()),
284 max_i64(Int(64).max()),
285 max_u64(UInt(64).max()),
286
287 min_f32(Float(32).min()),
288 max_f32(Float(32).max()),
289
290 min_f64(Float(64).min()),
291 max_f64(Float(64).max()),
292
293 inside_atomic_mutex_node(false),
294 emit_atomic_stores(false),
295
296 destructor_block(nullptr),
297 strict_float(t.has_feature(Target::StrictFloat)) {
298 initialize_llvm();
299 }
300
301 namespace {
302
303 template<typename T>
make_codegen(const Target & target,llvm::LLVMContext & context)304 CodeGen_LLVM *make_codegen(const Target &target,
305 llvm::LLVMContext &context) {
306 CodeGen_LLVM *ret = new T(target);
307 ret->set_context(context);
308 return ret;
309 }
310
311 } // namespace
312
set_context(llvm::LLVMContext & context)313 void CodeGen_LLVM::set_context(llvm::LLVMContext &context) {
314 this->context = &context;
315 }
316
new_for_target(const Target & target,llvm::LLVMContext & context)317 CodeGen_LLVM *CodeGen_LLVM::new_for_target(const Target &target,
318 llvm::LLVMContext &context) {
319 // The awkward mapping from targets to code generators
320 if (target.features_any_of({Target::CUDA,
321 Target::OpenCL,
322 Target::OpenGL,
323 Target::OpenGLCompute,
324 Target::Metal,
325 Target::D3D12Compute})) {
326 #ifdef WITH_X86
327 if (target.arch == Target::X86) {
328 return make_codegen<CodeGen_GPU_Host<CodeGen_X86>>(target, context);
329 }
330 #endif
331 #if defined(WITH_ARM) || defined(WITH_AARCH64)
332 if (target.arch == Target::ARM) {
333 return make_codegen<CodeGen_GPU_Host<CodeGen_ARM>>(target, context);
334 }
335 #endif
336 #ifdef WITH_MIPS
337 if (target.arch == Target::MIPS) {
338 return make_codegen<CodeGen_GPU_Host<CodeGen_MIPS>>(target, context);
339 }
340 #endif
341 #ifdef WITH_POWERPC
342 if (target.arch == Target::POWERPC) {
343 return make_codegen<CodeGen_GPU_Host<CodeGen_PowerPC>>(target, context);
344 }
345 #endif
346 #ifdef WITH_WEBASSEMBLY
347 if (target.arch == Target::WebAssembly) {
348 return make_codegen<CodeGen_GPU_Host<CodeGen_WebAssembly>>(target, context);
349 }
350 #endif
351 #ifdef WITH_RISCV
352 if (target.arch == Target::RISCV) {
353 return make_codegen<CodeGen_GPU_Host<CodeGen_RISCV>>(target, context);
354 }
355 #endif
356 user_error << "Invalid target architecture for GPU backend: "
357 << target.to_string() << "\n";
358 return nullptr;
359
360 } else if (target.arch == Target::X86) {
361 return make_codegen<CodeGen_X86>(target, context);
362 } else if (target.arch == Target::ARM) {
363 return make_codegen<CodeGen_ARM>(target, context);
364 } else if (target.arch == Target::MIPS) {
365 return make_codegen<CodeGen_MIPS>(target, context);
366 } else if (target.arch == Target::POWERPC) {
367 return make_codegen<CodeGen_PowerPC>(target, context);
368 } else if (target.arch == Target::Hexagon) {
369 return make_codegen<CodeGen_Hexagon>(target, context);
370 } else if (target.arch == Target::WebAssembly) {
371 return make_codegen<CodeGen_WebAssembly>(target, context);
372 } else if (target.arch == Target::RISCV) {
373 return make_codegen<CodeGen_RISCV>(target, context);
374 }
375
376 user_error << "Unknown target architecture: "
377 << target.to_string() << "\n";
378 return nullptr;
379 }
380
initialize_llvm()381 void CodeGen_LLVM::initialize_llvm() {
382 static std::once_flag init_llvm_once;
383 std::call_once(init_llvm_once, []() {
384 // You can hack in command-line args to llvm with the
385 // environment variable HL_LLVM_ARGS, e.g. HL_LLVM_ARGS="-print-after-all"
386 std::string args = get_env_variable("HL_LLVM_ARGS");
387 if (!args.empty()) {
388 vector<std::string> arg_vec = split_string(args, " ");
389 vector<const char *> c_arg_vec;
390 c_arg_vec.push_back("llc");
391 for (const std::string &s : arg_vec) {
392 c_arg_vec.push_back(s.c_str());
393 }
394 cl::ParseCommandLineOptions((int)(c_arg_vec.size()), &c_arg_vec[0], "Halide compiler\n");
395 }
396
397 InitializeNativeTarget();
398 InitializeNativeTargetAsmPrinter();
399 InitializeNativeTargetAsmParser();
400
401 #define LLVM_TARGET(target) \
402 Initialize##target##Target();
403 #include <llvm/Config/Targets.def>
404 #undef LLVM_TARGET
405
406 #define LLVM_ASM_PARSER(target) \
407 Initialize##target##AsmParser();
408 #include <llvm/Config/AsmParsers.def>
409 #undef LLVM_ASM_PARSER
410
411 #define LLVM_ASM_PRINTER(target) \
412 Initialize##target##AsmPrinter();
413 #include <llvm/Config/AsmPrinters.def>
414 #include <utility>
415 #undef LLVM_ASM_PRINTER
416 });
417 }
418
init_context()419 void CodeGen_LLVM::init_context() {
420 // Ensure our IRBuilder is using the current context.
421 delete builder;
422 builder = new IRBuilder<>(*context);
423
424 // Branch weights for very likely branches
425 llvm::MDBuilder md_builder(*context);
426 very_likely_branch = md_builder.createBranchWeights(1 << 30, 0);
427 default_fp_math_md = md_builder.createFPMath(0.0);
428 strict_fp_math_md = md_builder.createFPMath(0.0);
429 builder->setDefaultFPMathTag(default_fp_math_md);
430 llvm::FastMathFlags fast_flags;
431 fast_flags.setNoNaNs();
432 fast_flags.setNoInfs();
433 fast_flags.setNoSignedZeros();
434 // Don't use approximate reciprocals for division. It's too inaccurate even for Halide.
435 // fast_flags.setAllowReciprocal();
436 // Theoretically, setAllowReassoc could be setUnsafeAlgebra for earlier versions, but that
437 // turns on all the flags.
438 fast_flags.setAllowReassoc();
439 fast_flags.setAllowContract(true);
440 fast_flags.setApproxFunc();
441 builder->setFastMathFlags(fast_flags);
442
443 // Define some types
444 void_t = llvm::Type::getVoidTy(*context);
445 i1_t = llvm::Type::getInt1Ty(*context);
446 i8_t = llvm::Type::getInt8Ty(*context);
447 i16_t = llvm::Type::getInt16Ty(*context);
448 i32_t = llvm::Type::getInt32Ty(*context);
449 i64_t = llvm::Type::getInt64Ty(*context);
450 f16_t = llvm::Type::getHalfTy(*context);
451 f32_t = llvm::Type::getFloatTy(*context);
452 f64_t = llvm::Type::getDoubleTy(*context);
453
454 i8x8 = get_vector_type(i8_t, 8);
455 i8x16 = get_vector_type(i8_t, 16);
456 i8x32 = get_vector_type(i8_t, 32);
457 i16x4 = get_vector_type(i16_t, 4);
458 i16x8 = get_vector_type(i16_t, 8);
459 i16x16 = get_vector_type(i16_t, 16);
460 i32x2 = get_vector_type(i32_t, 2);
461 i32x4 = get_vector_type(i32_t, 4);
462 i32x8 = get_vector_type(i32_t, 8);
463 i64x2 = get_vector_type(i64_t, 2);
464 i64x4 = get_vector_type(i64_t, 4);
465 f32x2 = get_vector_type(f32_t, 2);
466 f32x4 = get_vector_type(f32_t, 4);
467 f32x8 = get_vector_type(f32_t, 8);
468 f64x2 = get_vector_type(f64_t, 2);
469 f64x4 = get_vector_type(f64_t, 4);
470 }
471
init_module()472 void CodeGen_LLVM::init_module() {
473 init_context();
474
475 // Start with a module containing the initial module for this target.
476 module = get_initial_module_for_target(target, context);
477 }
478
add_external_code(const Module & halide_module)479 void CodeGen_LLVM::add_external_code(const Module &halide_module) {
480 for (const ExternalCode &code_blob : halide_module.external_code()) {
481 if (code_blob.is_for_cpu_target(get_target())) {
482 add_bitcode_to_module(context, *module, code_blob.contents(), code_blob.name());
483 }
484 }
485 }
486
~CodeGen_LLVM()487 CodeGen_LLVM::~CodeGen_LLVM() {
488 delete builder;
489 }
490
491 bool CodeGen_LLVM::llvm_X86_enabled = false;
492 bool CodeGen_LLVM::llvm_ARM_enabled = false;
493 bool CodeGen_LLVM::llvm_Hexagon_enabled = false;
494 bool CodeGen_LLVM::llvm_AArch64_enabled = false;
495 bool CodeGen_LLVM::llvm_NVPTX_enabled = false;
496 bool CodeGen_LLVM::llvm_Mips_enabled = false;
497 bool CodeGen_LLVM::llvm_PowerPC_enabled = false;
498 bool CodeGen_LLVM::llvm_AMDGPU_enabled = false;
499 bool CodeGen_LLVM::llvm_WebAssembly_enabled = false;
500 bool CodeGen_LLVM::llvm_RISCV_enabled = false;
501
502 namespace {
503
504 struct MangledNames {
505 string simple_name;
506 string extern_name;
507 string argv_name;
508 string metadata_name;
509 };
510
get_mangled_names(const std::string & name,LinkageType linkage,NameMangling mangling,const std::vector<LoweredArgument> & args,const Target & target)511 MangledNames get_mangled_names(const std::string &name,
512 LinkageType linkage,
513 NameMangling mangling,
514 const std::vector<LoweredArgument> &args,
515 const Target &target) {
516 std::vector<std::string> namespaces;
517 MangledNames names;
518 names.simple_name = extract_namespaces(name, namespaces);
519 names.extern_name = names.simple_name;
520 names.argv_name = names.simple_name + "_argv";
521 names.metadata_name = names.simple_name + "_metadata";
522
523 if (linkage != LinkageType::Internal &&
524 ((mangling == NameMangling::Default &&
525 target.has_feature(Target::CPlusPlusMangling)) ||
526 mangling == NameMangling::CPlusPlus)) {
527 std::vector<ExternFuncArgument> mangle_args;
528 for (const auto &arg : args) {
529 if (arg.kind == Argument::InputScalar) {
530 mangle_args.emplace_back(make_zero(arg.type));
531 } else if (arg.kind == Argument::InputBuffer ||
532 arg.kind == Argument::OutputBuffer) {
533 mangle_args.emplace_back(Buffer<>());
534 }
535 }
536 names.extern_name = cplusplus_function_mangled_name(names.simple_name, namespaces, type_of<int>(), mangle_args, target);
537 halide_handle_cplusplus_type inner_type(halide_cplusplus_type_name(halide_cplusplus_type_name::Simple, "void"), {}, {},
538 {halide_handle_cplusplus_type::Pointer, halide_handle_cplusplus_type::Pointer});
539 Type void_star_star(Handle(1, &inner_type));
540 names.argv_name = cplusplus_function_mangled_name(names.argv_name, namespaces, type_of<int>(), {ExternFuncArgument(make_zero(void_star_star))}, target);
541 names.metadata_name = cplusplus_function_mangled_name(names.metadata_name, namespaces, type_of<const struct halide_filter_metadata_t *>(), {}, target);
542 }
543 return names;
544 }
545
get_mangled_names(const LoweredFunc & f,const Target & target)546 MangledNames get_mangled_names(const LoweredFunc &f, const Target &target) {
547 return get_mangled_names(f.name, f.linkage, f.name_mangling, f.args, target);
548 }
549
550 } // namespace
551
signature_to_type(const ExternSignature & signature)552 llvm::FunctionType *CodeGen_LLVM::signature_to_type(const ExternSignature &signature) {
553 internal_assert(void_t != nullptr && halide_buffer_t_type != nullptr);
554 llvm::Type *ret_type =
555 signature.is_void_return() ? void_t : llvm_type_of(upgrade_type_for_argument_passing(signature.ret_type()));
556 std::vector<llvm::Type *> llvm_arg_types;
557 for (const Type &t : signature.arg_types()) {
558 if (t == type_of<struct halide_buffer_t *>()) {
559 llvm_arg_types.push_back(halide_buffer_t_type->getPointerTo());
560 } else {
561 llvm_arg_types.push_back(llvm_type_of(upgrade_type_for_argument_passing(t)));
562 }
563 }
564
565 return llvm::FunctionType::get(ret_type, llvm_arg_types, false);
566 }
567
568 /*static*/
compile_trampolines(const Target & target,llvm::LLVMContext & context,const std::string & suffix,const std::vector<std::pair<std::string,ExternSignature>> & externs)569 std::unique_ptr<llvm::Module> CodeGen_LLVM::compile_trampolines(
570 const Target &target,
571 llvm::LLVMContext &context,
572 const std::string &suffix,
573 const std::vector<std::pair<std::string, ExternSignature>> &externs) {
574 std::unique_ptr<CodeGen_LLVM> codegen(new_for_target(target, context));
575 codegen->init_codegen("trampolines" + suffix);
576 for (const std::pair<std::string, ExternSignature> &e : externs) {
577 const std::string &callee_name = e.first;
578 const std::string wrapper_name = callee_name + suffix;
579 llvm::FunctionType *fn_type = codegen->signature_to_type(e.second);
580 // callee might already be present for builtins, e.g. halide_print
581 llvm::Function *callee = codegen->module->getFunction(callee_name);
582 if (!callee) {
583 callee = llvm::Function::Create(fn_type, llvm::Function::ExternalLinkage, callee_name, codegen->module.get());
584 }
585 codegen->add_argv_wrapper(callee, wrapper_name, /*result_in_argv*/ true);
586 }
587 return codegen->finish_codegen();
588 }
589
init_codegen(const std::string & name,bool any_strict_float)590 void CodeGen_LLVM::init_codegen(const std::string &name, bool any_strict_float) {
591 init_module();
592
593 internal_assert(module && context);
594
595 debug(1) << "Target triple of initial module: " << module->getTargetTriple() << "\n";
596
597 module->setModuleIdentifier(name);
598
599 // Add some target specific info to the module as metadata.
600 module->addModuleFlag(llvm::Module::Warning, "halide_use_soft_float_abi", use_soft_float_abi() ? 1 : 0);
601 module->addModuleFlag(llvm::Module::Warning, "halide_mcpu", MDString::get(*context, mcpu()));
602 module->addModuleFlag(llvm::Module::Warning, "halide_mattrs", MDString::get(*context, mattrs()));
603 module->addModuleFlag(llvm::Module::Warning, "halide_use_pic", use_pic() ? 1 : 0);
604 module->addModuleFlag(llvm::Module::Warning, "halide_per_instruction_fast_math_flags", any_strict_float);
605
606 // Ensure some types we need are defined
607 halide_buffer_t_type = module->getTypeByName("struct.halide_buffer_t");
608 internal_assert(halide_buffer_t_type) << "Did not find halide_buffer_t in initial module";
609
610 type_t_type = module->getTypeByName("struct.halide_type_t");
611 internal_assert(type_t_type) << "Did not find halide_type_t in initial module";
612
613 dimension_t_type = module->getTypeByName("struct.halide_dimension_t");
614 internal_assert(dimension_t_type) << "Did not find halide_dimension_t in initial module";
615
616 metadata_t_type = module->getTypeByName("struct.halide_filter_metadata_t");
617 internal_assert(metadata_t_type) << "Did not find halide_filter_metadata_t in initial module";
618
619 argument_t_type = module->getTypeByName("struct.halide_filter_argument_t");
620 internal_assert(argument_t_type) << "Did not find halide_filter_argument_t in initial module";
621
622 scalar_value_t_type = module->getTypeByName("struct.halide_scalar_value_t");
623 internal_assert(scalar_value_t_type) << "Did not find halide_scalar_value_t in initial module";
624
625 device_interface_t_type = module->getTypeByName("struct.halide_device_interface_t");
626 internal_assert(device_interface_t_type) << "Did not find halide_device_interface_t in initial module";
627
628 pseudostack_slot_t_type = module->getTypeByName("struct.halide_pseudostack_slot_t");
629 internal_assert(pseudostack_slot_t_type) << "Did not find halide_pseudostack_slot_t in initial module";
630
631 semaphore_t_type = module->getTypeByName("struct.halide_semaphore_t");
632 internal_assert(semaphore_t_type) << "Did not find halide_semaphore_t in initial module";
633
634 semaphore_acquire_t_type = module->getTypeByName("struct.halide_semaphore_acquire_t");
635 internal_assert(semaphore_acquire_t_type) << "Did not find halide_semaphore_acquire_t in initial module";
636
637 parallel_task_t_type = module->getTypeByName("struct.halide_parallel_task_t");
638 internal_assert(parallel_task_t_type) << "Did not find halide_parallel_task_t in initial module";
639 }
640
compile(const Module & input)641 std::unique_ptr<llvm::Module> CodeGen_LLVM::compile(const Module &input) {
642 init_codegen(input.name(), input.any_strict_float());
643
644 internal_assert(module && context && builder)
645 << "The CodeGen_LLVM subclass should have made an initial module before calling CodeGen_LLVM::compile\n";
646
647 add_external_code(input);
648
649 // Generate the code for this module.
650 debug(1) << "Generating llvm bitcode...\n";
651 for (const auto &b : input.buffers()) {
652 compile_buffer(b);
653 }
654 for (const auto &f : input.functions()) {
655 const auto names = get_mangled_names(f, get_target());
656
657 compile_func(f, names.simple_name, names.extern_name);
658
659 // If the Func is externally visible, also create the argv wrapper and metadata.
660 // (useful for calling from JIT and other machine interfaces).
661 if (f.linkage == LinkageType::ExternalPlusMetadata) {
662 llvm::Function *wrapper = add_argv_wrapper(function, names.argv_name);
663 llvm::Function *metadata_getter = embed_metadata_getter(names.metadata_name,
664 names.simple_name, f.args, input.get_metadata_name_map());
665
666 if (target.has_feature(Target::Matlab)) {
667 define_matlab_wrapper(module.get(), wrapper, metadata_getter);
668 }
669 }
670 }
671
672 debug(2) << module.get() << "\n";
673
674 return finish_codegen();
675 }
676
finish_codegen()677 std::unique_ptr<llvm::Module> CodeGen_LLVM::finish_codegen() {
678 // Verify the module is ok
679 internal_assert(!verifyModule(*module, &llvm::errs()));
680 debug(2) << "Done generating llvm bitcode\n";
681
682 // Optimize
683 CodeGen_LLVM::optimize_module();
684
685 if (target.has_feature(Target::EmbedBitcode)) {
686 std::string halide_command = "halide target=" + target.to_string();
687 embed_bitcode(module.get(), halide_command);
688 }
689
690 // Disown the module and return it.
691 return std::move(module);
692 }
693
begin_func(LinkageType linkage,const std::string & name,const std::string & extern_name,const std::vector<LoweredArgument> & args)694 void CodeGen_LLVM::begin_func(LinkageType linkage, const std::string &name,
695 const std::string &extern_name, const std::vector<LoweredArgument> &args) {
696 current_function_args = args;
697
698 // Deduce the types of the arguments to our function
699 vector<llvm::Type *> arg_types(args.size());
700 for (size_t i = 0; i < args.size(); i++) {
701 if (args[i].is_buffer()) {
702 arg_types[i] = halide_buffer_t_type->getPointerTo();
703 } else {
704 arg_types[i] = llvm_type_of(upgrade_type_for_argument_passing(args[i].type));
705 }
706 }
707 FunctionType *func_t = FunctionType::get(i32_t, arg_types, false);
708
709 // Make our function. There may already be a declaration of it.
710 function = module->getFunction(extern_name);
711 if (!function) {
712 function = llvm::Function::Create(func_t, llvm_linkage(linkage), extern_name, module.get());
713 } else {
714 user_assert(function->isDeclaration())
715 << "Another function with the name " << extern_name
716 << " already exists in the same module\n";
717 if (func_t != function->getFunctionType()) {
718 std::cerr << "Desired function type for " << extern_name << ":\n";
719 func_t->print(dbgs(), true);
720 std::cerr << "Declared function type of " << extern_name << ":\n";
721 function->getFunctionType()->print(dbgs(), true);
722 user_error << "Cannot create a function with a declaration of mismatched type.\n";
723 }
724 }
725 set_function_attributes_for_target(function, target);
726
727 // Mark the buffer args as no alias
728 for (size_t i = 0; i < args.size(); i++) {
729 if (args[i].is_buffer()) {
730 function->addParamAttr(i, Attribute::NoAlias);
731 }
732 }
733
734 debug(1) << "Generating llvm bitcode prolog for function " << name << "...\n";
735
736 // Null out the destructor block.
737 destructor_block = nullptr;
738
739 // Make the initial basic block
740 BasicBlock *block = BasicBlock::Create(*context, "entry", function);
741 builder->SetInsertPoint(block);
742
743 // Put the arguments in the symbol table
744 {
745 size_t i = 0;
746 for (auto &arg : function->args()) {
747 if (args[i].is_buffer()) {
748 // Track this buffer name so that loads and stores from it
749 // don't try to be too aligned.
750 external_buffer.insert(args[i].name);
751 sym_push(args[i].name + ".buffer", &arg);
752 } else {
753 Type passed_type = upgrade_type_for_argument_passing(args[i].type);
754 if (args[i].type != passed_type) {
755 llvm::Value *a = builder->CreateBitCast(&arg, llvm_type_of(args[i].type));
756 sym_push(args[i].name, a);
757 } else {
758 sym_push(args[i].name, &arg);
759 }
760 }
761
762 i++;
763 }
764 }
765 }
766
end_func(const std::vector<LoweredArgument> & args)767 void CodeGen_LLVM::end_func(const std::vector<LoweredArgument> &args) {
768 return_with_error_code(ConstantInt::get(i32_t, 0));
769
770 // Remove the arguments from the symbol table
771 for (size_t i = 0; i < args.size(); i++) {
772 if (args[i].is_buffer()) {
773 sym_pop(args[i].name + ".buffer");
774 } else {
775 sym_pop(args[i].name);
776 }
777 }
778
779 internal_assert(!verifyFunction(*function, &llvm::errs()));
780
781 current_function_args.clear();
782 }
783
compile_func(const LoweredFunc & f,const std::string & simple_name,const std::string & extern_name)784 void CodeGen_LLVM::compile_func(const LoweredFunc &f, const std::string &simple_name,
785 const std::string &extern_name) {
786 // Generate the function declaration and argument unpacking code.
787 begin_func(f.linkage, simple_name, extern_name, f.args);
788
789 // If building with MSAN, ensure that calls to halide_msan_annotate_buffer_is_initialized()
790 // happen for every output buffer if the function succeeds.
791 if (f.linkage != LinkageType::Internal &&
792 target.has_feature(Target::MSAN)) {
793 llvm::Function *annotate_buffer_fn =
794 module->getFunction("halide_msan_annotate_buffer_is_initialized_as_destructor");
795 internal_assert(annotate_buffer_fn)
796 << "Could not find halide_msan_annotate_buffer_is_initialized_as_destructor in module\n";
797 annotate_buffer_fn->addParamAttr(0, Attribute::NoAlias);
798 for (const auto &arg : f.args) {
799 if (arg.kind == Argument::OutputBuffer) {
800 register_destructor(annotate_buffer_fn, sym_get(arg.name + ".buffer"), OnSuccess);
801 }
802 }
803 }
804
805 // Generate the function body.
806 debug(1) << "Generating llvm bitcode for function " << f.name << "...\n";
807 f.body.accept(this);
808
809 // Clean up and return.
810 end_func(f.args);
811 }
812
813 // Given a range of iterators of constant ints, get a corresponding vector of llvm::Constant.
814 template<typename It>
get_constants(llvm::Type * t,It begin,It end)815 std::vector<llvm::Constant *> get_constants(llvm::Type *t, It begin, It end) {
816 std::vector<llvm::Constant *> ret;
817 for (It i = begin; i != end; i++) {
818 ret.push_back(ConstantInt::get(t, *i));
819 }
820 return ret;
821 }
822
get_destructor_block()823 BasicBlock *CodeGen_LLVM::get_destructor_block() {
824 if (!destructor_block) {
825 // Create it if it doesn't exist.
826 IRBuilderBase::InsertPoint here = builder->saveIP();
827 destructor_block = BasicBlock::Create(*context, "destructor_block", function);
828 builder->SetInsertPoint(destructor_block);
829 // The first instruction in the destructor block is a phi node
830 // that collects the error code.
831 PHINode *error_code = builder->CreatePHI(i32_t, 0);
832
833 // Calls to destructors will get inserted here.
834
835 // The last instruction is the return op that returns it.
836 builder->CreateRet(error_code);
837
838 // Jump back to where we were.
839 builder->restoreIP(here);
840 }
841 internal_assert(destructor_block->getParent() == function);
842 return destructor_block;
843 }
844
register_destructor(llvm::Function * destructor_fn,Value * obj,DestructorType when)845 Value *CodeGen_LLVM::register_destructor(llvm::Function *destructor_fn, Value *obj, DestructorType when) {
846
847 // Create a null-initialized stack slot to track this object
848 llvm::Type *void_ptr = i8_t->getPointerTo();
849 llvm::Value *stack_slot = create_alloca_at_entry(void_ptr, 1, true);
850
851 // Cast the object to llvm's representation of void *
852 obj = builder->CreatePointerCast(obj, void_ptr);
853
854 // Put it in the stack slot
855 builder->CreateStore(obj, stack_slot);
856
857 // Passing the constant null as the object means the destructor
858 // will never get called.
859 {
860 llvm::Constant *c = dyn_cast<llvm::Constant>(obj);
861 if (c && c->isNullValue()) {
862 internal_error << "Destructors must take a non-null object\n";
863 }
864 }
865
866 // Switch to the destructor block, and add code that cleans up
867 // this object if the contents of the stack slot is not nullptr.
868 IRBuilderBase::InsertPoint here = builder->saveIP();
869 BasicBlock *dtors = get_destructor_block();
870
871 builder->SetInsertPoint(dtors->getFirstNonPHI());
872
873 PHINode *error_code = dyn_cast<PHINode>(dtors->begin());
874 internal_assert(error_code) << "The destructor block is supposed to start with a phi node\n";
875
876 llvm::Value *should_call = nullptr;
877 switch (when) {
878 case Always:
879 should_call = ConstantInt::get(i1_t, 1);
880 break;
881 case OnError:
882 should_call = builder->CreateIsNotNull(error_code);
883 break;
884 case OnSuccess:
885 should_call = builder->CreateIsNull(error_code);
886 break;
887 }
888 llvm::Function *call_destructor = module->getFunction("call_destructor");
889 internal_assert(call_destructor);
890 internal_assert(destructor_fn);
891 internal_assert(should_call);
892 Value *args[] = {get_user_context(), destructor_fn, stack_slot, should_call};
893 builder->CreateCall(call_destructor, args);
894
895 // Switch back to the original location
896 builder->restoreIP(here);
897
898 // Return the stack slot so that it's possible to cleanup the object early.
899 return stack_slot;
900 }
901
trigger_destructor(llvm::Function * destructor_fn,Value * stack_slot)902 void CodeGen_LLVM::trigger_destructor(llvm::Function *destructor_fn, Value *stack_slot) {
903 llvm::Function *call_destructor = module->getFunction("call_destructor");
904 internal_assert(call_destructor);
905 internal_assert(destructor_fn);
906 stack_slot = builder->CreatePointerCast(stack_slot, i8_t->getPointerTo()->getPointerTo());
907 Value *should_call = ConstantInt::get(i1_t, 1);
908 Value *args[] = {get_user_context(), destructor_fn, stack_slot, should_call};
909 builder->CreateCall(call_destructor, args);
910 }
911
compile_buffer(const Buffer<> & buf)912 void CodeGen_LLVM::compile_buffer(const Buffer<> &buf) {
913 // Embed the buffer declaration as a global.
914 internal_assert(buf.defined());
915
916 user_assert(buf.data())
917 << "Can't embed buffer " << buf.name() << " because it has a null host pointer.\n";
918 user_assert(!buf.device_dirty())
919 << "Can't embed Image \"" << buf.name() << "\""
920 << " because it has a dirty device pointer\n";
921
922 Constant *type_fields[] = {
923 ConstantInt::get(i8_t, buf.type().code()),
924 ConstantInt::get(i8_t, buf.type().bits()),
925 ConstantInt::get(i16_t, buf.type().lanes())};
926
927 Constant *shape = nullptr;
928 if (buf.dimensions()) {
929 size_t shape_size = buf.dimensions() * sizeof(halide_dimension_t);
930 vector<char> shape_blob((char *)buf.raw_buffer()->dim, (char *)buf.raw_buffer()->dim + shape_size);
931 shape = create_binary_blob(shape_blob, buf.name() + ".shape");
932 shape = ConstantExpr::getPointerCast(shape, dimension_t_type->getPointerTo());
933 } else {
934 shape = ConstantPointerNull::get(dimension_t_type->getPointerTo());
935 }
936
937 // For now, we assume buffers that aren't scalar are constant,
938 // while scalars can be mutated. This accommodates all our existing
939 // use cases, which is that all buffers are constant, except those
940 // used to store stateful module information in offloading runtimes.
941 bool constant = buf.dimensions() != 0;
942
943 vector<char> data_blob((const char *)buf.data(), (const char *)buf.data() + buf.size_in_bytes());
944
945 Constant *fields[] = {
946 ConstantInt::get(i64_t, 0), // device
947 ConstantPointerNull::get(device_interface_t_type->getPointerTo()), // device_interface
948 create_binary_blob(data_blob, buf.name() + ".data", constant), // host
949 ConstantInt::get(i64_t, halide_buffer_flag_host_dirty), // flags
950 ConstantStruct::get(type_t_type, type_fields), // type
951 ConstantInt::get(i32_t, buf.dimensions()), // dimensions
952 shape, // dim
953 ConstantPointerNull::get(i8_t->getPointerTo()), // padding
954 };
955 Constant *buffer_struct = ConstantStruct::get(halide_buffer_t_type, fields);
956
957 // Embed the halide_buffer_t and make it point to the data array.
958 GlobalVariable *global = new GlobalVariable(*module, halide_buffer_t_type,
959 false, GlobalValue::PrivateLinkage,
960 0, buf.name() + ".buffer");
961 global->setInitializer(buffer_struct);
962
963 // Finally, dump it in the symbol table
964 Constant *zero[] = {ConstantInt::get(i32_t, 0)};
965 Constant *global_ptr = ConstantExpr::getInBoundsGetElementPtr(halide_buffer_t_type, global, zero);
966 sym_push(buf.name() + ".buffer", global_ptr);
967 }
968
embed_constant_scalar_value_t(const Expr & e)969 Constant *CodeGen_LLVM::embed_constant_scalar_value_t(const Expr &e) {
970 if (!e.defined()) {
971 return Constant::getNullValue(scalar_value_t_type->getPointerTo());
972 }
973
974 internal_assert(!e.type().is_handle()) << "Should never see Handle types here.";
975
976 llvm::Value *val = codegen(e);
977 llvm::Constant *constant = dyn_cast<llvm::Constant>(val);
978 internal_assert(constant);
979
980 // Verify that the size of the LLVM value is the size we expected.
981 internal_assert((uint64_t)constant->getType()->getPrimitiveSizeInBits() == (uint64_t)e.type().bits());
982
983 // It's important that we allocate a full scalar_value_t_type here,
984 // even if the type of the value is smaller; downstream consumers should
985 // be able to correctly load an entire scalar_value_t_type regardless of its
986 // type, and if we emit just (say) a uint8 value here, the pointer may be
987 // misaligned and/or the storage after may be unmapped. LLVM doesn't support
988 // unions directly, so we'll fake it by making a constant array of the elements
989 // we need, setting the first to the constant we want, and setting the rest
990 // to all-zeros. (This happens to work because sizeof(halide_scalar_value_t) is evenly
991 // divisible by sizeof(any-union-field.)
992
993 const size_t value_size = e.type().bytes();
994 internal_assert(value_size > 0 && value_size <= sizeof(halide_scalar_value_t));
995
996 const size_t array_size = sizeof(halide_scalar_value_t) / value_size;
997 internal_assert(array_size * value_size == sizeof(halide_scalar_value_t));
998
999 vector<Constant *> array_entries(array_size, Constant::getNullValue(constant->getType()));
1000 array_entries[0] = constant;
1001
1002 llvm::ArrayType *array_type = ArrayType::get(constant->getType(), array_size);
1003 GlobalVariable *storage = new GlobalVariable(
1004 *module,
1005 array_type,
1006 /*isConstant*/ true,
1007 GlobalValue::PrivateLinkage,
1008 ConstantArray::get(array_type, array_entries));
1009
1010 // Ensure that the storage is aligned for halide_scalar_value_t
1011 storage->setAlignment(make_alignment((int)sizeof(halide_scalar_value_t)));
1012
1013 Constant *zero[] = {ConstantInt::get(i32_t, 0)};
1014 return ConstantExpr::getBitCast(
1015 ConstantExpr::getInBoundsGetElementPtr(array_type, storage, zero),
1016 scalar_value_t_type->getPointerTo());
1017 }
1018
embed_constant_expr(Expr e,llvm::Type * t)1019 Constant *CodeGen_LLVM::embed_constant_expr(Expr e, llvm::Type *t) {
1020 internal_assert(t != scalar_value_t_type);
1021
1022 if (!e.defined()) {
1023 return Constant::getNullValue(t->getPointerTo());
1024 }
1025
1026 internal_assert(!e.type().is_handle()) << "Should never see Handle types here.";
1027 if (!is_const(e)) {
1028 e = simplify(e);
1029 internal_assert(is_const(e)) << "Should only see constant values for estimates.";
1030 }
1031
1032 llvm::Value *val = codegen(e);
1033 llvm::Constant *constant = dyn_cast<llvm::Constant>(val);
1034 internal_assert(constant);
1035
1036 GlobalVariable *storage = new GlobalVariable(
1037 *module,
1038 constant->getType(),
1039 /*isConstant*/ true,
1040 GlobalValue::PrivateLinkage,
1041 constant);
1042
1043 Constant *zero[] = {ConstantInt::get(i32_t, 0)};
1044 return ConstantExpr::getBitCast(
1045 ConstantExpr::getInBoundsGetElementPtr(constant->getType(), storage, zero),
1046 t->getPointerTo());
1047 }
1048
1049 // Make a wrapper to call the function with an array of pointer
1050 // args. This is easier for the JIT to call than a function with an
1051 // unknown (at compile time) argument list. If result_in_argv is false,
1052 // the internal function result is returned as the wrapper function
1053 // result; if result_in_argv is true, the internal function result
1054 // is stored as the last item in the argv list (which must be one
1055 // longer than the number of arguments), and the wrapper's actual
1056 // return type is always 'void'.
add_argv_wrapper(llvm::Function * fn,const std::string & name,bool result_in_argv)1057 llvm::Function *CodeGen_LLVM::add_argv_wrapper(llvm::Function *fn,
1058 const std::string &name,
1059 bool result_in_argv) {
1060 llvm::Type *wrapper_result_type = result_in_argv ? void_t : i32_t;
1061 llvm::Type *wrapper_args_t[] = {i8_t->getPointerTo()->getPointerTo()};
1062 llvm::FunctionType *wrapper_func_t = llvm::FunctionType::get(wrapper_result_type, wrapper_args_t, false);
1063 llvm::Function *wrapper_func = llvm::Function::Create(wrapper_func_t, llvm::GlobalValue::ExternalLinkage, name, module.get());
1064 llvm::BasicBlock *wrapper_block = llvm::BasicBlock::Create(module->getContext(), "entry", wrapper_func);
1065 builder->SetInsertPoint(wrapper_block);
1066
1067 llvm::Value *arg_array = iterator_to_pointer(wrapper_func->arg_begin());
1068 std::vector<llvm::Value *> wrapper_args;
1069 for (llvm::Function::arg_iterator i = fn->arg_begin(); i != fn->arg_end(); i++) {
1070 // Get the address of the nth argument
1071 llvm::Value *ptr = builder->CreateConstGEP1_32(arg_array, wrapper_args.size());
1072 ptr = builder->CreateLoad(ptr);
1073 if (i->getType() == halide_buffer_t_type->getPointerTo()) {
1074 // Cast the argument to a halide_buffer_t *
1075 wrapper_args.push_back(builder->CreatePointerCast(ptr, halide_buffer_t_type->getPointerTo()));
1076 } else {
1077 // Cast to the appropriate type and load
1078 ptr = builder->CreatePointerCast(ptr, i->getType()->getPointerTo());
1079 wrapper_args.push_back(builder->CreateLoad(ptr));
1080 }
1081 }
1082 debug(4) << "Creating call from wrapper to actual function\n";
1083 llvm::CallInst *result = builder->CreateCall(fn, wrapper_args);
1084 // This call should never inline
1085 result->setIsNoInline();
1086
1087 if (result_in_argv) {
1088 llvm::Value *result_in_argv_ptr = builder->CreateConstGEP1_32(arg_array, wrapper_args.size());
1089 if (fn->getReturnType() != void_t) {
1090 result_in_argv_ptr = builder->CreateLoad(result_in_argv_ptr);
1091 // Cast to the appropriate type and store
1092 result_in_argv_ptr = builder->CreatePointerCast(result_in_argv_ptr, fn->getReturnType()->getPointerTo());
1093 builder->CreateStore(result, result_in_argv_ptr);
1094 }
1095 builder->CreateRetVoid();
1096 } else {
1097 // We could probably support other types as return values,
1098 // but int32 results are all that have actually been tested.
1099 internal_assert(fn->getReturnType() == i32_t);
1100 builder->CreateRet(result);
1101 }
1102 internal_assert(!verifyFunction(*wrapper_func, &llvm::errs()));
1103 return wrapper_func;
1104 }
1105
embed_metadata_getter(const std::string & metadata_name,const std::string & function_name,const std::vector<LoweredArgument> & args,const std::map<std::string,std::string> & metadata_name_map)1106 llvm::Function *CodeGen_LLVM::embed_metadata_getter(const std::string &metadata_name,
1107 const std::string &function_name, const std::vector<LoweredArgument> &args,
1108 const std::map<std::string, std::string> &metadata_name_map) {
1109 Constant *zero = ConstantInt::get(i32_t, 0);
1110
1111 const int num_args = (int)args.size();
1112
1113 auto map_string = [&metadata_name_map](const std::string &from) -> std::string {
1114 auto it = metadata_name_map.find(from);
1115 return it == metadata_name_map.end() ? from : it->second;
1116 };
1117
1118 vector<Constant *> arguments_array_entries;
1119 for (int arg = 0; arg < num_args; ++arg) {
1120
1121 StructType *type_t_type = module->getTypeByName("struct.halide_type_t");
1122 internal_assert(type_t_type) << "Did not find halide_type_t in module.\n";
1123
1124 Constant *type_fields[] = {
1125 ConstantInt::get(i8_t, args[arg].type.code()),
1126 ConstantInt::get(i8_t, args[arg].type.bits()),
1127 ConstantInt::get(i16_t, 1)};
1128 Constant *type = ConstantStruct::get(type_t_type, type_fields);
1129
1130 auto argument_estimates = args[arg].argument_estimates;
1131 if (args[arg].type.is_handle()) {
1132 // Handle values are always emitted into metadata as "undefined", regardless of
1133 // what sort of Expr is provided.
1134 argument_estimates = ArgumentEstimates{};
1135 }
1136
1137 Constant *buffer_estimates_array_ptr;
1138 if (args[arg].is_buffer() && !argument_estimates.buffer_estimates.empty()) {
1139 internal_assert((int)argument_estimates.buffer_estimates.size() == args[arg].dimensions);
1140 vector<Constant *> buffer_estimates_array_entries;
1141 for (const auto &be : argument_estimates.buffer_estimates) {
1142 Expr min = be.min;
1143 if (min.defined()) min = cast<int64_t>(min);
1144 Expr extent = be.extent;
1145 if (extent.defined()) extent = cast<int64_t>(extent);
1146 buffer_estimates_array_entries.push_back(embed_constant_expr(min, i64_t));
1147 buffer_estimates_array_entries.push_back(embed_constant_expr(extent, i64_t));
1148 }
1149
1150 llvm::ArrayType *buffer_estimates_array = ArrayType::get(i64_t->getPointerTo(), buffer_estimates_array_entries.size());
1151 GlobalVariable *buffer_estimates_array_storage = new GlobalVariable(
1152 *module,
1153 buffer_estimates_array,
1154 /*isConstant*/ true,
1155 GlobalValue::PrivateLinkage,
1156 ConstantArray::get(buffer_estimates_array, buffer_estimates_array_entries));
1157
1158 Value *zeros[] = {zero, zero};
1159 buffer_estimates_array_ptr = ConstantExpr::getInBoundsGetElementPtr(buffer_estimates_array, buffer_estimates_array_storage, zeros);
1160 } else {
1161 buffer_estimates_array_ptr = Constant::getNullValue(i64_t->getPointerTo()->getPointerTo());
1162 }
1163
1164 Constant *argument_fields[] = {
1165 create_string_constant(map_string(args[arg].name)),
1166 ConstantInt::get(i32_t, args[arg].kind),
1167 ConstantInt::get(i32_t, args[arg].dimensions),
1168 type,
1169 embed_constant_scalar_value_t(argument_estimates.scalar_def),
1170 embed_constant_scalar_value_t(argument_estimates.scalar_min),
1171 embed_constant_scalar_value_t(argument_estimates.scalar_max),
1172 embed_constant_scalar_value_t(argument_estimates.scalar_estimate),
1173 buffer_estimates_array_ptr};
1174 arguments_array_entries.push_back(ConstantStruct::get(argument_t_type, argument_fields));
1175 }
1176 llvm::ArrayType *arguments_array = ArrayType::get(argument_t_type, num_args);
1177 GlobalVariable *arguments_array_storage = new GlobalVariable(
1178 *module,
1179 arguments_array,
1180 /*isConstant*/ true,
1181 GlobalValue::PrivateLinkage,
1182 ConstantArray::get(arguments_array, arguments_array_entries));
1183
1184 Constant *version = ConstantInt::get(i32_t, halide_filter_metadata_t::VERSION);
1185
1186 Value *zeros[] = {zero, zero};
1187 Constant *metadata_fields[] = {
1188 /* version */ version,
1189 /* num_arguments */ ConstantInt::get(i32_t, num_args),
1190 /* arguments */ ConstantExpr::getInBoundsGetElementPtr(arguments_array, arguments_array_storage, zeros),
1191 /* target */ create_string_constant(map_string(target.to_string())),
1192 /* name */ create_string_constant(map_string(function_name))};
1193
1194 GlobalVariable *metadata_storage = new GlobalVariable(
1195 *module,
1196 metadata_t_type,
1197 /*isConstant*/ true,
1198 GlobalValue::PrivateLinkage,
1199 ConstantStruct::get(metadata_t_type, metadata_fields),
1200 metadata_name + "_storage");
1201
1202 llvm::FunctionType *func_t = llvm::FunctionType::get(metadata_t_type->getPointerTo(), false);
1203 llvm::Function *metadata_getter = llvm::Function::Create(func_t, llvm::GlobalValue::ExternalLinkage, metadata_name, module.get());
1204 llvm::BasicBlock *block = llvm::BasicBlock::Create(module->getContext(), "entry", metadata_getter);
1205 builder->SetInsertPoint(block);
1206 builder->CreateRet(metadata_storage);
1207 internal_assert(!verifyFunction(*metadata_getter, &llvm::errs()));
1208
1209 return metadata_getter;
1210 }
1211
llvm_type_of(const Type & t) const1212 llvm::Type *CodeGen_LLVM::llvm_type_of(const Type &t) const {
1213 return Internal::llvm_type_of(context, t);
1214 }
1215
optimize_module()1216 void CodeGen_LLVM::optimize_module() {
1217 debug(3) << "Optimizing module\n";
1218
1219 auto time_start = std::chrono::high_resolution_clock::now();
1220
1221 if (debug::debug_level() >= 3) {
1222 module->print(dbgs(), nullptr, false, true);
1223 }
1224
1225 std::unique_ptr<TargetMachine> tm = make_target_machine(*module);
1226
1227 // At present, we default to *enabling* LLVM loop optimization,
1228 // unless DisableLLVMLoopOpt is set; we're going to flip this to defaulting
1229 // to *not* enabling these optimizations (and removing the DisableLLVMLoopOpt feature).
1230 // See https://github.com/halide/Halide/issues/4113 for more info.
1231 // (Note that setting EnableLLVMLoopOpt always enables loop opt, regardless
1232 // of the setting of DisableLLVMLoopOpt.)
1233 const bool do_loop_opt = !get_target().has_feature(Target::DisableLLVMLoopOpt) ||
1234 get_target().has_feature(Target::EnableLLVMLoopOpt);
1235
1236 PipelineTuningOptions pto;
1237 pto.LoopInterleaving = do_loop_opt;
1238 pto.LoopVectorization = do_loop_opt;
1239 pto.SLPVectorization = true; // Note: SLP vectorization has no analogue in the Halide scheduling model
1240 pto.LoopUnrolling = do_loop_opt;
1241 // Clear ScEv info for all loops. Certain Halide applications spend a very
1242 // long time compiling in forgetLoop, and prefer to forget everything
1243 // and rebuild SCEV (aka "Scalar Evolution") from scratch.
1244 // Sample difference in compile time reduction at the time of this change was
1245 // 21.04 -> 14.78 using current ToT release build. (See also https://reviews.llvm.org/rL358304)
1246 pto.ForgetAllSCEVInLoopUnroll = true;
1247
1248 llvm::PassBuilder pb(tm.get(), pto);
1249
1250 bool debug_pass_manager = false;
1251 // These analysis managers have to be declared in this order.
1252 llvm::LoopAnalysisManager lam(debug_pass_manager);
1253 llvm::FunctionAnalysisManager fam(debug_pass_manager);
1254 llvm::CGSCCAnalysisManager cgam(debug_pass_manager);
1255 llvm::ModuleAnalysisManager mam(debug_pass_manager);
1256
1257 llvm::AAManager aa = pb.buildDefaultAAPipeline();
1258 fam.registerPass([&] { return std::move(aa); });
1259
1260 // Register all the basic analyses with the managers.
1261 pb.registerModuleAnalyses(mam);
1262 pb.registerCGSCCAnalyses(cgam);
1263 pb.registerFunctionAnalyses(fam);
1264 pb.registerLoopAnalyses(lam);
1265 pb.crossRegisterProxies(lam, fam, cgam, mam);
1266 ModulePassManager mpm(debug_pass_manager);
1267
1268 PassBuilder::OptimizationLevel level = PassBuilder::OptimizationLevel::O3;
1269
1270 if (get_target().has_feature(Target::ASAN)) {
1271 pb.registerPipelineStartEPCallback([&](ModulePassManager &mpm) {
1272 mpm.addPass(
1273 RequireAnalysisPass<ASanGlobalsMetadataAnalysis, llvm::Module>());
1274 });
1275 #if LLVM_VERSION >= 110
1276 pb.registerOptimizerLastEPCallback(
1277 [](ModulePassManager &mpm, PassBuilder::OptimizationLevel level) {
1278 constexpr bool compile_kernel = false;
1279 constexpr bool recover = false;
1280 constexpr bool use_after_scope = true;
1281 mpm.addPass(createModuleToFunctionPassAdaptor(AddressSanitizerPass(
1282 compile_kernel, recover, use_after_scope)));
1283 });
1284 #else
1285 pb.registerOptimizerLastEPCallback(
1286 [](FunctionPassManager &fpm, PassBuilder::OptimizationLevel level) {
1287 constexpr bool compile_kernel = false;
1288 constexpr bool recover = false;
1289 constexpr bool use_after_scope = true;
1290 fpm.addPass(AddressSanitizerPass(
1291 compile_kernel, recover, use_after_scope));
1292 });
1293 #endif
1294 pb.registerPipelineStartEPCallback(
1295 [](ModulePassManager &mpm) {
1296 constexpr bool compile_kernel = false;
1297 constexpr bool recover = false;
1298 constexpr bool module_use_after_scope = false;
1299 constexpr bool use_odr_indicator = true;
1300 mpm.addPass(ModuleAddressSanitizerPass(
1301 compile_kernel, recover, module_use_after_scope,
1302 use_odr_indicator));
1303 });
1304 }
1305
1306 if (get_target().has_feature(Target::TSAN)) {
1307 #if LLVM_VERSION >= 110
1308 pb.registerOptimizerLastEPCallback(
1309 [](ModulePassManager &mpm, PassBuilder::OptimizationLevel level) {
1310 mpm.addPass(
1311 createModuleToFunctionPassAdaptor(ThreadSanitizerPass()));
1312 });
1313 #else
1314 pb.registerOptimizerLastEPCallback(
1315 [](FunctionPassManager &fpm, PassBuilder::OptimizationLevel level) {
1316 fpm.addPass(ThreadSanitizerPass());
1317 });
1318 #endif
1319 }
1320
1321 for (llvm::Module::iterator i = module->begin(); i != module->end(); i++) {
1322 if (get_target().has_feature(Target::ASAN)) {
1323 i->addFnAttr(Attribute::SanitizeAddress);
1324 }
1325 if (get_target().has_feature(Target::TSAN)) {
1326 // Do not annotate any of Halide's low-level synchronization code as it has
1327 // tsan interface calls to mark its behavior and is much faster if
1328 // it is not analyzed instruction by instruction.
1329 if (!(i->getName().startswith("_ZN6Halide7Runtime8Internal15Synchronization") ||
1330 // TODO: this is a benign data race that re-initializes the detected features;
1331 // we should really fix it properly inside the implementation, rather than disabling
1332 // it here as a band-aid.
1333 i->getName().startswith("halide_default_can_use_target_features") ||
1334 i->getName().startswith("halide_mutex_") ||
1335 i->getName().startswith("halide_cond_"))) {
1336 i->addFnAttr(Attribute::SanitizeThread);
1337 }
1338 }
1339 }
1340
1341 mpm = pb.buildPerModuleDefaultPipeline(level, debug_pass_manager);
1342 mpm.run(*module, mam);
1343
1344 if (llvm::verifyModule(*module, &errs()))
1345 report_fatal_error("Transformation resulted in an invalid module\n");
1346
1347 debug(3) << "After LLVM optimizations:\n";
1348 if (debug::debug_level() >= 2) {
1349 module->print(dbgs(), nullptr, false, true);
1350 }
1351
1352 auto *logger = get_compiler_logger();
1353 if (logger) {
1354 auto time_end = std::chrono::high_resolution_clock::now();
1355 std::chrono::duration<double> diff = time_end - time_start;
1356 logger->record_compilation_time(CompilerLogger::Phase::LLVM, diff.count());
1357 }
1358 }
1359
sym_push(const string & name,llvm::Value * value)1360 void CodeGen_LLVM::sym_push(const string &name, llvm::Value *value) {
1361 if (!value->getType()->isVoidTy()) {
1362 value->setName(name);
1363 }
1364 symbol_table.push(name, value);
1365 }
1366
sym_pop(const string & name)1367 void CodeGen_LLVM::sym_pop(const string &name) {
1368 symbol_table.pop(name);
1369 }
1370
sym_get(const string & name,bool must_succeed) const1371 llvm::Value *CodeGen_LLVM::sym_get(const string &name, bool must_succeed) const {
1372 // look in the symbol table
1373 if (!symbol_table.contains(name)) {
1374 if (must_succeed) {
1375 std::ostringstream err;
1376 err << "Symbol not found: " << name << "\n";
1377
1378 if (debug::debug_level() > 0) {
1379 err << "The following names are in scope:\n"
1380 << symbol_table << "\n";
1381 }
1382
1383 internal_error << err.str();
1384 } else {
1385 return nullptr;
1386 }
1387 }
1388 return symbol_table.get(name);
1389 }
1390
sym_exists(const string & name) const1391 bool CodeGen_LLVM::sym_exists(const string &name) const {
1392 return symbol_table.contains(name);
1393 }
1394
codegen(const Expr & e)1395 Value *CodeGen_LLVM::codegen(const Expr &e) {
1396 internal_assert(e.defined());
1397 debug(4) << "Codegen: " << e.type() << ", " << e << "\n";
1398 value = nullptr;
1399 e.accept(this);
1400 internal_assert(value) << "Codegen of an expr did not produce an llvm value\n";
1401
1402 // Halide's type system doesn't distinguish between scalars and
1403 // vectors of size 1, so if a codegen method returned a vector of
1404 // size one, just extract it out as a scalar.
1405 if (e.type().is_scalar() &&
1406 value->getType()->isVectorTy()) {
1407 internal_assert(get_vector_num_elements(value->getType()) == 1);
1408 value = builder->CreateExtractElement(value, ConstantInt::get(i32_t, 0));
1409 }
1410
1411 // TODO: skip this correctness check for bool vectors,
1412 // as eliminate_bool_vectors() will cause a discrepancy for some backends
1413 // (eg OpenCL, HVX, WASM); for now we're just ignoring the assert, but
1414 // in the long run we should improve the smarts. See https://github.com/halide/Halide/issues/4194.
1415 const bool is_bool_vector = e.type().is_bool() && e.type().lanes() > 1;
1416 // TODO: skip this correctness check for prefetch, because the return type
1417 // of prefetch indicates the type being prefetched, which does not match the
1418 // implementation of prefetch.
1419 // See https://github.com/halide/Halide/issues/4211.
1420 const bool is_prefetch = e.as<Call>() && e.as<Call>()->is_intrinsic(Call::prefetch);
1421 internal_assert(is_bool_vector || is_prefetch ||
1422 e.type().is_handle() ||
1423 value->getType()->isVoidTy() ||
1424 value->getType() == llvm_type_of(e.type()))
1425 << "Codegen of Expr " << e
1426 << " of type " << e.type()
1427 << " did not produce llvm IR of the corresponding llvm type.\n";
1428 return value;
1429 }
1430
codegen(const Stmt & s)1431 void CodeGen_LLVM::codegen(const Stmt &s) {
1432 internal_assert(s.defined());
1433 debug(3) << "Codegen: " << s << "\n";
1434 value = nullptr;
1435 s.accept(this);
1436 }
1437
upgrade_type_for_arithmetic(const Type & t) const1438 Type CodeGen_LLVM::upgrade_type_for_arithmetic(const Type &t) const {
1439 if (t.is_bfloat() || (t.is_float() && t.bits() < 32)) {
1440 return Float(32, t.lanes());
1441 } else {
1442 return t;
1443 }
1444 }
1445
upgrade_type_for_argument_passing(const Type & t) const1446 Type CodeGen_LLVM::upgrade_type_for_argument_passing(const Type &t) const {
1447 if (t.is_bfloat() || (t.is_float() && t.bits() < 32)) {
1448 return t.with_code(halide_type_uint);
1449 } else {
1450 return t;
1451 }
1452 }
1453
upgrade_type_for_storage(const Type & t) const1454 Type CodeGen_LLVM::upgrade_type_for_storage(const Type &t) const {
1455 if (t.is_bfloat() || (t.is_float() && t.bits() < 32)) {
1456 return t.with_code(halide_type_uint);
1457 } else if (t.is_bool()) {
1458 return t.with_bits(8);
1459 } else if (t.is_handle()) {
1460 return UInt(64, t.lanes());
1461 } else {
1462 return t;
1463 }
1464 }
1465
visit(const IntImm * op)1466 void CodeGen_LLVM::visit(const IntImm *op) {
1467 value = ConstantInt::getSigned(llvm_type_of(op->type), op->value);
1468 }
1469
visit(const UIntImm * op)1470 void CodeGen_LLVM::visit(const UIntImm *op) {
1471 value = ConstantInt::get(llvm_type_of(op->type), op->value);
1472 }
1473
visit(const FloatImm * op)1474 void CodeGen_LLVM::visit(const FloatImm *op) {
1475 if (op->type.is_bfloat()) {
1476 codegen(reinterpret(BFloat(16), make_const(UInt(16), bfloat16_t(op->value).to_bits())));
1477 } else if (op->type.bits() == 16) {
1478 codegen(reinterpret(Float(16), make_const(UInt(16), float16_t(op->value).to_bits())));
1479 } else {
1480 value = ConstantFP::get(llvm_type_of(op->type), op->value);
1481 }
1482 }
1483
visit(const StringImm * op)1484 void CodeGen_LLVM::visit(const StringImm *op) {
1485 value = create_string_constant(op->value);
1486 }
1487
visit(const Cast * op)1488 void CodeGen_LLVM::visit(const Cast *op) {
1489 Halide::Type src = op->value.type();
1490 Halide::Type dst = op->type;
1491
1492 if (upgrade_type_for_arithmetic(src) != src ||
1493 upgrade_type_for_arithmetic(dst) != dst) {
1494 // Handle casts to and from types for which we don't have native support.
1495 debug(4) << "Emulating cast from " << src << " to " << dst << "\n";
1496 if ((src.is_float() && src.bits() < 32) ||
1497 (dst.is_float() && dst.bits() < 32)) {
1498 Expr equiv = lower_float16_cast(op);
1499 internal_assert(equiv.type() == op->type);
1500 codegen(equiv);
1501 } else {
1502 internal_error << "Cast from type: " << src
1503 << " to " << dst
1504 << " unimplemented\n";
1505 }
1506 return;
1507 }
1508
1509 value = codegen(op->value);
1510 llvm::Type *llvm_dst = llvm_type_of(dst);
1511
1512 if (dst.is_handle() && src.is_handle()) {
1513 value = builder->CreateBitCast(value, llvm_dst);
1514 } else if (dst.is_handle() || src.is_handle()) {
1515 internal_error << "Can't cast from " << src << " to " << dst << "\n";
1516 } else if (!src.is_float() && !dst.is_float()) {
1517 // Widening integer casts either zero extend or sign extend,
1518 // depending on the source type. Narrowing integer casts
1519 // always truncate.
1520 value = builder->CreateIntCast(value, llvm_dst, src.is_int());
1521 } else if (src.is_float() && dst.is_int()) {
1522 value = builder->CreateFPToSI(value, llvm_dst);
1523 } else if (src.is_float() && dst.is_uint()) {
1524 // fptoui has undefined behavior on overflow. Seems reasonable
1525 // to get an unspecified uint on overflow, but because uint1s
1526 // are stored in uint8s for float->uint1 casts this undefined
1527 // behavior manifests itself as uint1 values greater than 1,
1528 // which could in turn break our bounds inference
1529 // guarantees. So go via uint8 in this case.
1530 if (dst.bits() < 8) {
1531 value = builder->CreateFPToUI(value, llvm_type_of(dst.with_bits(8)));
1532 value = builder->CreateIntCast(value, llvm_dst, false);
1533 } else {
1534 value = builder->CreateFPToUI(value, llvm_dst);
1535 }
1536 } else if (src.is_int() && dst.is_float()) {
1537 value = builder->CreateSIToFP(value, llvm_dst);
1538 } else if (src.is_uint() && dst.is_float()) {
1539 value = builder->CreateUIToFP(value, llvm_dst);
1540 } else {
1541 internal_assert(src.is_float() && dst.is_float());
1542 // Float widening or narrowing
1543 value = builder->CreateFPCast(value, llvm_dst);
1544 }
1545 }
1546
visit(const Variable * op)1547 void CodeGen_LLVM::visit(const Variable *op) {
1548 value = sym_get(op->name);
1549 }
1550
1551 template<typename Op>
try_to_fold_vector_reduce(const Op * op)1552 bool CodeGen_LLVM::try_to_fold_vector_reduce(const Op *op) {
1553 const VectorReduce *red = op->a.template as<VectorReduce>();
1554 Expr b = op->b;
1555 if (!red) {
1556 red = op->b.template as<VectorReduce>();
1557 b = op->a;
1558 }
1559 if (red &&
1560 ((std::is_same<Op, Add>::value && red->op == VectorReduce::Add) ||
1561 (std::is_same<Op, Min>::value && red->op == VectorReduce::Min) ||
1562 (std::is_same<Op, Max>::value && red->op == VectorReduce::Max) ||
1563 (std::is_same<Op, Mul>::value && red->op == VectorReduce::Mul) ||
1564 (std::is_same<Op, And>::value && red->op == VectorReduce::And) ||
1565 (std::is_same<Op, Or>::value && red->op == VectorReduce::Or))) {
1566 codegen_vector_reduce(red, b);
1567 return true;
1568 }
1569 return false;
1570 }
1571
visit(const Add * op)1572 void CodeGen_LLVM::visit(const Add *op) {
1573 Type t = upgrade_type_for_arithmetic(op->type);
1574 if (t != op->type) {
1575 codegen(cast(op->type, Add::make(cast(t, op->a), cast(t, op->b))));
1576 return;
1577 }
1578
1579 // Some backends can fold the add into a vector reduce
1580 if (try_to_fold_vector_reduce(op)) {
1581 return;
1582 }
1583
1584 Value *a = codegen(op->a);
1585 Value *b = codegen(op->b);
1586 if (op->type.is_float()) {
1587 value = builder->CreateFAdd(a, b);
1588 } else if (op->type.is_int() && op->type.bits() >= 32) {
1589 // We tell llvm integers don't wrap, so that it generates good
1590 // code for loop indices.
1591 value = builder->CreateNSWAdd(a, b);
1592 } else {
1593 value = builder->CreateAdd(a, b);
1594 }
1595 }
1596
visit(const Sub * op)1597 void CodeGen_LLVM::visit(const Sub *op) {
1598 Type t = upgrade_type_for_arithmetic(op->type);
1599 if (t != op->type) {
1600 codegen(cast(op->type, Sub::make(cast(t, op->a), cast(t, op->b))));
1601 return;
1602 }
1603
1604 Value *a = codegen(op->a);
1605 Value *b = codegen(op->b);
1606 if (op->type.is_float()) {
1607 value = builder->CreateFSub(a, b);
1608 } else if (op->type.is_int() && op->type.bits() >= 32) {
1609 // We tell llvm integers don't wrap, so that it generates good
1610 // code for loop indices.
1611 value = builder->CreateNSWSub(a, b);
1612 } else {
1613 value = builder->CreateSub(a, b);
1614 }
1615 }
1616
visit(const Mul * op)1617 void CodeGen_LLVM::visit(const Mul *op) {
1618 Type t = upgrade_type_for_arithmetic(op->type);
1619 if (t != op->type) {
1620 codegen(cast(op->type, Mul::make(cast(t, op->a), cast(t, op->b))));
1621 return;
1622 }
1623
1624 if (try_to_fold_vector_reduce(op)) {
1625 return;
1626 }
1627
1628 Value *a = codegen(op->a);
1629 Value *b = codegen(op->b);
1630 if (op->type.is_float()) {
1631 value = builder->CreateFMul(a, b);
1632 } else if (op->type.is_int() && op->type.bits() >= 32) {
1633 // We tell llvm integers don't wrap, so that it generates good
1634 // code for loop indices.
1635 value = builder->CreateNSWMul(a, b);
1636 } else {
1637 value = builder->CreateMul(a, b);
1638 }
1639 }
1640
visit(const Div * op)1641 void CodeGen_LLVM::visit(const Div *op) {
1642 user_assert(!is_zero(op->b)) << "Division by constant zero in expression: " << Expr(op) << "\n";
1643
1644 Type t = upgrade_type_for_arithmetic(op->type);
1645 if (t != op->type) {
1646 codegen(cast(op->type, Div::make(cast(t, op->a), cast(t, op->b))));
1647 return;
1648 }
1649
1650 if (op->type.is_float()) {
1651 // Don't call codegen() multiple times within an argument list:
1652 // order-of-evaluation isn't guaranteed and can vary by compiler,
1653 // leading to different LLVM IR ordering, which makes comparing
1654 // output hard.
1655 Value *a = codegen(op->a);
1656 Value *b = codegen(op->b);
1657 value = builder->CreateFDiv(a, b);
1658 } else {
1659 value = codegen(lower_int_uint_div(op->a, op->b));
1660 }
1661 }
1662
visit(const Mod * op)1663 void CodeGen_LLVM::visit(const Mod *op) {
1664 Type t = upgrade_type_for_arithmetic(op->type);
1665 if (t != op->type) {
1666 codegen(cast(op->type, Mod::make(cast(t, op->a), cast(t, op->b))));
1667 return;
1668 }
1669
1670 if (op->type.is_float()) {
1671 value = codegen(simplify(op->a - op->b * floor(op->a / op->b)));
1672 } else {
1673 value = codegen(lower_int_uint_mod(op->a, op->b));
1674 }
1675 }
1676
visit(const Min * op)1677 void CodeGen_LLVM::visit(const Min *op) {
1678 Type t = upgrade_type_for_arithmetic(op->type);
1679 if (t != op->type) {
1680 codegen(cast(op->type, Min::make(cast(t, op->a), cast(t, op->b))));
1681 return;
1682 }
1683
1684 if (try_to_fold_vector_reduce(op)) {
1685 return;
1686 }
1687
1688 string a_name = unique_name('a');
1689 string b_name = unique_name('b');
1690 Expr a = Variable::make(op->a.type(), a_name);
1691 Expr b = Variable::make(op->b.type(), b_name);
1692 value = codegen(Let::make(a_name, op->a,
1693 Let::make(b_name, op->b,
1694 select(a < b, a, b))));
1695 }
1696
visit(const Max * op)1697 void CodeGen_LLVM::visit(const Max *op) {
1698 Type t = upgrade_type_for_arithmetic(op->type);
1699 if (t != op->type) {
1700 codegen(cast(op->type, Max::make(cast(t, op->a), cast(t, op->b))));
1701 return;
1702 }
1703
1704 if (try_to_fold_vector_reduce(op)) {
1705 return;
1706 }
1707
1708 string a_name = unique_name('a');
1709 string b_name = unique_name('b');
1710 Expr a = Variable::make(op->a.type(), a_name);
1711 Expr b = Variable::make(op->b.type(), b_name);
1712 value = codegen(Let::make(a_name, op->a,
1713 Let::make(b_name, op->b,
1714 select(a > b, a, b))));
1715 }
1716
visit(const EQ * op)1717 void CodeGen_LLVM::visit(const EQ *op) {
1718 Type t = upgrade_type_for_arithmetic(op->a.type());
1719 if (t != op->a.type()) {
1720 codegen(EQ::make(cast(t, op->a), cast(t, op->b)));
1721 return;
1722 }
1723
1724 Value *a = codegen(op->a);
1725 Value *b = codegen(op->b);
1726 if (t.is_float()) {
1727 value = builder->CreateFCmpOEQ(a, b);
1728 } else {
1729 value = builder->CreateICmpEQ(a, b);
1730 }
1731 }
1732
visit(const NE * op)1733 void CodeGen_LLVM::visit(const NE *op) {
1734 Type t = upgrade_type_for_arithmetic(op->a.type());
1735 if (t != op->a.type()) {
1736 codegen(NE::make(cast(t, op->a), cast(t, op->b)));
1737 return;
1738 }
1739
1740 Value *a = codegen(op->a);
1741 Value *b = codegen(op->b);
1742 if (t.is_float()) {
1743 value = builder->CreateFCmpONE(a, b);
1744 } else {
1745 value = builder->CreateICmpNE(a, b);
1746 }
1747 }
1748
visit(const LT * op)1749 void CodeGen_LLVM::visit(const LT *op) {
1750 Type t = upgrade_type_for_arithmetic(op->a.type());
1751 if (t != op->a.type()) {
1752 codegen(LT::make(cast(t, op->a), cast(t, op->b)));
1753 return;
1754 }
1755
1756 Value *a = codegen(op->a);
1757 Value *b = codegen(op->b);
1758 if (t.is_float()) {
1759 value = builder->CreateFCmpOLT(a, b);
1760 } else if (t.is_int()) {
1761 value = builder->CreateICmpSLT(a, b);
1762 } else {
1763 value = builder->CreateICmpULT(a, b);
1764 }
1765 }
1766
visit(const LE * op)1767 void CodeGen_LLVM::visit(const LE *op) {
1768 Type t = upgrade_type_for_arithmetic(op->a.type());
1769 if (t != op->a.type()) {
1770 codegen(LE::make(cast(t, op->a), cast(t, op->b)));
1771 return;
1772 }
1773
1774 Value *a = codegen(op->a);
1775 Value *b = codegen(op->b);
1776 if (t.is_float()) {
1777 value = builder->CreateFCmpOLE(a, b);
1778 } else if (t.is_int()) {
1779 value = builder->CreateICmpSLE(a, b);
1780 } else {
1781 value = builder->CreateICmpULE(a, b);
1782 }
1783 }
1784
visit(const GT * op)1785 void CodeGen_LLVM::visit(const GT *op) {
1786 Type t = upgrade_type_for_arithmetic(op->a.type());
1787 if (t != op->a.type()) {
1788 codegen(GT::make(cast(t, op->a), cast(t, op->b)));
1789 return;
1790 }
1791
1792 Value *a = codegen(op->a);
1793 Value *b = codegen(op->b);
1794
1795 if (t.is_float()) {
1796 value = builder->CreateFCmpOGT(a, b);
1797 } else if (t.is_int()) {
1798 value = builder->CreateICmpSGT(a, b);
1799 } else {
1800 value = builder->CreateICmpUGT(a, b);
1801 }
1802 }
1803
visit(const GE * op)1804 void CodeGen_LLVM::visit(const GE *op) {
1805 Type t = upgrade_type_for_arithmetic(op->a.type());
1806 if (t != op->a.type()) {
1807 codegen(GE::make(cast(t, op->a), cast(t, op->b)));
1808 return;
1809 }
1810
1811 Value *a = codegen(op->a);
1812 Value *b = codegen(op->b);
1813 if (t.is_float()) {
1814 value = builder->CreateFCmpOGE(a, b);
1815 } else if (t.is_int()) {
1816 value = builder->CreateICmpSGE(a, b);
1817 } else {
1818 value = builder->CreateICmpUGE(a, b);
1819 }
1820 }
1821
visit(const And * op)1822 void CodeGen_LLVM::visit(const And *op) {
1823 if (try_to_fold_vector_reduce(op)) {
1824 return;
1825 }
1826
1827 Value *a = codegen(op->a);
1828 Value *b = codegen(op->b);
1829 value = builder->CreateAnd(a, b);
1830 }
1831
visit(const Or * op)1832 void CodeGen_LLVM::visit(const Or *op) {
1833 if (try_to_fold_vector_reduce(op)) {
1834 return;
1835 }
1836
1837 Value *a = codegen(op->a);
1838 Value *b = codegen(op->b);
1839 value = builder->CreateOr(a, b);
1840 }
1841
visit(const Not * op)1842 void CodeGen_LLVM::visit(const Not *op) {
1843 Value *a = codegen(op->a);
1844 value = builder->CreateNot(a);
1845 }
1846
visit(const Select * op)1847 void CodeGen_LLVM::visit(const Select *op) {
1848 Value *cmp = codegen(op->condition);
1849 Value *a = codegen(op->true_value);
1850 Value *b = codegen(op->false_value);
1851 value = builder->CreateSelect(cmp, a, b);
1852 }
1853
1854 namespace {
promote_64(const Expr & e)1855 Expr promote_64(const Expr &e) {
1856 if (const Add *a = e.as<Add>()) {
1857 return Add::make(promote_64(a->a), promote_64(a->b));
1858 } else if (const Sub *s = e.as<Sub>()) {
1859 return Sub::make(promote_64(s->a), promote_64(s->b));
1860 } else if (const Mul *m = e.as<Mul>()) {
1861 return Mul::make(promote_64(m->a), promote_64(m->b));
1862 } else if (const Min *m = e.as<Min>()) {
1863 return Min::make(promote_64(m->a), promote_64(m->b));
1864 } else if (const Max *m = e.as<Max>()) {
1865 return Max::make(promote_64(m->a), promote_64(m->b));
1866 } else {
1867 return cast(Int(64), e);
1868 }
1869 }
1870 } // namespace
1871
codegen_buffer_pointer(const string & buffer,Halide::Type type,Expr index)1872 Value *CodeGen_LLVM::codegen_buffer_pointer(const string &buffer, Halide::Type type, Expr index) {
1873 // Find the base address from the symbol table
1874 Value *base_address = symbol_table.get(buffer);
1875 return codegen_buffer_pointer(base_address, type, std::move(index));
1876 }
1877
codegen_buffer_pointer(Value * base_address,Halide::Type type,Expr index)1878 Value *CodeGen_LLVM::codegen_buffer_pointer(Value *base_address, Halide::Type type, Expr index) {
1879 // Promote index to 64-bit on targets that use 64-bit pointers.
1880 llvm::DataLayout d(module.get());
1881 if (promote_indices() && d.getPointerSize() == 8) {
1882 index = promote_64(index);
1883 }
1884
1885 // Peel off a constant offset as a second GEP. This helps LLVM's
1886 // aliasing analysis, especially for backends that do address
1887 // computation in 32 bits but use 64-bit pointers.
1888 if (const Add *add = index.as<Add>()) {
1889 if (const int64_t *offset = as_const_int(add->b)) {
1890 Value *base = codegen_buffer_pointer(base_address, type, add->a);
1891 Value *off = codegen(make_const(Int(8 * d.getPointerSize()), *offset));
1892 return builder->CreateInBoundsGEP(base, off);
1893 }
1894 }
1895
1896 return codegen_buffer_pointer(base_address, type, codegen(index));
1897 }
1898
codegen_buffer_pointer(const string & buffer,Halide::Type type,Value * index)1899 Value *CodeGen_LLVM::codegen_buffer_pointer(const string &buffer, Halide::Type type, Value *index) {
1900 // Find the base address from the symbol table
1901 Value *base_address = symbol_table.get(buffer);
1902 return codegen_buffer_pointer(base_address, type, index);
1903 }
1904
codegen_buffer_pointer(Value * base_address,Halide::Type type,Value * index)1905 Value *CodeGen_LLVM::codegen_buffer_pointer(Value *base_address, Halide::Type type, Value *index) {
1906 llvm::Type *base_address_type = base_address->getType();
1907 unsigned address_space = base_address_type->getPointerAddressSpace();
1908
1909 type = upgrade_type_for_storage(type);
1910
1911 llvm::Type *load_type = llvm_type_of(type)->getPointerTo(address_space);
1912
1913 // If the type doesn't match the expected type, we need to pointer cast
1914 if (load_type != base_address_type) {
1915 base_address = builder->CreatePointerCast(base_address, load_type);
1916 }
1917
1918 llvm::Constant *constant_index = dyn_cast<llvm::Constant>(index);
1919 if (constant_index && constant_index->isZeroValue()) {
1920 return base_address;
1921 }
1922
1923 // Promote index to 64-bit on targets that use 64-bit pointers.
1924 llvm::DataLayout d(module.get());
1925 if (d.getPointerSize() == 8) {
1926 index = builder->CreateIntCast(index, i64_t, true);
1927 }
1928
1929 return builder->CreateInBoundsGEP(base_address, index);
1930 }
1931
1932 namespace {
next_power_of_two(int x)1933 int next_power_of_two(int x) {
1934 for (int p2 = 1;; p2 *= 2) {
1935 if (p2 >= x) {
1936 return p2;
1937 }
1938 }
1939 // unreachable.
1940 }
1941 } // namespace
1942
add_tbaa_metadata(llvm::Instruction * inst,string buffer,const Expr & index)1943 void CodeGen_LLVM::add_tbaa_metadata(llvm::Instruction *inst, string buffer, const Expr &index) {
1944
1945 // Get the unique name for the block of memory this allocate node
1946 // is using.
1947 buffer = get_allocation_name(buffer);
1948
1949 // If the index is constant, we generate some TBAA info that helps
1950 // LLVM understand our loads/stores aren't aliased.
1951 bool constant_index = false;
1952 int64_t base = 0;
1953 int64_t width = 1;
1954
1955 if (index.defined()) {
1956 if (const Ramp *ramp = index.as<Ramp>()) {
1957 const int64_t *pstride = as_const_int(ramp->stride);
1958 const int64_t *pbase = as_const_int(ramp->base);
1959 if (pstride && pbase) {
1960 // We want to find the smallest aligned width and offset
1961 // that contains this ramp.
1962 int64_t stride = *pstride;
1963 base = *pbase;
1964 internal_assert(base >= 0);
1965 width = next_power_of_two(ramp->lanes * stride);
1966
1967 while (base % width) {
1968 base -= base % width;
1969 width *= 2;
1970 }
1971 constant_index = true;
1972 }
1973 } else {
1974 const int64_t *pbase = as_const_int(index);
1975 if (pbase) {
1976 base = *pbase;
1977 constant_index = true;
1978 }
1979 }
1980 }
1981
1982 llvm::MDBuilder builder(*context);
1983
1984 // Add type-based-alias-analysis metadata to the pointer, so that
1985 // loads and stores to different buffers can get reordered.
1986 MDNode *tbaa = builder.createTBAARoot("Halide buffer");
1987
1988 tbaa = builder.createTBAAScalarTypeNode(buffer, tbaa);
1989
1990 // We also add metadata for constant indices to allow loads and
1991 // stores to the same buffer to get reordered.
1992 if (constant_index) {
1993 for (int w = 1024; w >= width; w /= 2) {
1994 int64_t b = (base / w) * w;
1995
1996 std::stringstream level;
1997 level << buffer << ".width" << w << ".base" << b;
1998 tbaa = builder.createTBAAScalarTypeNode(level.str(), tbaa);
1999 }
2000 }
2001
2002 tbaa = builder.createTBAAStructTagNode(tbaa, tbaa, 0);
2003
2004 inst->setMetadata("tbaa", tbaa);
2005 }
2006
visit(const Load * op)2007 void CodeGen_LLVM::visit(const Load *op) {
2008 // If the type should be stored as some other type, insert a reinterpret cast.
2009 Type storage_type = upgrade_type_for_storage(op->type);
2010 if (op->type != storage_type) {
2011 codegen(reinterpret(op->type, Load::make(storage_type, op->name,
2012 op->index, op->image,
2013 op->param, op->predicate, op->alignment)));
2014 return;
2015 }
2016
2017 // Predicated load
2018 if (!is_one(op->predicate)) {
2019 codegen_predicated_vector_load(op);
2020 return;
2021 }
2022
2023 // There are several cases. Different architectures may wish to override some.
2024 if (op->type.is_scalar()) {
2025 // Scalar loads
2026 Value *ptr = codegen_buffer_pointer(op->name, op->type, op->index);
2027 LoadInst *load = builder->CreateAlignedLoad(ptr, make_alignment(op->type.bytes()));
2028 add_tbaa_metadata(load, op->name, op->index);
2029 value = load;
2030 } else {
2031 const Ramp *ramp = op->index.as<Ramp>();
2032 const IntImm *stride = ramp ? ramp->stride.as<IntImm>() : nullptr;
2033
2034 if (ramp && stride && stride->value == 1) {
2035 value = codegen_dense_vector_load(op);
2036 } else if (ramp && stride && stride->value == 2) {
2037 // Load two vectors worth and then shuffle
2038 Expr base_a = ramp->base, base_b = ramp->base + ramp->lanes;
2039 Expr stride_a = make_one(base_a.type());
2040 Expr stride_b = make_one(base_b.type());
2041
2042 ModulusRemainder align_a = op->alignment;
2043 ModulusRemainder align_b = align_a + ramp->lanes;
2044
2045 // False indicates we should take the even-numbered lanes
2046 // from the load, true indicates we should take the
2047 // odd-numbered-lanes.
2048 bool shifted_a = false, shifted_b = false;
2049
2050 bool external = op->param.defined() || op->image.defined();
2051
2052 // Don't read beyond the end of an external buffer.
2053 // (In ASAN mode, don't read beyond the end of internal buffers either,
2054 // as ASAN will complain even about harmless stack overreads.)
2055 if (external || target.has_feature(Target::ASAN)) {
2056 base_b -= 1;
2057 align_b = align_b - 1;
2058 shifted_b = true;
2059 } else {
2060 // If the base ends in an odd constant, then subtract one
2061 // and do a different shuffle. This helps expressions like
2062 // (f(2*x) + f(2*x+1) share loads
2063 const Add *add = ramp->base.as<Add>();
2064 const IntImm *offset = add ? add->b.as<IntImm>() : nullptr;
2065 if (offset && offset->value & 1) {
2066 base_a -= 1;
2067 align_a = align_a - 1;
2068 shifted_a = true;
2069 base_b -= 1;
2070 align_b = align_b - 1;
2071 shifted_b = true;
2072 }
2073 }
2074
2075 // Do each load.
2076 Expr ramp_a = Ramp::make(base_a, stride_a, ramp->lanes);
2077 Expr ramp_b = Ramp::make(base_b, stride_b, ramp->lanes);
2078 Expr load_a = Load::make(op->type, op->name, ramp_a, op->image, op->param, op->predicate, align_a);
2079 Expr load_b = Load::make(op->type, op->name, ramp_b, op->image, op->param, op->predicate, align_b);
2080 Value *vec_a = codegen(load_a);
2081 Value *vec_b = codegen(load_b);
2082
2083 // Shuffle together the results.
2084 vector<int> indices(ramp->lanes);
2085 for (int i = 0; i < (ramp->lanes + 1) / 2; i++) {
2086 indices[i] = i * 2 + (shifted_a ? 1 : 0);
2087 }
2088 for (int i = (ramp->lanes + 1) / 2; i < ramp->lanes; i++) {
2089 indices[i] = i * 2 + (shifted_b ? 1 : 0);
2090 }
2091
2092 value = shuffle_vectors(vec_a, vec_b, indices);
2093 } else if (ramp && stride && stride->value == -1) {
2094 // Load the vector and then flip it in-place
2095 Expr flipped_base = ramp->base - ramp->lanes + 1;
2096 Expr flipped_stride = make_one(flipped_base.type());
2097 Expr flipped_index = Ramp::make(flipped_base, flipped_stride, ramp->lanes);
2098 ModulusRemainder align = op->alignment;
2099 // Switch to the alignment of the last lane
2100 align = align - (ramp->lanes - 1);
2101 Expr flipped_load = Load::make(op->type, op->name, flipped_index, op->image, op->param, op->predicate, align);
2102
2103 Value *flipped = codegen(flipped_load);
2104
2105 vector<int> indices(ramp->lanes);
2106 for (int i = 0; i < ramp->lanes; i++) {
2107 indices[i] = ramp->lanes - 1 - i;
2108 }
2109
2110 value = shuffle_vectors(flipped, indices);
2111 } else if (ramp) {
2112 // Gather without generating the indices as a vector
2113 Value *ptr = codegen_buffer_pointer(op->name, op->type.element_of(), ramp->base);
2114 Value *stride = codegen(ramp->stride);
2115 value = UndefValue::get(llvm_type_of(op->type));
2116 for (int i = 0; i < ramp->lanes; i++) {
2117 Value *lane = ConstantInt::get(i32_t, i);
2118 LoadInst *val = builder->CreateLoad(ptr);
2119 add_tbaa_metadata(val, op->name, op->index);
2120 value = builder->CreateInsertElement(value, val, lane);
2121 ptr = builder->CreateInBoundsGEP(ptr, stride);
2122 }
2123 } else if ((false)) { /* should_scalarize(op->index) */
2124 // TODO: put something sensible in for
2125 // should_scalarize. Probably a good idea if there are no
2126 // loads in it, and it's all int32.
2127
2128 // Compute the index as scalars, and then do a gather
2129 Value *vec = UndefValue::get(llvm_type_of(op->type));
2130 for (int i = 0; i < op->type.lanes(); i++) {
2131 Expr idx = extract_lane(op->index, i);
2132 Value *ptr = codegen_buffer_pointer(op->name, op->type.element_of(), idx);
2133 LoadInst *val = builder->CreateLoad(ptr);
2134 add_tbaa_metadata(val, op->name, op->index);
2135 vec = builder->CreateInsertElement(vec, val, ConstantInt::get(i32_t, i));
2136 }
2137 value = vec;
2138 } else {
2139 // General gathers
2140 Value *index = codegen(op->index);
2141 Value *vec = UndefValue::get(llvm_type_of(op->type));
2142 for (int i = 0; i < op->type.lanes(); i++) {
2143 Value *idx = builder->CreateExtractElement(index, ConstantInt::get(i32_t, i));
2144 Value *ptr = codegen_buffer_pointer(op->name, op->type.element_of(), idx);
2145 LoadInst *val = builder->CreateLoad(ptr);
2146 add_tbaa_metadata(val, op->name, op->index);
2147 vec = builder->CreateInsertElement(vec, val, ConstantInt::get(i32_t, i));
2148 }
2149 value = vec;
2150 }
2151 }
2152 }
2153
visit(const Ramp * op)2154 void CodeGen_LLVM::visit(const Ramp *op) {
2155 if (is_const(op->stride) && !is_const(op->base)) {
2156 // If the stride is const and the base is not (e.g. ramp(x, 1,
2157 // 4)), we can lift out the stride and broadcast the base so
2158 // we can do a single vector broadcast and add instead of
2159 // repeated insertion
2160 Expr broadcast = Broadcast::make(op->base, op->lanes);
2161 Expr ramp = Ramp::make(make_zero(op->base.type()), op->stride, op->lanes);
2162 value = codegen(broadcast + ramp);
2163 } else {
2164 // Otherwise we generate element by element by adding the stride to the base repeatedly
2165
2166 Value *base = codegen(op->base);
2167 Value *stride = codegen(op->stride);
2168
2169 value = UndefValue::get(llvm_type_of(op->type));
2170 for (int i = 0; i < op->type.lanes(); i++) {
2171 if (i > 0) {
2172 if (op->type.is_float()) {
2173 base = builder->CreateFAdd(base, stride);
2174 } else if (op->type.is_int() && op->type.bits() >= 32) {
2175 base = builder->CreateNSWAdd(base, stride);
2176 } else {
2177 base = builder->CreateAdd(base, stride);
2178 }
2179 }
2180 value = builder->CreateInsertElement(value, base, ConstantInt::get(i32_t, i));
2181 }
2182 }
2183 }
2184
create_broadcast(llvm::Value * v,int lanes)2185 llvm::Value *CodeGen_LLVM::create_broadcast(llvm::Value *v, int lanes) {
2186 Constant *undef = UndefValue::get(get_vector_type(v->getType(), lanes));
2187 Constant *zero = ConstantInt::get(i32_t, 0);
2188 v = builder->CreateInsertElement(undef, v, zero);
2189 Constant *zeros = ConstantVector::getSplat(element_count(lanes), zero);
2190 return builder->CreateShuffleVector(v, undef, zeros);
2191 }
2192
visit(const Broadcast * op)2193 void CodeGen_LLVM::visit(const Broadcast *op) {
2194 Value *v = codegen(op->value);
2195 value = create_broadcast(v, op->lanes);
2196 }
2197
interleave_vectors(const std::vector<Value * > & vecs)2198 Value *CodeGen_LLVM::interleave_vectors(const std::vector<Value *> &vecs) {
2199 internal_assert(!vecs.empty());
2200 for (size_t i = 1; i < vecs.size(); i++) {
2201 internal_assert(vecs[0]->getType() == vecs[i]->getType());
2202 }
2203 int vec_elements = get_vector_num_elements(vecs[0]->getType());
2204
2205 if (vecs.size() == 1) {
2206 return vecs[0];
2207 } else if (vecs.size() == 2) {
2208 Value *a = vecs[0];
2209 Value *b = vecs[1];
2210 vector<int> indices(vec_elements * 2);
2211 for (int i = 0; i < vec_elements * 2; i++) {
2212 indices[i] = i % 2 == 0 ? i / 2 : i / 2 + vec_elements;
2213 }
2214 return shuffle_vectors(a, b, indices);
2215 } else {
2216 // Grab the even and odd elements of vecs.
2217 vector<Value *> even_vecs;
2218 vector<Value *> odd_vecs;
2219 for (size_t i = 0; i < vecs.size(); i++) {
2220 if (i % 2 == 0) {
2221 even_vecs.push_back(vecs[i]);
2222 } else {
2223 odd_vecs.push_back(vecs[i]);
2224 }
2225 }
2226
2227 // If the number of vecs is odd, save the last one for later.
2228 Value *last = nullptr;
2229 if (even_vecs.size() > odd_vecs.size()) {
2230 last = even_vecs.back();
2231 even_vecs.pop_back();
2232 }
2233 internal_assert(even_vecs.size() == odd_vecs.size());
2234
2235 // Interleave the even and odd parts.
2236 Value *even = interleave_vectors(even_vecs);
2237 Value *odd = interleave_vectors(odd_vecs);
2238
2239 if (last) {
2240 int result_elements = vec_elements * vecs.size();
2241
2242 // Interleave even and odd, leaving a space for the last element.
2243 vector<int> indices(result_elements, -1);
2244 for (int i = 0, idx = 0; i < result_elements; i++) {
2245 if (i % vecs.size() < vecs.size() - 1) {
2246 indices[i] = idx % 2 == 0 ? idx / 2 : idx / 2 + vec_elements * even_vecs.size();
2247 idx++;
2248 }
2249 }
2250 Value *even_odd = shuffle_vectors(even, odd, indices);
2251
2252 // Interleave the last vector into the result.
2253 last = slice_vector(last, 0, result_elements);
2254 for (int i = 0; i < result_elements; i++) {
2255 if (i % vecs.size() < vecs.size() - 1) {
2256 indices[i] = i;
2257 } else {
2258 indices[i] = i / vecs.size() + result_elements;
2259 }
2260 }
2261
2262 return shuffle_vectors(even_odd, last, indices);
2263 } else {
2264 return interleave_vectors({even, odd});
2265 }
2266 }
2267 }
2268
scalarize(const Expr & e)2269 void CodeGen_LLVM::scalarize(const Expr &e) {
2270 llvm::Type *result_type = llvm_type_of(e.type());
2271
2272 Value *result = UndefValue::get(result_type);
2273
2274 for (int i = 0; i < e.type().lanes(); i++) {
2275 Value *v = codegen(extract_lane(e, i));
2276 result = builder->CreateInsertElement(result, v, ConstantInt::get(i32_t, i));
2277 }
2278 value = result;
2279 }
2280
codegen_predicated_vector_store(const Store * op)2281 void CodeGen_LLVM::codegen_predicated_vector_store(const Store *op) {
2282 const Ramp *ramp = op->index.as<Ramp>();
2283 if (ramp && is_one(ramp->stride)) { // Dense vector store
2284 debug(4) << "Predicated dense vector store\n\t" << Stmt(op) << "\n";
2285 Value *vpred = codegen(op->predicate);
2286 Halide::Type value_type = op->value.type();
2287 Value *val = codegen(op->value);
2288 bool is_external = (external_buffer.find(op->name) != external_buffer.end());
2289 int alignment = value_type.bytes();
2290 int native_bits = native_vector_bits();
2291 int native_bytes = native_bits / 8;
2292
2293 // Boost the alignment if possible, up to the native vector width.
2294 ModulusRemainder mod_rem = op->alignment;
2295 while ((mod_rem.remainder & 1) == 0 &&
2296 (mod_rem.modulus & 1) == 0 &&
2297 alignment < native_bytes) {
2298 mod_rem.modulus /= 2;
2299 mod_rem.remainder /= 2;
2300 alignment *= 2;
2301 }
2302
2303 // If it is an external buffer, then we cannot assume that the host pointer
2304 // is aligned to at least the native vector width. However, we may be able to do
2305 // better than just assuming that it is unaligned.
2306 if (is_external && op->param.defined()) {
2307 int host_alignment = op->param.host_alignment();
2308 alignment = gcd(alignment, host_alignment);
2309 }
2310
2311 // For dense vector stores wider than the native vector
2312 // width, bust them up into native vectors.
2313 int store_lanes = value_type.lanes();
2314 int native_lanes = native_bits / value_type.bits();
2315
2316 for (int i = 0; i < store_lanes; i += native_lanes) {
2317 int slice_lanes = std::min(native_lanes, store_lanes - i);
2318 Expr slice_base = simplify(ramp->base + i);
2319 Expr slice_stride = make_one(slice_base.type());
2320 Expr slice_index = slice_lanes == 1 ? slice_base : Ramp::make(slice_base, slice_stride, slice_lanes);
2321 Value *slice_val = slice_vector(val, i, slice_lanes);
2322 Value *elt_ptr = codegen_buffer_pointer(op->name, value_type.element_of(), slice_base);
2323 Value *vec_ptr = builder->CreatePointerCast(elt_ptr, slice_val->getType()->getPointerTo());
2324
2325 Value *slice_mask = slice_vector(vpred, i, slice_lanes);
2326 #if LLVM_VERSION >= 110
2327 Instruction *store_inst =
2328 builder->CreateMaskedStore(slice_val, vec_ptr, make_alignment(alignment), slice_mask);
2329 #else
2330 Instruction *store_inst =
2331 builder->CreateMaskedStore(slice_val, vec_ptr, alignment, slice_mask);
2332 #endif
2333 add_tbaa_metadata(store_inst, op->name, slice_index);
2334 }
2335 } else { // It's not dense vector store, we need to scalarize it
2336 debug(4) << "Scalarize predicated vector store\n";
2337 Type value_type = op->value.type().element_of();
2338 Value *vpred = codegen(op->predicate);
2339 Value *vval = codegen(op->value);
2340 Value *vindex = codegen(op->index);
2341 for (int i = 0; i < op->index.type().lanes(); i++) {
2342 Constant *lane = ConstantInt::get(i32_t, i);
2343 Value *p = builder->CreateExtractElement(vpred, lane);
2344 if (p->getType() != i1_t) {
2345 p = builder->CreateIsNotNull(p);
2346 }
2347
2348 Value *v = builder->CreateExtractElement(vval, lane);
2349 Value *idx = builder->CreateExtractElement(vindex, lane);
2350 internal_assert(p && v && idx);
2351
2352 BasicBlock *true_bb = BasicBlock::Create(*context, "true_bb", function);
2353 BasicBlock *after_bb = BasicBlock::Create(*context, "after_bb", function);
2354 builder->CreateCondBr(p, true_bb, after_bb);
2355
2356 builder->SetInsertPoint(true_bb);
2357
2358 // Scalar
2359 Value *ptr = codegen_buffer_pointer(op->name, value_type, idx);
2360 builder->CreateAlignedStore(v, ptr, make_alignment(value_type.bytes()));
2361
2362 builder->CreateBr(after_bb);
2363 builder->SetInsertPoint(after_bb);
2364 }
2365 }
2366 }
2367
codegen_dense_vector_load(const Load * load,Value * vpred)2368 Value *CodeGen_LLVM::codegen_dense_vector_load(const Load *load, Value *vpred) {
2369 debug(4) << "Vectorize predicated dense vector load:\n\t" << Expr(load) << "\n";
2370
2371 const Ramp *ramp = load->index.as<Ramp>();
2372 internal_assert(ramp && is_one(ramp->stride)) << "Should be dense vector load\n";
2373
2374 bool is_external = (external_buffer.find(load->name) != external_buffer.end());
2375 int alignment = load->type.bytes(); // The size of a single element
2376
2377 int native_bits = native_vector_bits();
2378 int native_bytes = native_bits / 8;
2379
2380 // We assume halide_malloc for the platform returns buffers
2381 // aligned to at least the native vector width. So this is the
2382 // maximum alignment we can infer based on the index alone.
2383
2384 // Boost the alignment if possible, up to the native vector width.
2385 ModulusRemainder mod_rem = load->alignment;
2386 while ((mod_rem.remainder & 1) == 0 &&
2387 (mod_rem.modulus & 1) == 0 &&
2388 alignment < native_bytes) {
2389 mod_rem.modulus /= 2;
2390 mod_rem.remainder /= 2;
2391 alignment *= 2;
2392 }
2393
2394 // If it is an external buffer, then we cannot assume that the host pointer
2395 // is aligned to at least native vector width. However, we may be able to do
2396 // better than just assuming that it is unaligned.
2397 if (is_external) {
2398 if (load->param.defined()) {
2399 int host_alignment = load->param.host_alignment();
2400 alignment = gcd(alignment, host_alignment);
2401 } else if (get_target().has_feature(Target::JIT) && load->image.defined()) {
2402 // If we're JITting, use the actual pointer value to determine alignment for embedded buffers.
2403 alignment = gcd(alignment, (int)(((uintptr_t)load->image.data()) & std::numeric_limits<int>::max()));
2404 }
2405 }
2406
2407 // For dense vector loads wider than the native vector
2408 // width, bust them up into native vectors
2409 int load_lanes = load->type.lanes();
2410 int native_lanes = std::max(1, native_bits / load->type.bits());
2411 vector<Value *> slices;
2412 for (int i = 0; i < load_lanes; i += native_lanes) {
2413 int slice_lanes = std::min(native_lanes, load_lanes - i);
2414 Expr slice_base = simplify(ramp->base + i);
2415 Expr slice_stride = make_one(slice_base.type());
2416 Expr slice_index = slice_lanes == 1 ? slice_base : Ramp::make(slice_base, slice_stride, slice_lanes);
2417 llvm::Type *slice_type = get_vector_type(llvm_type_of(load->type.element_of()), slice_lanes);
2418 Value *elt_ptr = codegen_buffer_pointer(load->name, load->type.element_of(), slice_base);
2419 Value *vec_ptr = builder->CreatePointerCast(elt_ptr, slice_type->getPointerTo());
2420
2421 Instruction *load_inst;
2422 if (vpred != nullptr) {
2423 Value *slice_mask = slice_vector(vpred, i, slice_lanes);
2424 #if LLVM_VERSION >= 110
2425 load_inst = builder->CreateMaskedLoad(vec_ptr, make_alignment(alignment), slice_mask);
2426 #else
2427 load_inst = builder->CreateMaskedLoad(vec_ptr, alignment, slice_mask);
2428 #endif
2429 } else {
2430 load_inst = builder->CreateAlignedLoad(vec_ptr, make_alignment(alignment));
2431 }
2432 add_tbaa_metadata(load_inst, load->name, slice_index);
2433 slices.push_back(load_inst);
2434 }
2435 value = concat_vectors(slices);
2436 return value;
2437 }
2438
codegen_predicated_vector_load(const Load * op)2439 void CodeGen_LLVM::codegen_predicated_vector_load(const Load *op) {
2440 const Ramp *ramp = op->index.as<Ramp>();
2441 const IntImm *stride = ramp ? ramp->stride.as<IntImm>() : nullptr;
2442
2443 if (ramp && is_one(ramp->stride)) { // Dense vector load
2444 Value *vpred = codegen(op->predicate);
2445 value = codegen_dense_vector_load(op, vpred);
2446 } else if (ramp && stride && stride->value == -1) {
2447 debug(4) << "Predicated dense vector load with stride -1\n\t" << Expr(op) << "\n";
2448 vector<int> indices(ramp->lanes);
2449 for (int i = 0; i < ramp->lanes; i++) {
2450 indices[i] = ramp->lanes - 1 - i;
2451 }
2452
2453 // Flip the predicate
2454 Value *vpred = codegen(op->predicate);
2455 vpred = shuffle_vectors(vpred, indices);
2456
2457 // Load the vector and then flip it in-place
2458 Expr flipped_base = ramp->base - ramp->lanes + 1;
2459 Expr flipped_stride = make_one(flipped_base.type());
2460 Expr flipped_index = Ramp::make(flipped_base, flipped_stride, ramp->lanes);
2461 ModulusRemainder align = op->alignment;
2462 align = align - (ramp->lanes - 1);
2463
2464 Expr flipped_load = Load::make(op->type, op->name, flipped_index, op->image,
2465 op->param, const_true(op->type.lanes()), align);
2466
2467 Value *flipped = codegen_dense_vector_load(flipped_load.as<Load>(), vpred);
2468 value = shuffle_vectors(flipped, indices);
2469 } else { // It's not dense vector load, we need to scalarize it
2470 Expr load_expr = Load::make(op->type, op->name, op->index, op->image,
2471 op->param, const_true(op->type.lanes()), op->alignment);
2472 debug(4) << "Scalarize predicated vector load\n\t" << load_expr << "\n";
2473 Expr pred_load = Call::make(load_expr.type(),
2474 Call::if_then_else,
2475 {op->predicate, load_expr, make_zero(load_expr.type())},
2476 Internal::Call::Intrinsic);
2477 value = codegen(pred_load);
2478 }
2479 }
2480
codegen_atomic_store(const Store * op)2481 void CodeGen_LLVM::codegen_atomic_store(const Store *op) {
2482 // TODO: predicated store (see https://github.com/halide/Halide/issues/4298).
2483 user_assert(is_one(op->predicate)) << "Atomic predicated store is not supported.\n";
2484
2485 // Detect whether we can describe this as an atomic-read-modify-write,
2486 // otherwise fallback to a compare-and-swap loop.
2487 // Currently we only test for atomicAdd.
2488 Expr val_expr = op->value;
2489 Halide::Type value_type = op->value.type();
2490 // For atomicAdd, we check if op->value - store[index] is independent of store.
2491 // For llvm version < 9, the atomicRMW operations only support integers so we also check that.
2492 Expr equiv_load = Load::make(value_type, op->name,
2493 op->index,
2494 Buffer<>(),
2495 op->param,
2496 op->predicate,
2497 op->alignment);
2498 Expr delta = simplify(common_subexpression_elimination(op->value - equiv_load));
2499 bool is_atomic_add = supports_atomic_add(value_type) && !expr_uses_var(delta, op->name);
2500 if (is_atomic_add) {
2501 Value *val = codegen(delta);
2502 if (value_type.is_scalar()) {
2503 Value *ptr = codegen_buffer_pointer(op->name,
2504 op->value.type(),
2505 op->index);
2506 // llvm 9 has FAdd which can be used for atomic floats.
2507 if (value_type.is_float()) {
2508 builder->CreateAtomicRMW(AtomicRMWInst::FAdd, ptr, val, AtomicOrdering::Monotonic);
2509 } else {
2510 builder->CreateAtomicRMW(AtomicRMWInst::Add, ptr, val, AtomicOrdering::Monotonic);
2511 }
2512 } else {
2513 Value *index = codegen(op->index);
2514 // Scalarize vector store.
2515 for (int i = 0; i < value_type.lanes(); i++) {
2516 Value *lane = ConstantInt::get(i32_t, i);
2517 Value *idx = builder->CreateExtractElement(index, lane);
2518 Value *v = builder->CreateExtractElement(val, lane);
2519 Value *ptr = codegen_buffer_pointer(op->name, value_type.element_of(), idx);
2520 if (value_type.is_float()) {
2521 builder->CreateAtomicRMW(AtomicRMWInst::FAdd, ptr, v, AtomicOrdering::Monotonic);
2522 } else {
2523 builder->CreateAtomicRMW(AtomicRMWInst::Add, ptr, v, AtomicOrdering::Monotonic);
2524 }
2525 }
2526 }
2527 } else {
2528 // We want to create the following CAS loop:
2529 // entry:
2530 // %orig = load atomic op->name[op->index]
2531 // br label %casloop.start
2532 // casloop.start:
2533 // %cmp = phi [%orig, %entry], [%value_loaded %casloop.start]
2534 // %val = ...
2535 // %val_success = cmpxchg %ptr, %cmp, %val, monotonic
2536 // %val_loaded = extractvalue %val_success, 0
2537 // %success = extractvalue %val_success, 1
2538 // br %success, label %casloop.end, label %casloop.start
2539 // casloop.end:
2540 Value *vec_index = nullptr;
2541 if (!value_type.is_scalar()) {
2542 // Precompute index for vector store.
2543 vec_index = codegen(op->index);
2544 }
2545 // Scalarize vector store.
2546 for (int lane_id = 0; lane_id < value_type.lanes(); lane_id++) {
2547 LLVMContext &ctx = builder->getContext();
2548 BasicBlock *bb = builder->GetInsertBlock();
2549 llvm::Function *f = bb->getParent();
2550 BasicBlock *loop_bb =
2551 BasicBlock::Create(ctx, "casloop.start", f);
2552 // Load the old value for compare and swap test.
2553 Value *ptr = nullptr;
2554 if (value_type.is_scalar()) {
2555 ptr = codegen_buffer_pointer(op->name, value_type, op->index);
2556 } else {
2557 Value *idx = builder->CreateExtractElement(vec_index, ConstantInt::get(i32_t, lane_id));
2558 ptr = codegen_buffer_pointer(op->name, value_type.element_of(), idx);
2559 }
2560 LoadInst *orig = builder->CreateAlignedLoad(ptr, make_alignment(value_type.bytes()));
2561 orig->setOrdering(AtomicOrdering::Monotonic);
2562 add_tbaa_metadata(orig, op->name, op->index);
2563 // Explicit fall through from the current block to the cas loop body.
2564 builder->CreateBr(loop_bb);
2565
2566 // CAS loop body:
2567 builder->SetInsertPoint(loop_bb);
2568 llvm::Type *ptr_type = ptr->getType();
2569 PHINode *cmp = builder->CreatePHI(ptr_type->getPointerElementType(), 2, "loaded");
2570 Value *cmp_val = cmp;
2571 cmp->addIncoming(orig, bb);
2572 Value *val = nullptr;
2573 if (value_type.is_scalar()) {
2574 val = codegen(op->value);
2575 } else {
2576 val = codegen(extract_lane(op->value, lane_id));
2577 }
2578 llvm::Type *val_type = val->getType();
2579 bool need_bit_cast = val_type->isFloatingPointTy();
2580 if (need_bit_cast) {
2581 IntegerType *int_type = builder->getIntNTy(val_type->getPrimitiveSizeInBits());
2582 unsigned int addr_space = ptr_type->getPointerAddressSpace();
2583 ptr = builder->CreateBitCast(ptr, int_type->getPointerTo(addr_space));
2584 val = builder->CreateBitCast(val, int_type);
2585 cmp_val = builder->CreateBitCast(cmp_val, int_type);
2586 }
2587 Value *cmpxchg_pair = builder->CreateAtomicCmpXchg(
2588 ptr, cmp_val, val, AtomicOrdering::Monotonic, AtomicOrdering::Monotonic);
2589 Value *val_loaded = builder->CreateExtractValue(cmpxchg_pair, 0, "val_loaded");
2590 Value *success = builder->CreateExtractValue(cmpxchg_pair, 1, "success");
2591 if (need_bit_cast) {
2592 val_loaded = builder->CreateBitCast(val_loaded, val_type);
2593 }
2594 cmp->addIncoming(val_loaded, loop_bb);
2595 BasicBlock *exit_bb =
2596 BasicBlock::Create(ctx, "casloop.end", f);
2597 builder->CreateCondBr(success, exit_bb, loop_bb);
2598 builder->SetInsertPoint(exit_bb);
2599 }
2600 }
2601 }
2602
visit(const Call * op)2603 void CodeGen_LLVM::visit(const Call *op) {
2604 internal_assert(op->is_extern() || op->is_intrinsic())
2605 << "Can only codegen extern calls and intrinsics\n";
2606
2607 // Some call nodes are actually injected at various stages as a
2608 // cue for llvm to generate particular ops. In general these are
2609 // handled in the standard library, but ones with e.g. varying
2610 // types are handled here.
2611 if (op->is_intrinsic(Call::debug_to_file)) {
2612 internal_assert(op->args.size() == 3);
2613 const StringImm *filename = op->args[0].as<StringImm>();
2614 internal_assert(filename) << "Malformed debug_to_file node\n";
2615 // Grab the function from the initial module
2616 llvm::Function *debug_to_file = module->getFunction("halide_debug_to_file");
2617 internal_assert(debug_to_file) << "Could not find halide_debug_to_file function in initial module\n";
2618
2619 // Make the filename a global string constant
2620 Value *user_context = get_user_context();
2621 Value *char_ptr = codegen(Expr(filename));
2622 vector<Value *> args = {user_context, char_ptr, codegen(op->args[1])};
2623
2624 Value *buffer = codegen(op->args[2]);
2625 buffer = builder->CreatePointerCast(buffer, debug_to_file->getFunctionType()->getParamType(3));
2626 args.push_back(buffer);
2627
2628 value = builder->CreateCall(debug_to_file, args);
2629
2630 } else if (op->is_intrinsic(Call::bitwise_and)) {
2631 internal_assert(op->args.size() == 2);
2632 Value *a = codegen(op->args[0]);
2633 Value *b = codegen(op->args[1]);
2634 value = builder->CreateAnd(a, b);
2635 } else if (op->is_intrinsic(Call::bitwise_xor)) {
2636 internal_assert(op->args.size() == 2);
2637 Value *a = codegen(op->args[0]);
2638 Value *b = codegen(op->args[1]);
2639 value = builder->CreateXor(a, b);
2640 } else if (op->is_intrinsic(Call::bitwise_or)) {
2641 internal_assert(op->args.size() == 2);
2642 Value *a = codegen(op->args[0]);
2643 Value *b = codegen(op->args[1]);
2644 value = builder->CreateOr(a, b);
2645 } else if (op->is_intrinsic(Call::bitwise_not)) {
2646 internal_assert(op->args.size() == 1);
2647 Value *a = codegen(op->args[0]);
2648 value = builder->CreateNot(a);
2649 } else if (op->is_intrinsic(Call::reinterpret)) {
2650 internal_assert(op->args.size() == 1);
2651 Type dst = op->type;
2652 Type src = op->args[0].type();
2653 llvm::Type *llvm_dst = llvm_type_of(dst);
2654 value = codegen(op->args[0]);
2655 if (src.is_handle() && !dst.is_handle()) {
2656 internal_assert(dst.is_uint() && dst.bits() == 64);
2657
2658 // Handle -> UInt64
2659 llvm::DataLayout d(module.get());
2660 if (d.getPointerSize() == 4) {
2661 llvm::Type *intermediate = llvm_type_of(UInt(32, dst.lanes()));
2662 value = builder->CreatePtrToInt(value, intermediate);
2663 value = builder->CreateZExt(value, llvm_dst);
2664 } else if (d.getPointerSize() == 8) {
2665 value = builder->CreatePtrToInt(value, llvm_dst);
2666 } else {
2667 internal_error << "Pointer size is neither 4 nor 8 bytes\n";
2668 }
2669
2670 } else if (dst.is_handle() && !src.is_handle()) {
2671 internal_assert(src.is_uint() && src.bits() == 64);
2672
2673 // UInt64 -> Handle
2674 llvm::DataLayout d(module.get());
2675 if (d.getPointerSize() == 4) {
2676 llvm::Type *intermediate = llvm_type_of(UInt(32, src.lanes()));
2677 value = builder->CreateTrunc(value, intermediate);
2678 value = builder->CreateIntToPtr(value, llvm_dst);
2679 } else if (d.getPointerSize() == 8) {
2680 value = builder->CreateIntToPtr(value, llvm_dst);
2681 } else {
2682 internal_error << "Pointer size is neither 4 nor 8 bytes\n";
2683 }
2684
2685 } else {
2686 value = builder->CreateBitCast(value, llvm_dst);
2687 }
2688 } else if (op->is_intrinsic(Call::shift_left)) {
2689 internal_assert(op->args.size() == 2);
2690 Value *a = codegen(op->args[0]);
2691 Value *b = codegen(op->args[1]);
2692 if (op->args[1].type().is_uint()) {
2693 value = builder->CreateShl(a, b);
2694 } else {
2695 value = codegen(lower_signed_shift_left(op->args[0], op->args[1]));
2696 }
2697 } else if (op->is_intrinsic(Call::shift_right)) {
2698 internal_assert(op->args.size() == 2);
2699 Value *a = codegen(op->args[0]);
2700 Value *b = codegen(op->args[1]);
2701 if (op->args[1].type().is_uint()) {
2702 if (op->type.is_int()) {
2703 value = builder->CreateAShr(a, b);
2704 } else {
2705 value = builder->CreateLShr(a, b);
2706 }
2707 } else {
2708 value = codegen(lower_signed_shift_right(op->args[0], op->args[1]));
2709 }
2710 } else if (op->is_intrinsic(Call::abs)) {
2711
2712 internal_assert(op->args.size() == 1);
2713
2714 // Check if an appropriate vector abs for this type exists in the initial module
2715 Type t = op->args[0].type();
2716 string name = (t.is_float() ? "abs_f" : "abs_i") + std::to_string(t.bits());
2717 llvm::Function *builtin_abs =
2718 find_vector_runtime_function(name, op->type.lanes()).first;
2719
2720 if (t.is_vector() && builtin_abs) {
2721 codegen(Call::make(op->type, name, op->args, Call::Extern));
2722 } else {
2723 // Generate select(x >= 0, x, -x) instead
2724 string x_name = unique_name('x');
2725 Expr x = Variable::make(op->args[0].type(), x_name);
2726 value = codegen(Let::make(x_name, op->args[0], select(x >= 0, x, -x)));
2727 }
2728 } else if (op->is_intrinsic(Call::absd)) {
2729
2730 internal_assert(op->args.size() == 2);
2731
2732 Expr a = op->args[0];
2733 Expr b = op->args[1];
2734
2735 // Check if an appropriate vector abs for this type exists in the initial module
2736 Type t = a.type();
2737 string name;
2738 if (t.is_float()) {
2739 codegen(abs(a - b));
2740 return;
2741 } else if (t.is_int()) {
2742 name = "absd_i" + std::to_string(t.bits());
2743 } else {
2744 name = "absd_u" + std::to_string(t.bits());
2745 }
2746
2747 llvm::Function *builtin_absd =
2748 find_vector_runtime_function(name, op->type.lanes()).first;
2749
2750 if (t.is_vector() && builtin_absd) {
2751 codegen(Call::make(op->type, name, op->args, Call::Extern));
2752 } else {
2753 // Use a select instead
2754 string a_name = unique_name('a');
2755 string b_name = unique_name('b');
2756 Expr a_var = Variable::make(op->args[0].type(), a_name);
2757 Expr b_var = Variable::make(op->args[1].type(), b_name);
2758 codegen(Let::make(a_name, op->args[0],
2759 Let::make(b_name, op->args[1],
2760 Select::make(a_var < b_var, b_var - a_var, a_var - b_var))));
2761 }
2762 } else if (op->is_intrinsic(Call::div_round_to_zero)) {
2763 internal_assert(op->args.size() == 2);
2764 Value *a = codegen(op->args[0]);
2765 Value *b = codegen(op->args[1]);
2766 if (op->type.is_int()) {
2767 value = builder->CreateSDiv(a, b);
2768 } else if (op->type.is_uint()) {
2769 value = builder->CreateUDiv(a, b);
2770 } else {
2771 internal_error << "div_round_to_zero of non-integer type.\n";
2772 }
2773 } else if (op->is_intrinsic(Call::mod_round_to_zero)) {
2774 internal_assert(op->args.size() == 2);
2775 Value *a = codegen(op->args[0]);
2776 Value *b = codegen(op->args[1]);
2777 if (op->type.is_int()) {
2778 value = builder->CreateSRem(a, b);
2779 } else if (op->type.is_uint()) {
2780 value = builder->CreateURem(a, b);
2781 } else {
2782 internal_error << "mod_round_to_zero of non-integer type.\n";
2783 }
2784 } else if (op->is_intrinsic(Call::mulhi_shr)) {
2785 internal_assert(op->args.size() == 3);
2786
2787 Type ty = op->type;
2788 Type wide_ty = ty.with_bits(ty.bits() * 2);
2789
2790 Expr p_wide = cast(wide_ty, op->args[0]) * cast(wide_ty, op->args[1]);
2791 const UIntImm *shift = op->args[2].as<UIntImm>();
2792 internal_assert(shift != nullptr) << "Third argument to mulhi_shr intrinsic must be an unsigned integer immediate.\n";
2793 value = codegen(cast(ty, p_wide >> (shift->value + ty.bits())));
2794 } else if (op->is_intrinsic(Call::sorted_avg)) {
2795 internal_assert(op->args.size() == 2);
2796 // b > a, so the following works without widening:
2797 // a + (b - a)/2
2798 value = codegen(op->args[0] + (op->args[1] - op->args[0]) / 2);
2799 } else if (op->is_intrinsic(Call::lerp)) {
2800 internal_assert(op->args.size() == 3);
2801 // If we need to upgrade the type, do the entire lerp in the
2802 // upgraded type for better precision.
2803 Type t = upgrade_type_for_arithmetic(op->type);
2804 Type wt = upgrade_type_for_arithmetic(op->args[2].type());
2805 Expr e = lower_lerp(cast(t, op->args[0]),
2806 cast(t, op->args[1]),
2807 cast(wt, op->args[2]));
2808 e = cast(op->type, e);
2809 codegen(e);
2810 } else if (op->is_intrinsic(Call::popcount)) {
2811 internal_assert(op->args.size() == 1);
2812 std::vector<llvm::Type *> arg_type(1);
2813 arg_type[0] = llvm_type_of(op->args[0].type());
2814 llvm::Function *fn = Intrinsic::getDeclaration(module.get(), Intrinsic::ctpop, arg_type);
2815 Value *a = codegen(op->args[0]);
2816 CallInst *call = builder->CreateCall(fn, a);
2817 value = call;
2818 } else if (op->is_intrinsic(Call::count_leading_zeros) ||
2819 op->is_intrinsic(Call::count_trailing_zeros)) {
2820 internal_assert(op->args.size() == 1);
2821 std::vector<llvm::Type *> arg_type(1);
2822 arg_type[0] = llvm_type_of(op->args[0].type());
2823 llvm::Function *fn = Intrinsic::getDeclaration(module.get(),
2824 (op->is_intrinsic(Call::count_leading_zeros)) ? Intrinsic::ctlz : Intrinsic::cttz,
2825 arg_type);
2826 llvm::Value *is_zero_undef = llvm::ConstantInt::getFalse(*context);
2827 llvm::Value *args[2] = {codegen(op->args[0]), is_zero_undef};
2828 CallInst *call = builder->CreateCall(fn, args);
2829 value = call;
2830 } else if (op->is_intrinsic(Call::return_second)) {
2831 internal_assert(op->args.size() == 2);
2832 codegen(op->args[0]);
2833 value = codegen(op->args[1]);
2834 } else if (op->is_intrinsic(Call::if_then_else)) {
2835 Expr cond = op->args[0];
2836 if (const Broadcast *b = cond.as<Broadcast>()) {
2837 cond = b->value;
2838 }
2839 if (cond.type().is_vector()) {
2840 scalarize(op);
2841 } else {
2842
2843 internal_assert(op->args.size() == 3);
2844
2845 BasicBlock *true_bb = BasicBlock::Create(*context, "true_bb", function);
2846 BasicBlock *false_bb = BasicBlock::Create(*context, "false_bb", function);
2847 BasicBlock *after_bb = BasicBlock::Create(*context, "after_bb", function);
2848 Value *c = codegen(cond);
2849 if (c->getType() != i1_t) {
2850 c = builder->CreateIsNotNull(c);
2851 }
2852 builder->CreateCondBr(c, true_bb, false_bb);
2853 builder->SetInsertPoint(true_bb);
2854 Value *true_value = codegen(op->args[1]);
2855 builder->CreateBr(after_bb);
2856 BasicBlock *true_pred = builder->GetInsertBlock();
2857
2858 builder->SetInsertPoint(false_bb);
2859 Value *false_value = codegen(op->args[2]);
2860 builder->CreateBr(after_bb);
2861 BasicBlock *false_pred = builder->GetInsertBlock();
2862
2863 builder->SetInsertPoint(after_bb);
2864 PHINode *phi = builder->CreatePHI(true_value->getType(), 2);
2865 phi->addIncoming(true_value, true_pred);
2866 phi->addIncoming(false_value, false_pred);
2867
2868 value = phi;
2869 }
2870 } else if (op->is_intrinsic(Call::require)) {
2871 internal_assert(op->args.size() == 3);
2872 Expr cond = op->args[0];
2873 if (cond.type().is_vector()) {
2874 scalarize(op);
2875 } else {
2876 Value *c = codegen(cond);
2877 create_assertion(c, op->args[2]);
2878 value = codegen(op->args[1]);
2879 }
2880 } else if (op->is_intrinsic(Call::make_struct)) {
2881 if (op->type.is_vector()) {
2882 // Make a vector of pointers to distinct structs
2883 scalarize(op);
2884 } else if (op->args.empty()) {
2885 // Empty structs can be emitted for arrays of size zero
2886 // (e.g. the shape of a zero-dimensional buffer). We
2887 // generate a null in this situation. */
2888 value = ConstantPointerNull::get(dyn_cast<PointerType>(llvm_type_of(op->type)));
2889 } else {
2890 // Codegen each element.
2891 bool all_same_type = true;
2892 vector<llvm::Value *> args(op->args.size());
2893 vector<llvm::Type *> types(op->args.size());
2894 for (size_t i = 0; i < op->args.size(); i++) {
2895 args[i] = codegen(op->args[i]);
2896 types[i] = args[i]->getType();
2897 all_same_type &= (types[0] == types[i]);
2898 }
2899
2900 // Use either a single scalar, a fixed-size array, or a
2901 // struct. The struct type would always be correct, but
2902 // the array or scalar type produce slightly simpler IR.
2903 if (args.size() == 1) {
2904 value = create_alloca_at_entry(types[0], 1);
2905 builder->CreateStore(args[0], value);
2906 } else {
2907 llvm::Type *aggregate_t = (all_same_type ? (llvm::Type *)ArrayType::get(types[0], types.size()) : (llvm::Type *)StructType::get(*context, types));
2908
2909 value = create_alloca_at_entry(aggregate_t, 1);
2910 for (size_t i = 0; i < args.size(); i++) {
2911 Value *elem_ptr = builder->CreateConstInBoundsGEP2_32(aggregate_t, value, 0, i);
2912 builder->CreateStore(args[i], elem_ptr);
2913 }
2914 }
2915 }
2916
2917 } else if (op->is_intrinsic(Call::stringify)) {
2918 internal_assert(!op->args.empty());
2919
2920 if (op->type.is_vector()) {
2921 scalarize(op);
2922 } else {
2923
2924 // Compute the maximum possible size of the message.
2925 int buf_size = 1; // One for the terminating zero.
2926 for (size_t i = 0; i < op->args.size(); i++) {
2927 Type t = op->args[i].type();
2928 if (op->args[i].as<StringImm>()) {
2929 buf_size += op->args[i].as<StringImm>()->value.size();
2930 } else if (t.is_int() || t.is_uint()) {
2931 buf_size += 19; // 2^64 = 18446744073709551616
2932 } else if (t.is_float()) {
2933 if (t.bits() == 32) {
2934 buf_size += 47; // %f format of max negative float
2935 } else {
2936 buf_size += 14; // Scientific notation with 6 decimal places.
2937 }
2938 } else if (t == type_of<halide_buffer_t *>()) {
2939 // Not a strict upper bound (there isn't one), but ought to be enough for most buffers.
2940 buf_size += 512;
2941 } else {
2942 internal_assert(t.is_handle());
2943 buf_size += 18; // 0x0123456789abcdef
2944 }
2945 }
2946 // Round up to a multiple of 16 bytes.
2947 buf_size = ((buf_size + 15) / 16) * 16;
2948
2949 // Clamp to at most 8k.
2950 if (buf_size > 8 * 1024) buf_size = 8 * 1024;
2951
2952 // Allocate a stack array to hold the message.
2953 llvm::Value *buf = create_alloca_at_entry(i8_t, buf_size);
2954
2955 llvm::Value *dst = buf;
2956 llvm::Value *buf_end = builder->CreateConstGEP1_32(buf, buf_size);
2957
2958 llvm::Function *append_string = module->getFunction("halide_string_to_string");
2959 llvm::Function *append_int64 = module->getFunction("halide_int64_to_string");
2960 llvm::Function *append_uint64 = module->getFunction("halide_uint64_to_string");
2961 llvm::Function *append_double = module->getFunction("halide_double_to_string");
2962 llvm::Function *append_pointer = module->getFunction("halide_pointer_to_string");
2963 llvm::Function *append_buffer = module->getFunction("halide_buffer_to_string");
2964
2965 internal_assert(append_string);
2966 internal_assert(append_int64);
2967 internal_assert(append_uint64);
2968 internal_assert(append_double);
2969 internal_assert(append_pointer);
2970 internal_assert(append_buffer);
2971
2972 for (size_t i = 0; i < op->args.size(); i++) {
2973 const StringImm *s = op->args[i].as<StringImm>();
2974 Type t = op->args[i].type();
2975 internal_assert(t.lanes() == 1);
2976 vector<Value *> call_args(2);
2977 call_args[0] = dst;
2978 call_args[1] = buf_end;
2979
2980 if (s) {
2981 call_args.push_back(codegen(op->args[i]));
2982 dst = builder->CreateCall(append_string, call_args);
2983 } else if (t.is_bool()) {
2984 Value *a = codegen(op->args[i]);
2985 Value *t = codegen(StringImm::make("true"));
2986 Value *f = codegen(StringImm::make("false"));
2987 call_args.push_back(builder->CreateSelect(a, t, f));
2988 dst = builder->CreateCall(append_string, call_args);
2989 } else if (t.is_int()) {
2990 call_args.push_back(codegen(Cast::make(Int(64), op->args[i])));
2991 call_args.push_back(ConstantInt::get(i32_t, 1));
2992 dst = builder->CreateCall(append_int64, call_args);
2993 } else if (t.is_uint()) {
2994 call_args.push_back(codegen(Cast::make(UInt(64), op->args[i])));
2995 call_args.push_back(ConstantInt::get(i32_t, 1));
2996 dst = builder->CreateCall(append_uint64, call_args);
2997 } else if (t.is_float()) {
2998 call_args.push_back(codegen(Cast::make(Float(64), op->args[i])));
2999 // Use scientific notation for doubles
3000 call_args.push_back(ConstantInt::get(i32_t, t.bits() == 64 ? 1 : 0));
3001 dst = builder->CreateCall(append_double, call_args);
3002 } else if (t == type_of<halide_buffer_t *>()) {
3003 Value *buf = codegen(op->args[i]);
3004 buf = builder->CreatePointerCast(buf, append_buffer->getFunctionType()->getParamType(2));
3005 call_args.push_back(buf);
3006 dst = builder->CreateCall(append_buffer, call_args);
3007 } else {
3008 internal_assert(t.is_handle());
3009 call_args.push_back(codegen(op->args[i]));
3010 dst = builder->CreateCall(append_pointer, call_args);
3011 }
3012 }
3013 if (get_target().has_feature(Target::MSAN)) {
3014 // Note that we mark the entire buffer as initialized;
3015 // it would be more accurate to just mark (dst - buf)
3016 llvm::Function *annotate = module->getFunction("halide_msan_annotate_memory_is_initialized");
3017 vector<Value *> annotate_args(3);
3018 annotate_args[0] = get_user_context();
3019 annotate_args[1] = buf;
3020 annotate_args[2] = codegen(Cast::make(Int(64), buf_size));
3021 builder->CreateCall(annotate, annotate_args);
3022 }
3023 value = buf;
3024 }
3025 } else if (op->is_intrinsic(Call::memoize_expr)) {
3026 // Used as an annotation for caching, should be invisible to
3027 // codegen. Ignore arguments beyond the first as they are only
3028 // used in the cache key.
3029 internal_assert(!op->args.empty());
3030 value = codegen(op->args[0]);
3031 } else if (op->is_intrinsic(Call::alloca)) {
3032 // The argument is the number of bytes. For now it must be
3033 // const, or a call to size_of_halide_buffer_t.
3034 internal_assert(op->args.size() == 1);
3035
3036 // We can generate slightly cleaner IR with fewer alignment
3037 // restrictions if we recognize the most common types we
3038 // expect to get alloca'd.
3039 const Call *call = op->args[0].as<Call>();
3040 if (op->type == type_of<struct halide_buffer_t *>() &&
3041 call && call->is_intrinsic(Call::size_of_halide_buffer_t)) {
3042 value = create_alloca_at_entry(halide_buffer_t_type, 1);
3043 } else {
3044 const int64_t *sz = as_const_int(op->args[0]);
3045 internal_assert(sz);
3046 if (op->type == type_of<struct halide_dimension_t *>()) {
3047 value = create_alloca_at_entry(dimension_t_type, *sz / sizeof(halide_dimension_t));
3048 } else {
3049 // Just use an i8* and make the users bitcast it.
3050 value = create_alloca_at_entry(i8_t, *sz);
3051 }
3052 }
3053 } else if (op->is_intrinsic(Call::register_destructor)) {
3054 internal_assert(op->args.size() == 2);
3055 const StringImm *fn = op->args[0].as<StringImm>();
3056 internal_assert(fn);
3057 llvm::Function *f = module->getFunction(fn->value);
3058 if (!f) {
3059 llvm::Type *arg_types[] = {i8_t->getPointerTo(), i8_t->getPointerTo()};
3060 FunctionType *func_t = FunctionType::get(void_t, arg_types, false);
3061 f = llvm::Function::Create(func_t, llvm::Function::ExternalLinkage, fn->value, module.get());
3062 f->setCallingConv(CallingConv::C);
3063 }
3064 internal_assert(op->args[1].type().is_handle());
3065 Value *arg = codegen(op->args[1]);
3066 value = register_destructor(f, arg, Always);
3067 } else if (op->is_intrinsic(Call::call_cached_indirect_function)) {
3068 // Arguments to call_cached_indirect_function are of the form
3069 //
3070 // cond_1, "sub_function_name_1",
3071 // cond_2, "sub_function_name_2",
3072 // ...
3073 // cond_N, "sub_function_name_N"
3074 //
3075 // This will generate code that corresponds (roughly) to
3076 //
3077 // static FunctionPtr f = []{
3078 // if (cond_1) return sub_function_name_1;
3079 // if (cond_2) return sub_function_name_2;
3080 // ...
3081 // if (cond_N) return sub_function_name_N;
3082 // }
3083 // return f(args)
3084 //
3085 // i.e.: the conditions will be evaluated *in order*; the first one
3086 // evaluating to true will have its corresponding function cached,
3087 // which will be used to complete this (and all subsequent) calls.
3088 //
3089 // The final condition (cond_N) must evaluate to a constant TRUE
3090 // value (so that the final function will be selected if all others
3091 // fail); failure to do so will cause unpredictable results.
3092 //
3093 // There is currently no way to clear the cached function pointer.
3094 //
3095 // It is assumed/required that all of the conditions are "pure"; each
3096 // must evaluate to the same value (within a given runtime environment)
3097 // across multiple evaluations.
3098 //
3099 // It is assumed/required that all of the sub-functions have arguments
3100 // (and return values) that are identical to those of this->function.
3101 //
3102 // Note that we require >= 4 arguments: fewer would imply
3103 // only one condition+function pair, which is pointless to use
3104 // (the function should always be called directly).
3105 //
3106 internal_assert(op->args.size() >= 4);
3107 internal_assert(!(op->args.size() & 1));
3108
3109 // Gather information we need about each function.
3110 struct SubFn {
3111 llvm::Function *fn;
3112 llvm::GlobalValue *fn_ptr;
3113 Expr cond;
3114 };
3115 vector<SubFn> sub_fns;
3116 for (size_t i = 0; i < op->args.size(); i += 2) {
3117 const string sub_fn_name = op->args[i + 1].as<StringImm>()->value;
3118 string extern_sub_fn_name = sub_fn_name;
3119 llvm::Function *sub_fn = module->getFunction(sub_fn_name);
3120 if (!sub_fn) {
3121 extern_sub_fn_name = get_mangled_names(sub_fn_name,
3122 LinkageType::External,
3123 NameMangling::Default,
3124 current_function_args,
3125 get_target())
3126 .extern_name;
3127 debug(1) << "Did not find function " << sub_fn_name
3128 << ", assuming extern \"C\" " << extern_sub_fn_name << "\n";
3129 vector<llvm::Type *> arg_types;
3130 for (const auto &arg : function->args()) {
3131 arg_types.push_back(arg.getType());
3132 }
3133 llvm::Type *result_type = llvm_type_of(upgrade_type_for_argument_passing(op->type));
3134 FunctionType *func_t = FunctionType::get(result_type, arg_types, false);
3135 sub_fn = llvm::Function::Create(func_t, llvm::Function::ExternalLinkage,
3136 extern_sub_fn_name, module.get());
3137 sub_fn->setCallingConv(CallingConv::C);
3138 }
3139
3140 llvm::GlobalValue *sub_fn_ptr = module->getNamedValue(extern_sub_fn_name);
3141 if (!sub_fn_ptr) {
3142 debug(1) << "Did not find function ptr " << extern_sub_fn_name << ", assuming extern \"C\".\n";
3143 sub_fn_ptr = new GlobalVariable(*module, sub_fn->getType(),
3144 /*isConstant*/ true, GlobalValue::ExternalLinkage,
3145 /*initializer*/ nullptr, extern_sub_fn_name);
3146 }
3147 auto cond = op->args[i];
3148 sub_fns.push_back({sub_fn, sub_fn_ptr, cond});
3149 }
3150
3151 // Create a null-initialized global to track this object.
3152 const auto base_fn = sub_fns.back().fn;
3153 const string global_name = unique_name(base_fn->getName().str() + "_indirect_fn_ptr");
3154 GlobalVariable *global = new GlobalVariable(
3155 *module,
3156 base_fn->getType(),
3157 /*isConstant*/ false,
3158 GlobalValue::PrivateLinkage,
3159 ConstantPointerNull::get(base_fn->getType()),
3160 global_name);
3161 LoadInst *loaded_value = builder->CreateLoad(global);
3162
3163 BasicBlock *global_inited_bb = BasicBlock::Create(*context, "global_inited_bb", function);
3164 BasicBlock *global_not_inited_bb = BasicBlock::Create(*context, "global_not_inited_bb", function);
3165 BasicBlock *call_fn_bb = BasicBlock::Create(*context, "call_fn_bb", function);
3166
3167 // Only init the global if not already inited.
3168 //
3169 // Note that we deliberately do not attempt to make this threadsafe via (e.g.) mutexes;
3170 // the requirements of the conditions above mean that multiple writes *should* only
3171 // be able to re-write the same value, which is harmless for our purposes, and
3172 // avoiding such code simplifies and speeds the resulting code.
3173 //
3174 // (Note that if we ever need to add a way to clear the cached function pointer,
3175 // we may need to reconsider this, to avoid amusingly horrible race conditions.)
3176 builder->CreateCondBr(builder->CreateIsNotNull(loaded_value),
3177 global_inited_bb, global_not_inited_bb, very_likely_branch);
3178
3179 // Build the not-already-inited case
3180 builder->SetInsertPoint(global_not_inited_bb);
3181 llvm::Value *selected_value = nullptr;
3182 for (int i = sub_fns.size() - 1; i >= 0; i--) {
3183 const auto sub_fn = sub_fns[i];
3184 if (!selected_value) {
3185 selected_value = sub_fn.fn_ptr;
3186 } else {
3187 Value *c = codegen(sub_fn.cond);
3188 selected_value = builder->CreateSelect(c, sub_fn.fn_ptr, selected_value);
3189 }
3190 }
3191 builder->CreateStore(selected_value, global);
3192 builder->CreateBr(call_fn_bb);
3193
3194 // Just an incoming edge for the Phi node
3195 builder->SetInsertPoint(global_inited_bb);
3196 builder->CreateBr(call_fn_bb);
3197
3198 builder->SetInsertPoint(call_fn_bb);
3199 PHINode *phi = builder->CreatePHI(selected_value->getType(), 2);
3200 phi->addIncoming(selected_value, global_not_inited_bb);
3201 phi->addIncoming(loaded_value, global_inited_bb);
3202
3203 std::vector<llvm::Value *> call_args;
3204 for (auto &arg : function->args()) {
3205 call_args.push_back(&arg);
3206 }
3207
3208 llvm::CallInst *call = builder->CreateCall(base_fn->getFunctionType(), phi, call_args);
3209 value = call;
3210 } else if (op->is_intrinsic(Call::prefetch)) {
3211 user_assert((op->args.size() == 4) && is_one(op->args[2]))
3212 << "Only prefetch of 1 cache line is supported.\n";
3213
3214 llvm::Function *prefetch_fn = module->getFunction("_halide_prefetch");
3215 internal_assert(prefetch_fn);
3216
3217 vector<llvm::Value *> args;
3218 args.push_back(codegen_buffer_pointer(codegen(op->args[0]), op->type, op->args[1]));
3219 // The first argument is a pointer, which has type i8*. We
3220 // need to cast the argument, which might be a pointer to a
3221 // different type.
3222 llvm::Type *ptr_type = prefetch_fn->getFunctionType()->params()[0];
3223 args[0] = builder->CreateBitCast(args[0], ptr_type);
3224
3225 value = builder->CreateCall(prefetch_fn, args);
3226
3227 } else if (op->is_intrinsic(Call::signed_integer_overflow)) {
3228 user_error << "Signed integer overflow occurred during constant-folding. Signed"
3229 " integer overflow for int32 and int64 is undefined behavior in"
3230 " Halide.\n";
3231 } else if (op->is_intrinsic(Call::undef)) {
3232 value = UndefValue::get(llvm_type_of(op->type));
3233 } else if (op->is_intrinsic(Call::size_of_halide_buffer_t)) {
3234 llvm::DataLayout d(module.get());
3235 value = ConstantInt::get(i32_t, (int)d.getTypeAllocSize(halide_buffer_t_type));
3236 } else if (op->is_intrinsic(Call::strict_float)) {
3237 IRBuilder<llvm::ConstantFolder, llvm::IRBuilderDefaultInserter>::FastMathFlagGuard guard(*builder);
3238 llvm::FastMathFlags safe_flags;
3239 safe_flags.clear();
3240 builder->setFastMathFlags(safe_flags);
3241 builder->setDefaultFPMathTag(strict_fp_math_md);
3242 value = codegen(op->args[0]);
3243 } else if (is_float16_transcendental(op)) {
3244 value = codegen(lower_float16_transcendental_to_float32_equivalent(op));
3245 } else if (op->is_intrinsic()) {
3246 internal_error << "Unknown intrinsic: " << op->name << "\n";
3247 } else if (op->call_type == Call::PureExtern && op->name == "pow_f32") {
3248 internal_assert(op->args.size() == 2);
3249 Expr x = op->args[0];
3250 Expr y = op->args[1];
3251 Halide::Expr abs_x_pow_y = Internal::halide_exp(Internal::halide_log(abs(x)) * y);
3252 Halide::Expr nan_expr = Call::make(x.type(), "nan_f32", {}, Call::PureExtern);
3253 Expr iy = floor(y);
3254 Expr one = make_one(x.type());
3255 Expr zero = make_zero(x.type());
3256 Expr e = select(x > 0, abs_x_pow_y, // Strictly positive x
3257 y == 0.0f, one, // x^0 == 1
3258 x == 0.0f, zero, // 0^y == 0
3259 y != iy, nan_expr, // negative x to a non-integer power
3260 iy % 2 == 0, abs_x_pow_y, // negative x to an even power
3261 -abs_x_pow_y); // negative x to an odd power
3262 e = common_subexpression_elimination(e);
3263 e.accept(this);
3264 } else if (op->call_type == Call::PureExtern && op->name == "log_f32") {
3265 internal_assert(op->args.size() == 1);
3266 Expr e = Internal::halide_log(op->args[0]);
3267 e.accept(this);
3268 } else if (op->call_type == Call::PureExtern && op->name == "exp_f32") {
3269 internal_assert(op->args.size() == 1);
3270 Expr e = Internal::halide_exp(op->args[0]);
3271 e.accept(this);
3272 } else if (op->call_type == Call::PureExtern &&
3273 (op->name == "is_nan_f32" || op->name == "is_nan_f64")) {
3274 internal_assert(op->args.size() == 1);
3275 Value *a = codegen(op->args[0]);
3276
3277 /* NaNs are not supposed to exist in "no NaNs" compilation
3278 * mode, but it appears llvm special cases the unordered
3279 * compare instruction when the global NoNaNsFPMath option is
3280 * set and still checks for a NaN. However if the nnan flag is
3281 * set on the instruction itself, llvm treats the comparison
3282 * as always false. Thus we always turn off the per-instruction
3283 * fast-math flags for this instruction. I.e. it is always
3284 * treated as strict. Note that compilation may still be in
3285 * fast-math mode due to global options, but that's ok due to
3286 * the aforementioned special casing. */
3287 IRBuilder<llvm::ConstantFolder, llvm::IRBuilderDefaultInserter>::FastMathFlagGuard guard(*builder);
3288 llvm::FastMathFlags safe_flags;
3289 safe_flags.clear();
3290 builder->setFastMathFlags(safe_flags);
3291 builder->setDefaultFPMathTag(strict_fp_math_md);
3292
3293 value = builder->CreateFCmpUNO(a, a);
3294 } else if (op->call_type == Call::PureExtern &&
3295 (op->name == "is_inf_f32" || op->name == "is_inf_f64")) {
3296 internal_assert(op->args.size() == 1);
3297
3298 IRBuilder<llvm::ConstantFolder, llvm::IRBuilderDefaultInserter>::FastMathFlagGuard guard(*builder);
3299 llvm::FastMathFlags safe_flags;
3300 safe_flags.clear();
3301 builder->setFastMathFlags(safe_flags);
3302 builder->setDefaultFPMathTag(strict_fp_math_md);
3303
3304 // isinf(e) -> (fabs(e) == infinity)
3305 Expr e = op->args[0];
3306 internal_assert(e.type().is_float());
3307 Expr inf = e.type().max();
3308 codegen(abs(e) == inf);
3309 } else if (op->call_type == Call::PureExtern &&
3310 (op->name == "is_finite_f32" || op->name == "is_finite_f64")) {
3311 internal_assert(op->args.size() == 1);
3312 internal_assert(op->args[0].type().is_float());
3313
3314 IRBuilder<llvm::ConstantFolder, llvm::IRBuilderDefaultInserter>::FastMathFlagGuard guard(*builder);
3315 llvm::FastMathFlags safe_flags;
3316 safe_flags.clear();
3317 builder->setFastMathFlags(safe_flags);
3318 builder->setDefaultFPMathTag(strict_fp_math_md);
3319
3320 // isfinite(e) -> (fabs(e) != infinity && !isnan(e)) -> (fabs(e) != infinity && e == e)
3321 Expr e = op->args[0];
3322 internal_assert(e.type().is_float());
3323 Expr inf = e.type().max();
3324 codegen(abs(e) != inf && e == e);
3325 } else {
3326 // It's an extern call.
3327
3328 std::string name;
3329 if (op->call_type == Call::ExternCPlusPlus) {
3330 user_assert(get_target().has_feature(Target::CPlusPlusMangling)) << "Target must specify C++ name mangling (\"c_plus_plus_name_mangling\") in order to call C++ externs. (" << op->name << ")\n";
3331
3332 std::vector<std::string> namespaces;
3333 name = extract_namespaces(op->name, namespaces);
3334 std::vector<ExternFuncArgument> mangle_args;
3335 for (const auto &arg : op->args) {
3336 mangle_args.emplace_back(arg);
3337 }
3338 name = cplusplus_function_mangled_name(name, namespaces, op->type, mangle_args, get_target());
3339 } else {
3340 name = op->name;
3341 }
3342
3343 // Codegen the args
3344 vector<Value *> args(op->args.size());
3345 for (size_t i = 0; i < op->args.size(); i++) {
3346 args[i] = codegen(op->args[i]);
3347 }
3348
3349 llvm::Function *fn = module->getFunction(name);
3350
3351 llvm::Type *result_type = llvm_type_of(upgrade_type_for_argument_passing(op->type));
3352
3353 // Add a user context arg as needed. It's never a vector.
3354 bool takes_user_context = function_takes_user_context(op->name);
3355 if (takes_user_context) {
3356 internal_assert(fn) << "External function " << op->name << " is marked as taking user_context, but is not in the runtime module. Check if runtime_api.cpp needs to be rebuilt.\n";
3357 debug(4) << "Adding user_context to " << op->name << " args\n";
3358 args.insert(args.begin(), get_user_context());
3359 }
3360
3361 // If we can't find it, declare it extern "C"
3362 if (!fn) {
3363 vector<llvm::Type *> arg_types(args.size());
3364 for (size_t i = 0; i < args.size(); i++) {
3365 arg_types[i] = args[i]->getType();
3366 if (arg_types[i]->isVectorTy()) {
3367 VectorType *vt = dyn_cast<VectorType>(arg_types[i]);
3368 arg_types[i] = vt->getElementType();
3369 }
3370 }
3371
3372 llvm::Type *scalar_result_type = result_type;
3373 if (result_type->isVectorTy()) {
3374 VectorType *vt = dyn_cast<VectorType>(result_type);
3375 scalar_result_type = vt->getElementType();
3376 }
3377
3378 FunctionType *func_t = FunctionType::get(scalar_result_type, arg_types, false);
3379
3380 fn = llvm::Function::Create(func_t, llvm::Function::ExternalLinkage, name, module.get());
3381 fn->setCallingConv(CallingConv::C);
3382 debug(4) << "Did not find " << op->name << ". Declared it extern \"C\".\n";
3383 } else {
3384 debug(4) << "Found " << op->name << "\n";
3385
3386 // TODO: Say something more accurate here as there is now
3387 // partial information in the handle_type field, but it is
3388 // not clear it can be matched to the LLVM types and it is
3389 // not always there.
3390 // Halide's type system doesn't preserve pointer types
3391 // correctly (they just get called "Handle()"), so we may
3392 // need to pointer cast to the appropriate type. Only look at
3393 // fixed params (not varags) in llvm function.
3394 FunctionType *func_t = fn->getFunctionType();
3395 for (size_t i = takes_user_context ? 1 : 0;
3396 i < std::min(args.size(), (size_t)(func_t->getNumParams()));
3397 i++) {
3398 Expr halide_arg = takes_user_context ? op->args[i - 1] : op->args[i];
3399 if (halide_arg.type().is_handle()) {
3400 llvm::Type *t = func_t->getParamType(i);
3401
3402 // Widen to vector-width as needed. If the
3403 // function doesn't actually take a vector,
3404 // individual lanes will be extracted below.
3405 if (halide_arg.type().is_vector() &&
3406 !t->isVectorTy()) {
3407 t = get_vector_type(t, halide_arg.type().lanes());
3408 }
3409
3410 if (t != args[i]->getType()) {
3411 debug(4) << "Pointer casting argument to extern call: "
3412 << halide_arg << "\n";
3413 args[i] = builder->CreatePointerCast(args[i], t);
3414 }
3415 }
3416 }
3417 }
3418
3419 if (op->type.is_scalar()) {
3420 CallInst *call = builder->CreateCall(fn, args);
3421 if (op->is_pure()) {
3422 call->setDoesNotAccessMemory();
3423 }
3424 call->setDoesNotThrow();
3425 value = call;
3426 } else {
3427
3428 // Check if a vector version of the function already
3429 // exists at some useful width.
3430 pair<llvm::Function *, int> vec =
3431 find_vector_runtime_function(name, op->type.lanes());
3432 llvm::Function *vec_fn = vec.first;
3433 int w = vec.second;
3434
3435 if (vec_fn) {
3436 value = call_intrin(llvm_type_of(op->type), w,
3437 get_llvm_function_name(vec_fn), args);
3438 } else {
3439
3440 // No vector version found. Scalarize. Extract each simd
3441 // lane in turn and do one scalar call to the function.
3442 value = UndefValue::get(result_type);
3443 for (int i = 0; i < op->type.lanes(); i++) {
3444 Value *idx = ConstantInt::get(i32_t, i);
3445 vector<Value *> arg_lane(args.size());
3446 for (size_t j = 0; j < args.size(); j++) {
3447 if (args[j]->getType()->isVectorTy()) {
3448 arg_lane[j] = builder->CreateExtractElement(args[j], idx);
3449 } else {
3450 arg_lane[j] = args[j];
3451 }
3452 }
3453 CallInst *call = builder->CreateCall(fn, arg_lane);
3454 if (op->is_pure()) {
3455 call->setDoesNotAccessMemory();
3456 }
3457 call->setDoesNotThrow();
3458 if (!call->getType()->isVoidTy()) {
3459 value = builder->CreateInsertElement(value, call, idx);
3460 } // otherwise leave it as undef.
3461 }
3462 }
3463 }
3464 }
3465 }
3466
visit(const Prefetch * op)3467 void CodeGen_LLVM::visit(const Prefetch *op) {
3468 internal_error << "Prefetch encountered during codegen\n";
3469 }
3470
visit(const Let * op)3471 void CodeGen_LLVM::visit(const Let *op) {
3472 sym_push(op->name, codegen(op->value));
3473 value = codegen(op->body);
3474 sym_pop(op->name);
3475 }
3476
visit(const LetStmt * op)3477 void CodeGen_LLVM::visit(const LetStmt *op) {
3478 sym_push(op->name, codegen(op->value));
3479 codegen(op->body);
3480 sym_pop(op->name);
3481 }
3482
visit(const AssertStmt * op)3483 void CodeGen_LLVM::visit(const AssertStmt *op) {
3484 create_assertion(codegen(op->condition), op->message);
3485 }
3486
create_string_constant(const string & s)3487 Constant *CodeGen_LLVM::create_string_constant(const string &s) {
3488 map<string, Constant *>::iterator iter = string_constants.find(s);
3489 if (iter == string_constants.end()) {
3490 vector<char> data;
3491 data.reserve(s.size() + 1);
3492 data.insert(data.end(), s.begin(), s.end());
3493 data.push_back(0);
3494 Constant *val = create_binary_blob(data, "str");
3495 string_constants[s] = val;
3496 return val;
3497 } else {
3498 return iter->second;
3499 }
3500 }
3501
create_binary_blob(const vector<char> & data,const string & name,bool constant)3502 Constant *CodeGen_LLVM::create_binary_blob(const vector<char> &data, const string &name, bool constant) {
3503 internal_assert(!data.empty());
3504 llvm::Type *type = ArrayType::get(i8_t, data.size());
3505 GlobalVariable *global = new GlobalVariable(*module, type,
3506 constant, GlobalValue::PrivateLinkage,
3507 0, name);
3508 ArrayRef<unsigned char> data_array((const unsigned char *)&data[0], data.size());
3509 global->setInitializer(ConstantDataArray::get(*context, data_array));
3510 size_t alignment = 32;
3511 size_t native_vector_bytes = (size_t)(native_vector_bits() / 8);
3512 if (data.size() > alignment && native_vector_bytes > alignment) {
3513 alignment = native_vector_bytes;
3514 }
3515 global->setAlignment(make_alignment(alignment));
3516
3517 Constant *zero = ConstantInt::get(i32_t, 0);
3518 Constant *zeros[] = {zero, zero};
3519 Constant *ptr = ConstantExpr::getInBoundsGetElementPtr(type, global, zeros);
3520 return ptr;
3521 }
3522
create_assertion(Value * cond,const Expr & message,llvm::Value * error_code)3523 void CodeGen_LLVM::create_assertion(Value *cond, const Expr &message, llvm::Value *error_code) {
3524
3525 internal_assert(!message.defined() || message.type() == Int(32))
3526 << "Assertion result is not an int: " << message;
3527
3528 if (target.has_feature(Target::NoAsserts)) return;
3529
3530 // If the condition is a vector, fold it down to a scalar
3531 VectorType *vt = dyn_cast<VectorType>(cond->getType());
3532 if (vt) {
3533 Value *scalar_cond = builder->CreateExtractElement(cond, ConstantInt::get(i32_t, 0));
3534 for (int i = 1; i < get_vector_num_elements(vt); i++) {
3535 Value *lane = builder->CreateExtractElement(cond, ConstantInt::get(i32_t, i));
3536 scalar_cond = builder->CreateAnd(scalar_cond, lane);
3537 }
3538 cond = scalar_cond;
3539 }
3540
3541 // Make a new basic block for the assert
3542 BasicBlock *assert_fails_bb = BasicBlock::Create(*context, "assert failed", function);
3543 BasicBlock *assert_succeeds_bb = BasicBlock::Create(*context, "assert succeeded", function);
3544
3545 // If the condition fails, enter the assert body, otherwise, enter the block after
3546 builder->CreateCondBr(cond, assert_succeeds_bb, assert_fails_bb, very_likely_branch);
3547
3548 // Build the failure case
3549 builder->SetInsertPoint(assert_fails_bb);
3550
3551 // Call the error handler
3552 if (!error_code) error_code = codegen(message);
3553
3554 return_with_error_code(error_code);
3555
3556 // Continue on using the success case
3557 builder->SetInsertPoint(assert_succeeds_bb);
3558 }
3559
return_with_error_code(llvm::Value * error_code)3560 void CodeGen_LLVM::return_with_error_code(llvm::Value *error_code) {
3561 // Branch to the destructor block, which cleans up and then bails out.
3562 BasicBlock *dtors = get_destructor_block();
3563
3564 // Hook up our error code to the phi node that the destructor block starts with.
3565 PHINode *phi = dyn_cast<PHINode>(dtors->begin());
3566 internal_assert(phi) << "The destructor block is supposed to start with a phi node\n";
3567 phi->addIncoming(error_code, builder->GetInsertBlock());
3568
3569 builder->CreateBr(get_destructor_block());
3570 }
3571
visit(const ProducerConsumer * op)3572 void CodeGen_LLVM::visit(const ProducerConsumer *op) {
3573 string name;
3574 if (op->is_producer) {
3575 name = std::string("produce ") + op->name;
3576 } else {
3577 name = std::string("consume ") + op->name;
3578 }
3579 BasicBlock *produce = BasicBlock::Create(*context, name, function);
3580 builder->CreateBr(produce);
3581 builder->SetInsertPoint(produce);
3582 codegen(op->body);
3583 }
3584
visit(const For * op)3585 void CodeGen_LLVM::visit(const For *op) {
3586 Value *min = codegen(op->min);
3587 Value *extent = codegen(op->extent);
3588 const Acquire *acquire = op->body.as<Acquire>();
3589
3590 if (op->for_type == ForType::Parallel ||
3591 (op->for_type == ForType::Serial &&
3592 acquire &&
3593 !expr_uses_var(acquire->count, op->name))) {
3594 do_as_parallel_task(op);
3595 } else if (op->for_type == ForType::Serial) {
3596
3597 Value *max = builder->CreateNSWAdd(min, extent);
3598
3599 BasicBlock *preheader_bb = builder->GetInsertBlock();
3600
3601 // Make a new basic block for the loop
3602 BasicBlock *loop_bb = BasicBlock::Create(*context, std::string("for ") + op->name, function);
3603 // Create the block that comes after the loop
3604 BasicBlock *after_bb = BasicBlock::Create(*context, std::string("end for ") + op->name, function);
3605
3606 // If min < max, fall through to the loop bb
3607 Value *enter_condition = builder->CreateICmpSLT(min, max);
3608 builder->CreateCondBr(enter_condition, loop_bb, after_bb, very_likely_branch);
3609 builder->SetInsertPoint(loop_bb);
3610
3611 // Make our phi node.
3612 PHINode *phi = builder->CreatePHI(i32_t, 2);
3613 phi->addIncoming(min, preheader_bb);
3614
3615 // Within the loop, the variable is equal to the phi value
3616 sym_push(op->name, phi);
3617
3618 // Emit the loop body
3619 codegen(op->body);
3620
3621 // Update the counter
3622 Value *next_var = builder->CreateNSWAdd(phi, ConstantInt::get(i32_t, 1));
3623
3624 // Add the back-edge to the phi node
3625 phi->addIncoming(next_var, builder->GetInsertBlock());
3626
3627 // Maybe exit the loop
3628 Value *end_condition = builder->CreateICmpNE(next_var, max);
3629 builder->CreateCondBr(end_condition, loop_bb, after_bb);
3630
3631 builder->SetInsertPoint(after_bb);
3632
3633 // Pop the loop variable from the scope
3634 sym_pop(op->name);
3635 } else {
3636 internal_error << "Unknown type of For node. Only Serial and Parallel For nodes should survive down to codegen.\n";
3637 }
3638 }
3639
do_parallel_tasks(const vector<ParallelTask> & tasks)3640 void CodeGen_LLVM::do_parallel_tasks(const vector<ParallelTask> &tasks) {
3641 Closure closure;
3642 for (const auto &t : tasks) {
3643 Stmt s = t.body;
3644 if (!t.loop_var.empty()) {
3645 s = LetStmt::make(t.loop_var, 0, s);
3646 }
3647 s.accept(&closure);
3648 }
3649
3650 // Allocate a closure
3651 StructType *closure_t = build_closure_type(closure, halide_buffer_t_type, context);
3652 Value *closure_ptr = create_alloca_at_entry(closure_t, 1);
3653
3654 // Fill in the closure
3655 pack_closure(closure_t, closure_ptr, closure, symbol_table, halide_buffer_t_type, builder);
3656
3657 closure_ptr = builder->CreatePointerCast(closure_ptr, i8_t->getPointerTo());
3658
3659 int num_tasks = (int)tasks.size();
3660
3661 // Make space on the stack for the tasks
3662 llvm::Value *task_stack_ptr = create_alloca_at_entry(parallel_task_t_type, num_tasks);
3663
3664 llvm::Type *args_t[] = {i8_t->getPointerTo(), i32_t, i8_t->getPointerTo()};
3665 FunctionType *task_t = FunctionType::get(i32_t, args_t, false);
3666 llvm::Type *loop_args_t[] = {i8_t->getPointerTo(), i32_t, i32_t, i8_t->getPointerTo(), i8_t->getPointerTo()};
3667 FunctionType *loop_task_t = FunctionType::get(i32_t, loop_args_t, false);
3668
3669 Value *result = nullptr;
3670
3671 for (int i = 0; i < num_tasks; i++) {
3672 ParallelTask t = tasks[i];
3673
3674 // Analyze the task body
3675 class MayBlock : public IRVisitor {
3676 using IRVisitor::visit;
3677 void visit(const Acquire *op) override {
3678 result = true;
3679 }
3680
3681 public:
3682 bool result = false;
3683 };
3684
3685 // TODO(zvookin|abadams): This makes multiple passes over the
3686 // IR to cover each node. (One tree walk produces the min
3687 // thread count for all nodes, but we redo each subtree when
3688 // compiling a given node.) Ideally we'd move to a lowering pass
3689 // that converts our parallelism constructs to Call nodes, or
3690 // direct hardware operations in some cases.
3691 // Also, this code has to exactly mirror the logic in get_parallel_tasks.
3692 // It would be better to do one pass on the tree and centralize the task
3693 // deduction logic in one place.
3694 class MinThreads : public IRVisitor {
3695 using IRVisitor::visit;
3696
3697 std::pair<Stmt, int> skip_acquires(Stmt first) {
3698 int count = 0;
3699 while (first.defined()) {
3700 const Acquire *acq = first.as<Acquire>();
3701 if (acq == nullptr) {
3702 break;
3703 }
3704 count++;
3705 first = acq->body;
3706 }
3707 return {first, count};
3708 }
3709
3710 void visit(const Fork *op) override {
3711 int total_threads = 0;
3712 int direct_acquires = 0;
3713 // Take the sum of min threads across all
3714 // cascaded Fork nodes.
3715 const Fork *node = op;
3716 while (node != NULL) {
3717 result = 0;
3718 auto after_acquires = skip_acquires(node->first);
3719 direct_acquires += after_acquires.second;
3720
3721 after_acquires.first.accept(this);
3722 total_threads += result;
3723
3724 const Fork *continued_branches = node->rest.as<Fork>();
3725 if (continued_branches == NULL) {
3726 result = 0;
3727 after_acquires = skip_acquires(node->rest);
3728 direct_acquires += after_acquires.second;
3729 after_acquires.first.accept(this);
3730 total_threads += result;
3731 }
3732 node = continued_branches;
3733 }
3734 if (direct_acquires == 0 && total_threads == 0) {
3735 result = 0;
3736 } else {
3737 result = total_threads + 1;
3738 }
3739 }
3740
3741 void visit(const For *op) override {
3742 result = 0;
3743
3744 if (op->for_type == ForType::Parallel) {
3745 IRVisitor::visit(op);
3746 if (result > 0) {
3747 result += 1;
3748 }
3749 } else if (op->for_type == ForType::Serial) {
3750 auto after_acquires = skip_acquires(op->body);
3751 if (after_acquires.second > 0 &&
3752 !expr_uses_var(op->body.as<Acquire>()->count, op->name)) {
3753 after_acquires.first.accept(this);
3754 result++;
3755 } else {
3756 IRVisitor::visit(op);
3757 }
3758 } else {
3759 IRVisitor::visit(op);
3760 }
3761 }
3762
3763 // This is a "standalone" Acquire and will result in its own task.
3764 // Treat it requiring one more thread than its body.
3765 void visit(const Acquire *op) override {
3766 result = 0;
3767 auto after_inner_acquires = skip_acquires(op);
3768 after_inner_acquires.first.accept(this);
3769 result = result + 1;
3770 }
3771
3772 void visit(const Block *op) override {
3773 result = 0;
3774 op->first.accept(this);
3775 int result_first = result;
3776 result = 0;
3777 op->rest.accept(this);
3778 result = std::max(result, result_first);
3779 }
3780
3781 public:
3782 int result = 0;
3783 };
3784 MinThreads min_threads;
3785 t.body.accept(&min_threads);
3786
3787 // Decide if we're going to call do_par_for or
3788 // do_parallel_tasks. halide_do_par_for is simpler, but
3789 // assumes a bunch of things. Programs that don't use async
3790 // can also enter the task system via do_par_for.
3791 Value *task_parent = sym_get("__task_parent", false);
3792 bool use_do_par_for = (num_tasks == 1 &&
3793 min_threads.result == 0 &&
3794 t.semaphores.empty() &&
3795 !task_parent);
3796
3797 // Make the array of semaphore acquisitions this task needs to do before it runs.
3798 Value *semaphores;
3799 Value *num_semaphores = ConstantInt::get(i32_t, (int)t.semaphores.size());
3800 if (!t.semaphores.empty()) {
3801 semaphores = create_alloca_at_entry(semaphore_acquire_t_type, (int)t.semaphores.size());
3802 for (int i = 0; i < (int)t.semaphores.size(); i++) {
3803 Value *semaphore = codegen(t.semaphores[i].semaphore);
3804 semaphore = builder->CreatePointerCast(semaphore, semaphore_t_type->getPointerTo());
3805 Value *count = codegen(t.semaphores[i].count);
3806 Value *slot_ptr = builder->CreateConstGEP2_32(semaphore_acquire_t_type, semaphores, i, 0);
3807 builder->CreateStore(semaphore, slot_ptr);
3808 slot_ptr = builder->CreateConstGEP2_32(semaphore_acquire_t_type, semaphores, i, 1);
3809 builder->CreateStore(count, slot_ptr);
3810 }
3811 } else {
3812 semaphores = ConstantPointerNull::get(semaphore_acquire_t_type->getPointerTo());
3813 }
3814
3815 FunctionType *fn_type = use_do_par_for ? task_t : loop_task_t;
3816 int closure_arg_idx = use_do_par_for ? 2 : 3;
3817
3818 // Make a new function that does the body
3819 llvm::Function *containing_function = function;
3820 function = llvm::Function::Create(fn_type, llvm::Function::InternalLinkage,
3821 t.name, module.get());
3822
3823 llvm::Value *task_ptr = builder->CreatePointerCast(function, fn_type->getPointerTo());
3824
3825 function->addParamAttr(closure_arg_idx, Attribute::NoAlias);
3826
3827 set_function_attributes_for_target(function, target);
3828
3829 // Make the initial basic block and jump the builder into the new function
3830 IRBuilderBase::InsertPoint call_site = builder->saveIP();
3831 BasicBlock *block = BasicBlock::Create(*context, "entry", function);
3832 builder->SetInsertPoint(block);
3833
3834 // Save the destructor block
3835 BasicBlock *parent_destructor_block = destructor_block;
3836 destructor_block = nullptr;
3837
3838 // Make a new scope to use
3839 Scope<Value *> saved_symbol_table;
3840 symbol_table.swap(saved_symbol_table);
3841
3842 // Get the function arguments
3843
3844 // The user context is first argument of the function; it's
3845 // important that we override the name to be "__user_context",
3846 // since the LLVM function has a random auto-generated name for
3847 // this argument.
3848 llvm::Function::arg_iterator iter = function->arg_begin();
3849 sym_push("__user_context", iterator_to_pointer(iter));
3850
3851 if (use_do_par_for) {
3852 // Next is the loop variable.
3853 ++iter;
3854 sym_push(t.loop_var, iterator_to_pointer(iter));
3855 } else if (!t.loop_var.empty()) {
3856 // We peeled off a loop. Wrap a new loop around the body
3857 // that just does the slice given by the arguments.
3858 string loop_min_name = unique_name('t');
3859 string loop_extent_name = unique_name('t');
3860 t.body = For::make(t.loop_var,
3861 Variable::make(Int(32), loop_min_name),
3862 Variable::make(Int(32), loop_extent_name),
3863 ForType::Serial,
3864 DeviceAPI::None,
3865 t.body);
3866 ++iter;
3867 sym_push(loop_min_name, iterator_to_pointer(iter));
3868 ++iter;
3869 sym_push(loop_extent_name, iterator_to_pointer(iter));
3870 } else {
3871 // This task is not any kind of loop, so skip these args.
3872 ++iter;
3873 ++iter;
3874 }
3875
3876 // The closure pointer is either the last (for halide_do_par_for) or
3877 // second to last argument (for halide_do_parallel_tasks).
3878 ++iter;
3879 iter->setName("closure");
3880 Value *closure_handle = builder->CreatePointerCast(iterator_to_pointer(iter),
3881 closure_t->getPointerTo());
3882
3883 // Load everything from the closure into the new scope
3884 unpack_closure(closure, symbol_table, closure_t, closure_handle, builder);
3885
3886 if (!use_do_par_for) {
3887 // For halide_do_parallel_tasks the threading runtime task parent
3888 // is the last argument.
3889 ++iter;
3890 iter->setName("task_parent");
3891 sym_push("__task_parent", iterator_to_pointer(iter));
3892 }
3893
3894 // Generate the new function body
3895 codegen(t.body);
3896
3897 // Return success
3898 return_with_error_code(ConstantInt::get(i32_t, 0));
3899
3900 // Move the builder back to the main function.
3901 builder->restoreIP(call_site);
3902
3903 // Now restore the scope
3904 symbol_table.swap(saved_symbol_table);
3905 function = containing_function;
3906
3907 // Restore the destructor block
3908 destructor_block = parent_destructor_block;
3909
3910 Value *min = codegen(t.min);
3911 Value *extent = codegen(t.extent);
3912 Value *serial = codegen(cast(UInt(8), t.serial));
3913
3914 if (use_do_par_for) {
3915 llvm::Function *do_par_for = module->getFunction("halide_do_par_for");
3916 internal_assert(do_par_for) << "Could not find halide_do_par_for in initial module\n";
3917 do_par_for->addParamAttr(4, Attribute::NoAlias);
3918 Value *args[] = {get_user_context(), task_ptr, min, extent, closure_ptr};
3919 debug(4) << "Creating call to do_par_for\n";
3920 result = builder->CreateCall(do_par_for, args);
3921 } else {
3922 // Populate the task struct
3923 Value *slot_ptr = builder->CreateConstGEP2_32(parallel_task_t_type, task_stack_ptr, i, 0);
3924 builder->CreateStore(task_ptr, slot_ptr);
3925 slot_ptr = builder->CreateConstGEP2_32(parallel_task_t_type, task_stack_ptr, i, 1);
3926 builder->CreateStore(closure_ptr, slot_ptr);
3927 slot_ptr = builder->CreateConstGEP2_32(parallel_task_t_type, task_stack_ptr, i, 2);
3928 builder->CreateStore(create_string_constant(t.name), slot_ptr);
3929 slot_ptr = builder->CreateConstGEP2_32(parallel_task_t_type, task_stack_ptr, i, 3);
3930 builder->CreateStore(semaphores, slot_ptr);
3931 slot_ptr = builder->CreateConstGEP2_32(parallel_task_t_type, task_stack_ptr, i, 4);
3932 builder->CreateStore(num_semaphores, slot_ptr);
3933 slot_ptr = builder->CreateConstGEP2_32(parallel_task_t_type, task_stack_ptr, i, 5);
3934 builder->CreateStore(min, slot_ptr);
3935 slot_ptr = builder->CreateConstGEP2_32(parallel_task_t_type, task_stack_ptr, i, 6);
3936 builder->CreateStore(extent, slot_ptr);
3937 slot_ptr = builder->CreateConstGEP2_32(parallel_task_t_type, task_stack_ptr, i, 7);
3938 builder->CreateStore(ConstantInt::get(i32_t, min_threads.result), slot_ptr);
3939 slot_ptr = builder->CreateConstGEP2_32(parallel_task_t_type, task_stack_ptr, i, 8);
3940 builder->CreateStore(serial, slot_ptr);
3941 }
3942 }
3943
3944 if (!result) {
3945 llvm::Function *do_parallel_tasks = module->getFunction("halide_do_parallel_tasks");
3946 internal_assert(do_parallel_tasks) << "Could not find halide_do_parallel_tasks in initial module\n";
3947 do_parallel_tasks->addParamAttr(2, Attribute::NoAlias);
3948 Value *task_parent = sym_get("__task_parent", false);
3949 if (!task_parent) {
3950 task_parent = ConstantPointerNull::get(i8_t->getPointerTo()); // void*
3951 }
3952 Value *args[] = {get_user_context(),
3953 ConstantInt::get(i32_t, num_tasks),
3954 task_stack_ptr,
3955 task_parent};
3956 result = builder->CreateCall(do_parallel_tasks, args);
3957 }
3958
3959 // Check for success
3960 Value *did_succeed = builder->CreateICmpEQ(result, ConstantInt::get(i32_t, 0));
3961 create_assertion(did_succeed, Expr(), result);
3962 }
3963
3964 namespace {
3965
task_debug_name(const std::pair<string,int> & prefix)3966 string task_debug_name(const std::pair<string, int> &prefix) {
3967 if (prefix.second <= 1) {
3968 return prefix.first;
3969 } else {
3970 return prefix.first + "_" + std::to_string(prefix.second - 1);
3971 }
3972 }
3973
add_fork(std::pair<string,int> & prefix)3974 void add_fork(std::pair<string, int> &prefix) {
3975 if (prefix.second == 0) {
3976 prefix.first += ".fork";
3977 }
3978 prefix.second++;
3979 }
3980
add_suffix(std::pair<string,int> & prefix,const string & suffix)3981 void add_suffix(std::pair<string, int> &prefix, const string &suffix) {
3982 if (prefix.second > 1) {
3983 prefix.first += "_" + std::to_string(prefix.second - 1);
3984 prefix.second = 0;
3985 }
3986 prefix.first += suffix;
3987 }
3988
3989 } // namespace
3990
get_parallel_tasks(const Stmt & s,vector<ParallelTask> & result,std::pair<string,int> prefix)3991 void CodeGen_LLVM::get_parallel_tasks(const Stmt &s, vector<ParallelTask> &result, std::pair<string, int> prefix) {
3992 const For *loop = s.as<For>();
3993 const Acquire *acquire = loop ? loop->body.as<Acquire>() : s.as<Acquire>();
3994 if (const Fork *f = s.as<Fork>()) {
3995 add_fork(prefix);
3996 get_parallel_tasks(f->first, result, prefix);
3997 get_parallel_tasks(f->rest, result, prefix);
3998 } else if (!loop && acquire) {
3999 const Variable *v = acquire->semaphore.as<Variable>();
4000 internal_assert(v);
4001 add_suffix(prefix, "." + v->name);
4002 ParallelTask t{s, {}, "", 0, 1, const_false(), task_debug_name(prefix)};
4003 while (acquire) {
4004 t.semaphores.push_back({acquire->semaphore, acquire->count});
4005 t.body = acquire->body;
4006 acquire = t.body.as<Acquire>();
4007 }
4008 result.push_back(t);
4009 } else if (loop && loop->for_type == ForType::Parallel) {
4010 add_suffix(prefix, ".par_for." + loop->name);
4011 result.push_back(ParallelTask{loop->body, {}, loop->name, loop->min, loop->extent, const_false(), task_debug_name(prefix)});
4012 } else if (loop &&
4013 loop->for_type == ForType::Serial &&
4014 acquire &&
4015 !expr_uses_var(acquire->count, loop->name)) {
4016 const Variable *v = acquire->semaphore.as<Variable>();
4017 internal_assert(v);
4018 add_suffix(prefix, ".for." + v->name);
4019 ParallelTask t{loop->body, {}, loop->name, loop->min, loop->extent, const_true(), task_debug_name(prefix)};
4020 while (acquire) {
4021 t.semaphores.push_back({acquire->semaphore, acquire->count});
4022 t.body = acquire->body;
4023 acquire = t.body.as<Acquire>();
4024 }
4025 result.push_back(t);
4026 } else {
4027 add_suffix(prefix, "." + std::to_string(result.size()));
4028 result.push_back(ParallelTask{s, {}, "", 0, 1, const_false(), task_debug_name(prefix)});
4029 }
4030 }
4031
do_as_parallel_task(const Stmt & s)4032 void CodeGen_LLVM::do_as_parallel_task(const Stmt &s) {
4033 vector<ParallelTask> tasks;
4034 get_parallel_tasks(s, tasks, {function->getName().str(), 0});
4035 do_parallel_tasks(tasks);
4036 }
4037
visit(const Acquire * op)4038 void CodeGen_LLVM::visit(const Acquire *op) {
4039 do_as_parallel_task(op);
4040 }
4041
visit(const Fork * op)4042 void CodeGen_LLVM::visit(const Fork *op) {
4043 do_as_parallel_task(op);
4044 }
4045
visit(const Store * op)4046 void CodeGen_LLVM::visit(const Store *op) {
4047 Halide::Type value_type = op->value.type();
4048 Halide::Type storage_type = upgrade_type_for_storage(value_type);
4049 if (value_type != storage_type) {
4050 Expr v = reinterpret(storage_type, op->value);
4051 codegen(Store::make(op->name, v, op->index, op->param, op->predicate, op->alignment));
4052 return;
4053 }
4054
4055 if (inside_atomic_mutex_node) {
4056 user_assert(value_type.is_scalar())
4057 << "The vectorized atomic operation for the store " << op->name
4058 << " is lowered into a mutex lock, which does not support vectorization.\n";
4059 }
4060
4061 // Issue atomic store if we are inside an atomic node.
4062 if (emit_atomic_stores) {
4063 codegen_atomic_store(op);
4064 return;
4065 }
4066
4067 // Predicated store.
4068 if (!is_one(op->predicate)) {
4069 codegen_predicated_vector_store(op);
4070 return;
4071 }
4072
4073 Value *val = codegen(op->value);
4074 bool is_external = (external_buffer.find(op->name) != external_buffer.end());
4075 // Scalar
4076 if (value_type.is_scalar()) {
4077 Value *ptr = codegen_buffer_pointer(op->name, value_type, op->index);
4078 StoreInst *store = builder->CreateAlignedStore(val, ptr, make_alignment(value_type.bytes()));
4079 add_tbaa_metadata(store, op->name, op->index);
4080 } else if (const Let *let = op->index.as<Let>()) {
4081 Stmt s = Store::make(op->name, op->value, let->body, op->param, op->predicate, op->alignment);
4082 codegen(LetStmt::make(let->name, let->value, s));
4083 } else {
4084 int alignment = value_type.bytes();
4085 const Ramp *ramp = op->index.as<Ramp>();
4086 if (ramp && is_one(ramp->stride)) {
4087
4088 int native_bits = native_vector_bits();
4089 int native_bytes = native_bits / 8;
4090
4091 // Boost the alignment if possible, up to the native vector width.
4092 ModulusRemainder mod_rem = op->alignment;
4093 while ((mod_rem.remainder & 1) == 0 &&
4094 (mod_rem.modulus & 1) == 0 &&
4095 alignment < native_bytes) {
4096 mod_rem.modulus /= 2;
4097 mod_rem.remainder /= 2;
4098 alignment *= 2;
4099 }
4100
4101 // If it is an external buffer, then we cannot assume that the host pointer
4102 // is aligned to at least the native vector width. However, we may be able to do
4103 // better than just assuming that it is unaligned.
4104 if (is_external && op->param.defined()) {
4105 int host_alignment = op->param.host_alignment();
4106 alignment = gcd(alignment, host_alignment);
4107 }
4108
4109 // For dense vector stores wider than the native vector
4110 // width, bust them up into native vectors.
4111 int store_lanes = value_type.lanes();
4112 int native_lanes = native_bits / value_type.bits();
4113
4114 for (int i = 0; i < store_lanes; i += native_lanes) {
4115 int slice_lanes = std::min(native_lanes, store_lanes - i);
4116 Expr slice_base = simplify(ramp->base + i);
4117 Expr slice_stride = make_one(slice_base.type());
4118 Expr slice_index = slice_lanes == 1 ? slice_base : Ramp::make(slice_base, slice_stride, slice_lanes);
4119 Value *slice_val = slice_vector(val, i, slice_lanes);
4120 Value *elt_ptr = codegen_buffer_pointer(op->name, value_type.element_of(), slice_base);
4121 Value *vec_ptr = builder->CreatePointerCast(elt_ptr, slice_val->getType()->getPointerTo());
4122 StoreInst *store = builder->CreateAlignedStore(slice_val, vec_ptr, make_alignment(alignment));
4123 add_tbaa_metadata(store, op->name, slice_index);
4124 }
4125 } else if (ramp) {
4126 Type ptr_type = value_type.element_of();
4127 Value *ptr = codegen_buffer_pointer(op->name, ptr_type, ramp->base);
4128 const IntImm *const_stride = ramp->stride.as<IntImm>();
4129 Value *stride = codegen(ramp->stride);
4130 // Scatter without generating the indices as a vector
4131 for (int i = 0; i < ramp->lanes; i++) {
4132 Constant *lane = ConstantInt::get(i32_t, i);
4133 Value *v = builder->CreateExtractElement(val, lane);
4134 if (const_stride) {
4135 // Use a constant offset from the base pointer
4136 Value *p =
4137 builder->CreateConstInBoundsGEP1_32(
4138 llvm_type_of(ptr_type),
4139 ptr,
4140 const_stride->value * i);
4141 StoreInst *store = builder->CreateStore(v, p);
4142 add_tbaa_metadata(store, op->name, op->index);
4143 } else {
4144 // Increment the pointer by the stride for each element
4145 StoreInst *store = builder->CreateStore(v, ptr);
4146 add_tbaa_metadata(store, op->name, op->index);
4147 ptr = builder->CreateInBoundsGEP(ptr, stride);
4148 }
4149 }
4150 } else {
4151 // Scatter
4152 Value *index = codegen(op->index);
4153 for (int i = 0; i < value_type.lanes(); i++) {
4154 Value *lane = ConstantInt::get(i32_t, i);
4155 Value *idx = builder->CreateExtractElement(index, lane);
4156 Value *v = builder->CreateExtractElement(val, lane);
4157 Value *ptr = codegen_buffer_pointer(op->name, value_type.element_of(), idx);
4158 StoreInst *store = builder->CreateStore(v, ptr);
4159 add_tbaa_metadata(store, op->name, op->index);
4160 }
4161 }
4162 }
4163 }
4164
codegen_asserts(const vector<const AssertStmt * > & asserts)4165 void CodeGen_LLVM::codegen_asserts(const vector<const AssertStmt *> &asserts) {
4166 if (target.has_feature(Target::NoAsserts)) {
4167 return;
4168 }
4169
4170 if (asserts.size() < 4) {
4171 for (const auto *a : asserts) {
4172 codegen(Stmt(a));
4173 }
4174 return;
4175 }
4176
4177 internal_assert(asserts.size() <= 63);
4178
4179 // Mix all the conditions together into a bitmask
4180
4181 Expr bitmask = cast<uint64_t>(1) << 63;
4182 for (size_t i = 0; i < asserts.size(); i++) {
4183 bitmask = bitmask | (cast<uint64_t>(!asserts[i]->condition) << i);
4184 }
4185
4186 Expr switch_case = count_trailing_zeros(bitmask);
4187
4188 BasicBlock *no_errors_bb = BasicBlock::Create(*context, "no_errors_bb", function);
4189
4190 // Now switch on the bitmask to the correct failure
4191 Expr case_idx = cast<int32_t>(count_trailing_zeros(bitmask));
4192 llvm::SmallVector<uint32_t, 64> weights;
4193 weights.push_back(1 << 30);
4194 for (int i = 0; i < (int)asserts.size(); i++) {
4195 weights.push_back(0);
4196 }
4197 llvm::MDBuilder md_builder(*context);
4198 llvm::MDNode *switch_very_likely_branch = md_builder.createBranchWeights(weights);
4199 auto *switch_inst = builder->CreateSwitch(codegen(case_idx), no_errors_bb, asserts.size(), switch_very_likely_branch);
4200 for (int i = 0; i < (int)asserts.size(); i++) {
4201 BasicBlock *fail_bb = BasicBlock::Create(*context, "assert_failed", function);
4202 switch_inst->addCase(ConstantInt::get(IntegerType::get(*context, 32), i), fail_bb);
4203 builder->SetInsertPoint(fail_bb);
4204 Value *v = codegen(asserts[i]->message);
4205 builder->CreateRet(v);
4206 }
4207 builder->SetInsertPoint(no_errors_bb);
4208 }
4209
visit(const Block * op)4210 void CodeGen_LLVM::visit(const Block *op) {
4211 // Peel blocks of assertions with pure conditions
4212 const AssertStmt *a = op->first.as<AssertStmt>();
4213 if (a && is_pure(a->condition)) {
4214 vector<const AssertStmt *> asserts;
4215 asserts.push_back(a);
4216 Stmt s = op->rest;
4217 while ((op = s.as<Block>()) && (a = op->first.as<AssertStmt>()) && is_pure(a->condition) && asserts.size() < 63) {
4218 asserts.push_back(a);
4219 s = op->rest;
4220 }
4221 codegen_asserts(asserts);
4222 codegen(s);
4223 } else {
4224 codegen(op->first);
4225 codegen(op->rest);
4226 }
4227 }
4228
visit(const Realize * op)4229 void CodeGen_LLVM::visit(const Realize *op) {
4230 internal_error << "Realize encountered during codegen\n";
4231 }
4232
visit(const Provide * op)4233 void CodeGen_LLVM::visit(const Provide *op) {
4234 internal_error << "Provide encountered during codegen\n";
4235 }
4236
visit(const IfThenElse * op)4237 void CodeGen_LLVM::visit(const IfThenElse *op) {
4238 BasicBlock *true_bb = BasicBlock::Create(*context, "true_bb", function);
4239 BasicBlock *false_bb = BasicBlock::Create(*context, "false_bb", function);
4240 BasicBlock *after_bb = BasicBlock::Create(*context, "after_bb", function);
4241 builder->CreateCondBr(codegen(op->condition), true_bb, false_bb);
4242
4243 builder->SetInsertPoint(true_bb);
4244 codegen(op->then_case);
4245 builder->CreateBr(after_bb);
4246
4247 builder->SetInsertPoint(false_bb);
4248 if (op->else_case.defined()) {
4249 codegen(op->else_case);
4250 }
4251 builder->CreateBr(after_bb);
4252
4253 builder->SetInsertPoint(after_bb);
4254 }
4255
visit(const Evaluate * op)4256 void CodeGen_LLVM::visit(const Evaluate *op) {
4257 codegen(op->value);
4258
4259 // Discard result
4260 value = nullptr;
4261 }
4262
visit(const Shuffle * op)4263 void CodeGen_LLVM::visit(const Shuffle *op) {
4264 if (op->is_interleave()) {
4265 vector<Value *> vecs;
4266 for (Expr i : op->vectors) {
4267 vecs.push_back(codegen(i));
4268 }
4269 value = interleave_vectors(vecs);
4270 } else {
4271 vector<Value *> vecs;
4272 for (Expr i : op->vectors) {
4273 vecs.push_back(codegen(i));
4274 }
4275 value = concat_vectors(vecs);
4276 if (op->is_concat()) {
4277 // If this is just a concat, we're done.
4278 } else if (op->is_slice() && op->slice_stride() == 1) {
4279 value = slice_vector(value, op->indices[0], op->indices.size());
4280 } else {
4281 value = shuffle_vectors(value, op->indices);
4282 }
4283 }
4284
4285 if (op->type.is_scalar() && value->getType()->isVectorTy()) {
4286 value = builder->CreateExtractElement(value, ConstantInt::get(i32_t, 0));
4287 }
4288 }
4289
visit(const VectorReduce * op)4290 void CodeGen_LLVM::visit(const VectorReduce *op) {
4291 codegen_vector_reduce(op, Expr());
4292 }
4293
codegen_vector_reduce(const VectorReduce * op,const Expr & init)4294 void CodeGen_LLVM::codegen_vector_reduce(const VectorReduce *op, const Expr &init) {
4295 Expr val = op->value;
4296 const int output_lanes = op->type.lanes();
4297 const int native_lanes = native_vector_bits() / op->type.bits();
4298 const int factor = val.type().lanes() / output_lanes;
4299
4300 Expr (*binop)(Expr, Expr) = nullptr;
4301 switch (op->op) {
4302 case VectorReduce::Add:
4303 binop = Add::make;
4304 break;
4305 case VectorReduce::Mul:
4306 binop = Mul::make;
4307 break;
4308 case VectorReduce::Min:
4309 binop = Min::make;
4310 break;
4311 case VectorReduce::Max:
4312 binop = Max::make;
4313 break;
4314 case VectorReduce::And:
4315 binop = And::make;
4316 break;
4317 case VectorReduce::Or:
4318 binop = Or::make;
4319 break;
4320 }
4321
4322 if (op->type.is_bool() && op->op == VectorReduce::Or) {
4323 // Cast to u8, use max, cast back to bool.
4324 Expr equiv = cast(op->value.type().with_bits(8), op->value);
4325 equiv = VectorReduce::make(VectorReduce::Max, equiv, op->type.lanes());
4326 if (init.defined()) {
4327 equiv = max(equiv, init);
4328 }
4329 equiv = cast(op->type, equiv);
4330 equiv.accept(this);
4331 return;
4332 }
4333
4334 if (op->type.is_bool() && op->op == VectorReduce::And) {
4335 // Cast to u8, use min, cast back to bool.
4336 Expr equiv = cast(op->value.type().with_bits(8), op->value);
4337 equiv = VectorReduce::make(VectorReduce::Min, equiv, op->type.lanes());
4338 equiv = cast(op->type, equiv);
4339 if (init.defined()) {
4340 equiv = min(equiv, init);
4341 }
4342 equiv.accept(this);
4343 return;
4344 }
4345
4346 if (op->type.element_of() == Float(16)) {
4347 Expr equiv = cast(op->value.type().with_bits(32), op->value);
4348 equiv = VectorReduce::make(op->op, equiv, op->type.lanes());
4349 if (init.defined()) {
4350 equiv = binop(equiv, init);
4351 }
4352 equiv = cast(op->type, equiv);
4353 equiv.accept(this);
4354 return;
4355 }
4356
4357 #if LLVM_VERSION >= 90
4358 if (output_lanes == 1 &&
4359 (target.arch != Target::ARM || LLVM_VERSION >= 100)) {
4360 const int input_lanes = val.type().lanes();
4361 const int input_bytes = input_lanes * val.type().bytes();
4362 const bool llvm_has_intrinsic =
4363 // Must be one of these ops
4364 ((op->op == VectorReduce::Add ||
4365 op->op == VectorReduce::Mul ||
4366 op->op == VectorReduce::Min ||
4367 op->op == VectorReduce::Max) &&
4368 // Must be a power of two lanes
4369 (input_lanes >= 2) &&
4370 ((input_lanes & (input_lanes - 1)) == 0) &&
4371 // int versions exist up to 1024 bits
4372 ((!op->type.is_float() && input_bytes <= 1024) ||
4373 // float versions exist up to 16 lanes
4374 input_lanes <= 16) &&
4375 // As of the release of llvm 10, the 64-bit experimental total
4376 // reductions don't seem to be done yet on arm.
4377 (val.type().bits() != 64 ||
4378 target.arch != Target::ARM));
4379
4380 if (llvm_has_intrinsic) {
4381 std::stringstream name;
4382 name << "llvm.experimental.vector.reduce.";
4383 const int bits = op->type.bits();
4384 bool takes_initial_value = false;
4385 Expr initial_value = init;
4386 if (op->type.is_float()) {
4387 switch (op->op) {
4388 case VectorReduce::Add:
4389 name << "v2.fadd.f" << bits;
4390 takes_initial_value = true;
4391 if (!initial_value.defined()) {
4392 initial_value = make_zero(op->type);
4393 }
4394 break;
4395 case VectorReduce::Mul:
4396 name << "v2.fmul.f" << bits;
4397 takes_initial_value = true;
4398 if (!initial_value.defined()) {
4399 initial_value = make_one(op->type);
4400 }
4401 break;
4402 case VectorReduce::Min:
4403 name << "fmin";
4404 break;
4405 case VectorReduce::Max:
4406 name << "fmax";
4407 break;
4408 default:
4409 break;
4410 }
4411 } else if (op->type.is_int() || op->type.is_uint()) {
4412 switch (op->op) {
4413 case VectorReduce::Add:
4414 name << "add";
4415 break;
4416 case VectorReduce::Mul:
4417 name << "mul";
4418 break;
4419 case VectorReduce::Min:
4420 name << (op->type.is_int() ? 's' : 'u') << "min";
4421 break;
4422 case VectorReduce::Max:
4423 name << (op->type.is_int() ? 's' : 'u') << "max";
4424 break;
4425 default:
4426 break;
4427 }
4428 }
4429 name << ".v" << val.type().lanes() << (op->type.is_float() ? 'f' : 'i') << bits;
4430
4431 string intrin_name = name.str();
4432
4433 vector<Expr> args;
4434 if (takes_initial_value) {
4435 args.push_back(initial_value);
4436 initial_value = Expr();
4437 }
4438 args.push_back(op->value);
4439
4440 // Make sure the declaration exists, or the codegen for
4441 // call will assume that the args should scalarize.
4442 if (!module->getFunction(intrin_name)) {
4443 vector<llvm::Type *> arg_types;
4444 for (const Expr &e : args) {
4445 arg_types.push_back(llvm_type_of(e.type()));
4446 }
4447 FunctionType *func_t = FunctionType::get(llvm_type_of(op->type), arg_types, false);
4448 llvm::Function::Create(func_t, llvm::Function::ExternalLinkage, intrin_name, module.get());
4449 }
4450
4451 Expr equiv = Call::make(op->type, intrin_name, args, Call::PureExtern);
4452 if (initial_value.defined()) {
4453 equiv = binop(initial_value, equiv);
4454 }
4455 equiv.accept(this);
4456 return;
4457 }
4458 }
4459 #endif
4460
4461 if (output_lanes == 1 &&
4462 factor > native_lanes &&
4463 factor % native_lanes == 0) {
4464 // It's a total reduction of multiple native
4465 // vectors. Start by adding the vectors together.
4466 Expr equiv;
4467 for (int i = 0; i < factor / native_lanes; i++) {
4468 Expr next = Shuffle::make_slice(val, i * native_lanes, 1, native_lanes);
4469 if (equiv.defined()) {
4470 equiv = binop(equiv, next);
4471 } else {
4472 equiv = next;
4473 }
4474 }
4475 equiv = VectorReduce::make(op->op, equiv, 1);
4476 if (init.defined()) {
4477 equiv = binop(equiv, init);
4478 }
4479 equiv = common_subexpression_elimination(equiv);
4480 equiv.accept(this);
4481 return;
4482 }
4483
4484 if (factor > 2 && ((factor & 1) == 0)) {
4485 // Factor the reduce into multiple stages. If we're going to
4486 // be widening the type by 4x or more we should also factor the
4487 // widening into multiple stages.
4488 Type intermediate_type = op->value.type().with_lanes(op->value.type().lanes() / 2);
4489 Expr equiv = VectorReduce::make(op->op, op->value, intermediate_type.lanes());
4490 if (op->op == VectorReduce::Add &&
4491 (op->type.is_int() || op->type.is_uint()) &&
4492 op->type.bits() >= 32) {
4493 Type narrower_type = op->value.type().with_bits(op->type.bits() / 4);
4494 Expr narrower = lossless_cast(narrower_type, op->value);
4495 if (!narrower.defined() && narrower_type.is_int()) {
4496 // Maybe we can narrow to an unsigned int instead.
4497 narrower_type = narrower_type.with_code(Type::UInt);
4498 narrower = lossless_cast(narrower_type, op->value);
4499 }
4500 if (narrower.defined()) {
4501 // Widen it by 2x before the horizontal add
4502 narrower = cast(narrower.type().with_bits(narrower.type().bits() * 2), narrower);
4503 equiv = VectorReduce::make(op->op, narrower, intermediate_type.lanes());
4504 // Then widen it by 2x again afterwards
4505 equiv = cast(intermediate_type, equiv);
4506 }
4507 }
4508 equiv = VectorReduce::make(op->op, equiv, op->type.lanes());
4509 if (init.defined()) {
4510 equiv = binop(equiv, init);
4511 }
4512 equiv = common_subexpression_elimination(equiv);
4513 codegen(equiv);
4514 return;
4515 }
4516
4517 // Extract each slice and combine
4518 Expr equiv = init;
4519 for (int i = 0; i < factor; i++) {
4520 Expr next = Shuffle::make_slice(val, i, factor, val.type().lanes() / factor);
4521 if (equiv.defined()) {
4522 equiv = binop(equiv, next);
4523 } else {
4524 equiv = next;
4525 }
4526 }
4527 equiv = common_subexpression_elimination(equiv);
4528 codegen(equiv);
4529 } // namespace Internal
4530
visit(const Atomic * op)4531 void CodeGen_LLVM::visit(const Atomic *op) {
4532 if (op->mutex_name != "") {
4533 internal_assert(!inside_atomic_mutex_node)
4534 << "Nested atomic mutex locks detected. This might causes a deadlock.\n";
4535 ScopedValue<bool> old_inside_atomic_mutex_node(inside_atomic_mutex_node, true);
4536 // Mutex locking & unlocking are handled by function calls generated by previous lowering passes.
4537 codegen(op->body);
4538 } else {
4539 // Issue atomic stores.
4540 ScopedValue<bool> old_emit_atomic_stores(emit_atomic_stores, true);
4541 codegen(op->body);
4542 }
4543 }
4544
create_alloca_at_entry(llvm::Type * t,int n,bool zero_initialize,const string & name)4545 Value *CodeGen_LLVM::create_alloca_at_entry(llvm::Type *t, int n, bool zero_initialize, const string &name) {
4546 IRBuilderBase::InsertPoint here = builder->saveIP();
4547 BasicBlock *entry = &builder->GetInsertBlock()->getParent()->getEntryBlock();
4548 if (entry->empty()) {
4549 builder->SetInsertPoint(entry);
4550 } else {
4551 builder->SetInsertPoint(entry, entry->getFirstInsertionPt());
4552 }
4553 Value *size = ConstantInt::get(i32_t, n);
4554 AllocaInst *ptr = builder->CreateAlloca(t, size, name);
4555 int align = native_vector_bits() / 8;
4556 llvm::DataLayout d(module.get());
4557 int allocated_size = n * (int)d.getTypeAllocSize(t);
4558 if (t->isVectorTy() || n > 1) {
4559 ptr->setAlignment(make_alignment(align));
4560 }
4561 requested_alloca_total += allocated_size;
4562
4563 if (zero_initialize) {
4564 if (n == 1) {
4565 builder->CreateStore(Constant::getNullValue(t), ptr);
4566 } else {
4567 builder->CreateMemSet(ptr, Constant::getNullValue(t), n, make_alignment(align));
4568 }
4569 }
4570 builder->restoreIP(here);
4571 return ptr;
4572 }
4573
get_user_context() const4574 Value *CodeGen_LLVM::get_user_context() const {
4575 Value *ctx = sym_get("__user_context", false);
4576 if (!ctx) {
4577 ctx = ConstantPointerNull::get(i8_t->getPointerTo()); // void*
4578 }
4579 return ctx;
4580 }
4581
call_intrin(const Type & result_type,int intrin_lanes,const string & name,vector<Expr> args)4582 Value *CodeGen_LLVM::call_intrin(const Type &result_type, int intrin_lanes,
4583 const string &name, vector<Expr> args) {
4584 vector<Value *> arg_values(args.size());
4585 for (size_t i = 0; i < args.size(); i++) {
4586 arg_values[i] = codegen(args[i]);
4587 }
4588
4589 llvm::Type *t = llvm_type_of(result_type);
4590
4591 return call_intrin(t,
4592 intrin_lanes,
4593 name, arg_values);
4594 }
4595
call_intrin(llvm::Type * result_type,int intrin_lanes,const string & name,vector<Value * > arg_values)4596 Value *CodeGen_LLVM::call_intrin(llvm::Type *result_type, int intrin_lanes,
4597 const string &name, vector<Value *> arg_values) {
4598 int arg_lanes = 1;
4599 if (result_type->isVectorTy()) {
4600 arg_lanes = get_vector_num_elements(result_type);
4601 }
4602
4603 if (intrin_lanes != arg_lanes) {
4604 // Cut up each arg into appropriately-sized pieces, call the
4605 // intrinsic on each, then splice together the results.
4606 vector<Value *> results;
4607 for (int start = 0; start < arg_lanes; start += intrin_lanes) {
4608 vector<Value *> args;
4609 for (size_t i = 0; i < arg_values.size(); i++) {
4610 int arg_i_lanes = 1;
4611 if (arg_values[i]->getType()->isVectorTy()) {
4612 arg_i_lanes = get_vector_num_elements(arg_values[i]->getType());
4613 }
4614 if (arg_i_lanes >= arg_lanes) {
4615 // Horizontally reducing intrinsics may have
4616 // arguments that have more lanes than the
4617 // result. Assume that the horizontally reduce
4618 // neighboring elements...
4619 int reduce = arg_i_lanes / arg_lanes;
4620 args.push_back(slice_vector(arg_values[i], start * reduce, intrin_lanes * reduce));
4621 } else if (arg_i_lanes == 1) {
4622 // It's a scalar arg to an intrinsic that returns
4623 // a vector. Replicate it over the slices.
4624 args.push_back(arg_values[i]);
4625 } else {
4626 internal_error << "Argument in call_intrin has " << arg_i_lanes
4627 << " with result type having " << arg_lanes << "\n";
4628 }
4629 }
4630
4631 llvm::Type *result_slice_type =
4632 get_vector_type(result_type->getScalarType(), intrin_lanes);
4633
4634 results.push_back(call_intrin(result_slice_type, intrin_lanes, name, args));
4635 }
4636 Value *result = concat_vectors(results);
4637 return slice_vector(result, 0, arg_lanes);
4638 }
4639
4640 vector<llvm::Type *> arg_types(arg_values.size());
4641 for (size_t i = 0; i < arg_values.size(); i++) {
4642 arg_types[i] = arg_values[i]->getType();
4643 }
4644
4645 llvm::Function *fn = module->getFunction(name);
4646
4647 if (!fn) {
4648 llvm::Type *intrinsic_result_type = result_type->getScalarType();
4649 if (intrin_lanes > 1) {
4650 intrinsic_result_type = get_vector_type(result_type->getScalarType(), intrin_lanes);
4651 }
4652 FunctionType *func_t = FunctionType::get(intrinsic_result_type, arg_types, false);
4653 fn = llvm::Function::Create(func_t, llvm::Function::ExternalLinkage, name, module.get());
4654 fn->setCallingConv(CallingConv::C);
4655 }
4656
4657 CallInst *call = builder->CreateCall(fn, arg_values);
4658
4659 call->setDoesNotAccessMemory();
4660 call->setDoesNotThrow();
4661
4662 return call;
4663 }
4664
slice_vector(Value * vec,int start,int size)4665 Value *CodeGen_LLVM::slice_vector(Value *vec, int start, int size) {
4666 // Force the arg to be an actual vector
4667 if (!vec->getType()->isVectorTy()) {
4668 vec = create_broadcast(vec, 1);
4669 }
4670
4671 int vec_lanes = get_vector_num_elements(vec->getType());
4672
4673 if (start == 0 && size == vec_lanes) {
4674 return vec;
4675 }
4676
4677 if (size == 1) {
4678 return builder->CreateExtractElement(vec, (uint64_t)start);
4679 }
4680
4681 vector<int> indices(size);
4682 for (int i = 0; i < size; i++) {
4683 int idx = start + i;
4684 if (idx >= 0 && idx < vec_lanes) {
4685 indices[i] = idx;
4686 } else {
4687 indices[i] = -1;
4688 }
4689 }
4690 return shuffle_vectors(vec, indices);
4691 }
4692
concat_vectors(const vector<Value * > & v)4693 Value *CodeGen_LLVM::concat_vectors(const vector<Value *> &v) {
4694 if (v.size() == 1) return v[0];
4695
4696 internal_assert(!v.empty());
4697
4698 vector<Value *> vecs = v;
4699
4700 // Force them all to be actual vectors
4701 for (Value *&val : vecs) {
4702 if (!val->getType()->isVectorTy()) {
4703 val = create_broadcast(val, 1);
4704 }
4705 }
4706
4707 while (vecs.size() > 1) {
4708 vector<Value *> new_vecs;
4709
4710 for (size_t i = 0; i < vecs.size() - 1; i += 2) {
4711 Value *v1 = vecs[i];
4712 Value *v2 = vecs[i + 1];
4713
4714 int w1 = get_vector_num_elements(v1->getType());
4715 int w2 = get_vector_num_elements(v2->getType());
4716
4717 // Possibly pad one of the vectors to match widths.
4718 if (w1 < w2) {
4719 v1 = slice_vector(v1, 0, w2);
4720 } else if (w2 < w1) {
4721 v2 = slice_vector(v2, 0, w1);
4722 }
4723 int w_matched = std::max(w1, w2);
4724
4725 internal_assert(v1->getType() == v2->getType());
4726
4727 vector<int> indices(w1 + w2);
4728 for (int i = 0; i < w1; i++) {
4729 indices[i] = i;
4730 }
4731 for (int i = 0; i < w2; i++) {
4732 indices[w1 + i] = w_matched + i;
4733 }
4734
4735 Value *merged = shuffle_vectors(v1, v2, indices);
4736
4737 new_vecs.push_back(merged);
4738 }
4739
4740 // If there were an odd number of them, we need to also push
4741 // the one that didn't get merged.
4742 if (vecs.size() & 1) {
4743 new_vecs.push_back(vecs.back());
4744 }
4745
4746 vecs.swap(new_vecs);
4747 }
4748
4749 return vecs[0];
4750 }
4751
shuffle_vectors(Value * a,Value * b,const std::vector<int> & indices)4752 Value *CodeGen_LLVM::shuffle_vectors(Value *a, Value *b,
4753 const std::vector<int> &indices) {
4754 internal_assert(a->getType() == b->getType());
4755 vector<Constant *> llvm_indices(indices.size());
4756 for (size_t i = 0; i < llvm_indices.size(); i++) {
4757 if (indices[i] >= 0) {
4758 internal_assert(indices[i] < get_vector_num_elements(a->getType()) * 2);
4759 llvm_indices[i] = ConstantInt::get(i32_t, indices[i]);
4760 } else {
4761 // Only let -1 be undef.
4762 internal_assert(indices[i] == -1);
4763 llvm_indices[i] = UndefValue::get(i32_t);
4764 }
4765 }
4766
4767 return builder->CreateShuffleVector(a, b, ConstantVector::get(llvm_indices));
4768 }
4769
shuffle_vectors(Value * a,const std::vector<int> & indices)4770 Value *CodeGen_LLVM::shuffle_vectors(Value *a, const std::vector<int> &indices) {
4771 Value *b = UndefValue::get(a->getType());
4772 return shuffle_vectors(a, b, indices);
4773 }
4774
find_vector_runtime_function(const std::string & name,int lanes)4775 std::pair<llvm::Function *, int> CodeGen_LLVM::find_vector_runtime_function(const std::string &name, int lanes) {
4776 // Check if a vector version of the function already
4777 // exists at some useful width. We use the naming
4778 // convention that a N-wide version of a function foo is
4779 // called fooxN. All of our intrinsics are power-of-two
4780 // sized, so starting at the first power of two >= the
4781 // vector width, we'll try all powers of two in decreasing
4782 // order.
4783 vector<int> sizes_to_try;
4784 int l = 1;
4785 while (l < lanes)
4786 l *= 2;
4787 for (int i = l; i > 1; i /= 2) {
4788 sizes_to_try.push_back(i);
4789 }
4790
4791 // If none of those match, we'll also try doubling
4792 // the lanes up to the next power of two (this is to catch
4793 // cases where we're a 64-bit vector and have a 128-bit
4794 // vector implementation).
4795 sizes_to_try.push_back(l * 2);
4796
4797 for (size_t i = 0; i < sizes_to_try.size(); i++) {
4798 int l = sizes_to_try[i];
4799 llvm::Function *vec_fn = module->getFunction(name + "x" + std::to_string(l));
4800 if (vec_fn) {
4801 return {vec_fn, l};
4802 }
4803 }
4804
4805 return {nullptr, 0};
4806 }
4807
supports_atomic_add(const Type & t) const4808 bool CodeGen_LLVM::supports_atomic_add(const Type &t) const {
4809 return t.is_int_or_uint();
4810 }
4811
use_pic() const4812 bool CodeGen_LLVM::use_pic() const {
4813 return true;
4814 }
4815
4816 } // namespace Internal
4817 } // namespace Halide
4818