1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file codegen_llvm.cc
22 */
23 #ifdef TVM_LLVM_VERSION
24 // Part of the code are adapted from Halide's CodeGen_LLVM
25 #include <tvm/runtime/device_api.h>
26 #include <tvm/runtime/c_runtime_api.h>
27
28 #include <algorithm>
29
30 #include "codegen_llvm.h"
31 #include "codegen_cpu.h"
32 #include "../build_common.h"
33 #include "../../pass/ir_util.h"
34 #include "../../arithmetic/compute_expr.h"
35
36 namespace tvm {
37 namespace codegen {
38
Create(llvm::TargetMachine * tm)39 std::unique_ptr<CodeGenLLVM> CodeGenLLVM::Create(llvm::TargetMachine *tm) {
40 std::string target = tm->getTarget().getName();
41 std::string factory_name = "tvm.codegen.llvm.target_" + target;
42 const PackedFunc* f = runtime::Registry::Get(factory_name);
43 if (f != nullptr) {
44 void* handle = (*f)();
45 return std::unique_ptr<CodeGenLLVM>(static_cast<CodeGenLLVM*>(handle));
46 } else {
47 return std::unique_ptr<CodeGenLLVM>(new CodeGenCPU());
48 }
49 }
50
Init(const std::string & module_name,llvm::TargetMachine * tm,llvm::LLVMContext * ctx,bool system_lib,bool dynamic_lookup)51 void CodeGenLLVM::Init(const std::string& module_name,
52 llvm::TargetMachine* tm,
53 llvm::LLVMContext* ctx,
54 bool system_lib,
55 bool dynamic_lookup) {
56 InitializeLLVM();
57 ctx_ = ctx;
58 builder_.reset(new IRBuilder(*ctx_));
59 module_.reset(new llvm::Module(module_name, *ctx_));
60 md_builder_.reset(new llvm::MDBuilder(*ctx_));
61 // types
62 t_void_ = llvm::Type::getVoidTy(*ctx_);
63 t_void_p_ = llvm::Type::getInt8Ty(*ctx_)->getPointerTo();
64 t_int_ = llvm::Type::getInt32Ty(*ctx_);
65 t_char_ = llvm::Type::getInt8Ty(*ctx_);
66 t_int8_ = llvm::Type::getInt8Ty(*ctx_);
67 t_int16_ = llvm::Type::getInt16Ty(*ctx_);
68 t_int32_ = llvm::Type::getInt32Ty(*ctx_);
69 t_int64_ = llvm::Type::getInt64Ty(*ctx_);
70 t_float64_ = llvm::Type::getDoubleTy(*ctx_);
71 // meta data
72 md_very_likely_branch_ = md_builder_->createBranchWeights(1<<20, 1);
73 md_tbaa_root_ = md_builder_->createTBAARoot("tvm-tbaa");
74 md_tbaa_alias_set_ = md_builder_->createTBAANode("tvm-alias", md_tbaa_root_);
75 this->InitTarget(tm);
76 }
77
InitTarget(llvm::TargetMachine * tm)78 void CodeGenLLVM::InitTarget(llvm::TargetMachine* tm) {
79 module_->setTargetTriple(tm->getTargetTriple().str());
80 module_->setDataLayout(tm->createDataLayout());
81 data_layout_.reset(new llvm::DataLayout(module_.get()));
82 target_machine_ = tm;
83 if (native_vector_bits_ == 0) {
84 const auto& arch = tm->getTargetTriple().getArch();
85 if (arch == llvm::Triple::x86_64) {
86 // for avx512
87 native_vector_bits_ = 512;
88 } else if (arch == llvm::Triple::x86) {
89 native_vector_bits_ = 256;
90 } else if (arch == llvm::Triple::arm || arch == llvm::Triple::aarch64) {
91 native_vector_bits_ = 128;
92 } else {
93 native_vector_bits_ = 128;
94 std::string arch_name = tm->getTargetTriple().getArchName();
95 LOG(WARNING) << "Set native vector bits to be 128 for " << arch_name;
96 }
97 }
98 }
99
AddFunction(const LoweredFunc & f)100 void CodeGenLLVM::AddFunction(const LoweredFunc& f) {
101 this->AddFunctionInternal(f, false);
102 }
103
InitFuncState()104 void CodeGenLLVM::InitFuncState() {
105 var_map_.clear();
106 alias_var_set_.clear();
107 alloc_storage_info_.clear();
108 volatile_buf_.clear();
109 analyzer_.reset(new arith::Analyzer());
110 }
111
112
AddFunctionInternal(const LoweredFunc & f,bool ret_void)113 void CodeGenLLVM::AddFunctionInternal(const LoweredFunc& f, bool ret_void) {
114 this->InitFuncState();
115 std::vector<llvm::Type*> arg_types;
116 is_restricted_ = f->is_restricted;
117 for (Var arg : f->args) {
118 Type t = arg.type();
119 if (t.is_handle()) {
120 auto it = f->handle_data_type.find(arg);
121 if (it != f->handle_data_type.end()) {
122 arg_types.push_back(LLVMType((*it).second.type())
123 ->getPointerTo(GetGlobalAddressSpace()));
124 } else {
125 arg_types.push_back(t_int8_->getPointerTo(GetGlobalAddressSpace()));
126 }
127 if (!is_restricted_) {
128 alias_var_set_.insert(arg.get());
129 }
130 } else {
131 arg_types.push_back(LLVMType(arg.type()));
132 }
133 }
134 llvm::FunctionType* ftype = llvm::FunctionType::get(
135 ret_void ? t_void_ : t_int_, arg_types, false);
136 CHECK(module_->getFunction(f->name) == nullptr)
137 << "Function " << f->name << " already exist in module";
138 function_ = llvm::Function::Create(
139 ftype, llvm::Function::ExternalLinkage,
140 f->name, module_.get());
141 function_->setCallingConv(llvm::CallingConv::C);
142 function_->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass);
143 // set var map and align information
144 auto arg_it = function_->arg_begin();
145 for (size_t i = 0; i < f->args.size(); ++i, ++arg_it) {
146 llvm::Argument* v = &(*arg_it);
147 const Var& var = f->args[i];
148 var_map_[var.get()] = v;
149 if (is_restricted_) {
150 if (var.type().is_handle() && !alias_var_set_.count(var.get())) {
151 // set non alias.
152 #if TVM_LLVM_VERSION >= 50
153 function_->addParamAttr(i, llvm::Attribute::NoAlias);
154 #else
155 function_->setDoesNotAlias(i + 1);
156 #endif
157 }
158 }
159 }
160 llvm::BasicBlock* entry = llvm::BasicBlock::Create(*ctx_, "entry", function_);
161 builder_->SetInsertPoint(entry);
162 this->VisitStmt(f->body);
163 if (ret_void) {
164 builder_->CreateRetVoid();
165 } else {
166 builder_->CreateRet(ConstInt32(0));
167 }
168 }
169
170
Finish()171 std::unique_ptr<llvm::Module> CodeGenLLVM::Finish() {
172 this->AddStartupFunction();
173 for (size_t i = 0; i < link_modules_.size(); ++i) {
174 CHECK(!llvm::Linker::linkModules(*module_, std::move(link_modules_[i])))
175 << "Failed to link modules";
176 }
177 link_modules_.clear();
178 // optimize
179 this->Optimize();
180 return std::move(module_);
181 }
182
183
HandleImport(const std::string & code)184 void CodeGenLLVM::HandleImport(const std::string& code) {
185 std::unique_ptr<llvm::Module> mlib;
186 llvm::SMDiagnostic err;
187 if (code.length() >= 3 &&
188 (code.substr(code.length() - 3) == ".ll" ||
189 code.substr(code.length() - 3) == ".bc")) {
190 mlib = llvm::parseIRFile(code, err, *ctx_);
191 if (mlib.get() == nullptr) {
192 std::string msg = err.getMessage();
193 LOG(FATAL) << "Fail to load bitcode file " << code << "\n"
194 << "line " << err.getLineNo() << ":" << msg;
195 }
196 } else {
197 std::unique_ptr<llvm::MemoryBuffer> buf =
198 llvm::MemoryBuffer::getMemBuffer(code);
199 mlib = llvm::parseIR(*buf, err, *ctx_);
200 if (mlib.get() == nullptr) {
201 std::string msg = err.getMessage();
202 LOG(FATAL) << "Fail to load llvm ir "
203 << "line " << err.getLineNo() << ":" << msg
204 << "\ncontent:\n" << code;
205 }
206 }
207 mlib->setTargetTriple(target_machine_->getTargetTriple().str());
208 mlib->setDataLayout(target_machine_->createDataLayout());
209 // mark all the functions as force inline
210 for (llvm::Function &f : mlib->functions()) {
211 f.removeFnAttr(llvm::Attribute::NoInline);
212 f.addFnAttr(llvm::Attribute::AlwaysInline);
213 f.setLinkage(llvm::GlobalValue::AvailableExternallyLinkage);
214 }
215 // add to linker libraries.
216 this->AddLinkModule(std::move(mlib));
217 }
218
AddLinkModule(std::unique_ptr<llvm::Module> && mod)219 void CodeGenLLVM::AddLinkModule(std::unique_ptr<llvm::Module>&& mod) {
220 link_modules_.emplace_back(std::move(mod));
221 }
222
AddMainFunction(const std::string & entry_func_name)223 void CodeGenLLVM::AddMainFunction(const std::string& entry_func_name) {
224 LOG(FATAL) << "not implemented";
225 }
226
GetThreadIndex(const IterVar & iv)227 llvm::Value* CodeGenLLVM::GetThreadIndex(const IterVar& iv) {
228 LOG(FATAL) << "not implemented";
229 return nullptr;
230 }
231
CreateStorageSync(const Call * op)232 llvm::Value* CodeGenLLVM::CreateStorageSync(const Call* op) {
233 LOG(FATAL) << "not implemented";
234 return nullptr;
235 }
236
237 class FPassManager : public llvm::legacy::FunctionPassManager {
238 public:
FPassManager(llvm::Module * m)239 explicit FPassManager(llvm::Module* m)
240 : llvm::legacy::FunctionPassManager(m) {}
241 // override add to allow messaging
add(llvm::Pass * p)242 void add(llvm::Pass* p) final {
243 llvm::legacy::FunctionPassManager::add(p);
244 }
245 };
246
247 class MPassManager : public llvm::legacy::PassManager {
248 public:
249 // override add to allow messaging
add(llvm::Pass * p)250 void add(llvm::Pass* p) final {
251 llvm::legacy::PassManager::add(p);
252 }
253 };
254
InitPassManagerBuilder(llvm::PassManagerBuilder * builder)255 void CodeGenLLVM::InitPassManagerBuilder(llvm::PassManagerBuilder* builder) {
256 }
257
Optimize()258 void CodeGenLLVM::Optimize() {
259 // pass manager
260 FPassManager fpass(module_.get());
261 MPassManager mpass;
262 mpass.add(llvm::createTargetTransformInfoWrapperPass(
263 target_machine_ ? target_machine_->getTargetIRAnalysis() :
264 llvm::TargetIRAnalysis()));
265 fpass.add(llvm::createTargetTransformInfoWrapperPass(
266 target_machine_ ? target_machine_->getTargetIRAnalysis() :
267 llvm::TargetIRAnalysis()));
268
269 // place optimization pass
270 llvm::PassManagerBuilder builder;
271 builder.OptLevel = 3;
272
273 #if TVM_LLVM_VERSION >= 50
274 builder.Inliner = llvm::createFunctionInliningPass(builder.OptLevel, 0, false);
275 #else
276 builder.Inliner = llvm::createFunctionInliningPass(builder.OptLevel, 0);
277 #endif
278 builder.LoopVectorize = true;
279 builder.SLPVectorize = true;
280 this->InitPassManagerBuilder(&builder);
281
282 #if TVM_LLVM_VERSION >= 50
283 target_machine_->adjustPassManager(builder);
284 #endif
285
286 builder.populateFunctionPassManager(fpass);
287 builder.populateModulePassManager(mpass);
288
289 fpass.doInitialization();
290 for (auto it = module_->begin(); it != module_->end(); ++it) {
291 fpass.run(*it);
292 }
293 fpass.doFinalization();
294 mpass.run(*module_);
295 }
296
NativeVectorBits(const runtime::StorageScope & storage_scope) const297 int CodeGenLLVM::NativeVectorBits(const runtime::StorageScope& storage_scope) const {
298 return native_vector_bits_;
299 }
300
GetGlobalAddressSpace()301 unsigned CodeGenLLVM::GetGlobalAddressSpace() {
302 return 0;
303 }
304
LLVMType(const Type & t) const305 llvm::Type* CodeGenLLVM::LLVMType(const Type& t) const {
306 if (t.is_handle()) {
307 CHECK_EQ(t.lanes(), 1);
308 return t_void_p_;
309 }
310 llvm::Type* etype = nullptr;
311 if (t.is_int() || t.is_uint()) {
312 etype = llvm::Type::getIntNTy(*ctx_, t.bits());
313 } else if (t.is_float()) {
314 switch (t.bits()) {
315 case 16: etype = llvm::Type::getHalfTy(*ctx_); break;
316 case 32: etype = llvm::Type::getFloatTy(*ctx_); break;
317 case 64: etype = llvm::Type::getDoubleTy(*ctx_); break;
318 default: LOG(FATAL) << "do not support " << t;
319 }
320 }
321 if (t.lanes() != 1) {
322 return llvm::VectorType::get(etype, t.lanes());
323 } else {
324 return etype;
325 }
326 }
327
328 // Add tbaa alias information for load
329 //
330 // use a binary tree typed system to declare information
331 // and allow alias to be distinguished across nodes.
332 //
333 // This trick comes from Halide's CodeGen_LLVM
334 //
AddAliasInfo(llvm::Instruction * inst,const Variable * buffer,Expr index,Type type)335 void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst,
336 const Variable* buffer,
337 Expr index,
338 Type type) {
339 if (alias_var_set_.count(buffer) != 0) {
340 // Mark all possibly aliased pointer as same type.
341 llvm::MDNode* meta = md_tbaa_alias_set_;
342 inst->setMetadata(
343 "tbaa",
344 md_builder_->createTBAAStructTagNode(meta, meta, 0));
345 return;
346 }
347 int base = 0, width = 0;
348 // create meta-data for alias analysis
349 // Use a group of binary tree ranges of memory banks.
350 if (index.defined()) {
351 const Ramp* ramp = index.as<Ramp>();
352 if (ramp) {
353 int base, stride;
354 if (arith::GetConstInt(ramp->base, &base) &&
355 arith::GetConstInt(ramp->stride, &stride)) {
356 int xwith = ramp->lanes * stride;
357 width = 1;
358 while (width < xwith) {
359 width *= 2;
360 }
361 while (base % width) {
362 base -= base % width;
363 width *= 2;
364 }
365 }
366 } else {
367 if (arith::GetConstInt(index, &base)) width = 1;
368 }
369 }
370 llvm::MDNode* meta = md_tbaa_root_;
371 std::ostringstream buffer_addr, buffer_type;
372 buffer_addr << buffer;
373 meta = md_builder_->createTBAAScalarTypeNode(buffer_addr.str(), meta);
374 buffer_type << type.element_of();
375 meta = md_builder_->createTBAAScalarTypeNode(buffer_type.str(), meta);
376 // create a tree-shape access structure.
377 if (width != 0) {
378 for (int w = 1024; w >= width; w /= 2) {
379 int b = (base / w) * w;
380 std::stringstream os;
381 os << buffer << ".w" << w << ".b" << b;
382 meta = md_builder_->createTBAAScalarTypeNode(os.str(), meta);
383 }
384 }
385 inst->setMetadata(
386 "tbaa",
387 md_builder_->createTBAAStructTagNode(meta, meta, 0));
388 }
389
GetAlignment(Type t,const Variable * buf_var,const Expr & index,int * p_alignment,int * p_native_bits)390 void CodeGenLLVM::GetAlignment(Type t,
391 const Variable* buf_var,
392 const Expr& index,
393 int* p_alignment,
394 int* p_native_bits) {
395 int max_align_bits = t.bits();
396 auto it = alloc_storage_info_.find(buf_var);
397 if (it != alloc_storage_info_.end()) {
398 const StorageInfo& info = it->second;
399 *p_native_bits = NativeVectorBits(info.scope);
400 max_align_bits = info.alignment * 8;
401 } else {
402 *p_native_bits = native_vector_bits_;
403 }
404
405 arith::ModularSet me = analyzer_->modular_set(index);
406 int64_t base = me->base;
407 int64_t coeff = me->coeff;
408
409 int align_bits = t.bits();
410 while (align_bits < max_align_bits &&
411 base % 2 == 0 &&
412 coeff % 2 == 0) {
413 base = base / 2;
414 coeff = coeff / 2;
415 align_bits *= 2;
416 }
417 if (align_bits < 8) {
418 align_bits = 8;
419 }
420 *p_alignment = align_bits / 8;
421 }
422
423 std::unique_ptr<CodeGenLLVM::DebugInfo>
CreateDebugInfo(llvm::Module * module)424 CodeGenLLVM::CreateDebugInfo(llvm::Module* module) {
425 #if TVM_LLVM_VERSION >= 100
426 auto debug_info = std::make_unique<CodeGenLLVM::DebugInfo>();
427 debug_info->di_builder_ = std::make_unique<llvm::DIBuilder>(*module);
428 #else
429 auto debug_info = llvm::make_unique<CodeGenLLVM::DebugInfo>();
430 debug_info->di_builder_ = llvm::make_unique<llvm::DIBuilder>(*module);
431 #endif
432 // TODO(tulloch): pass this information through relay::Span classes to the LoweredFunc instance?
433 debug_info->file_ = debug_info->di_builder_->createFile("model.tvm", "/tmp/");
434 debug_info->compilation_unit_ = debug_info->di_builder_->createCompileUnit(
435 llvm::dwarf::DW_LANG_C, debug_info->file_, "TVM", 0, "", 0, "",
436 llvm::DICompileUnit::DebugEmissionKind::FullDebug,
437 /* SplitDebugInlining */ true,
438 /* DebugInfoForProfiling */ true);
439 return debug_info;
440 }
441
CreateBroadcast(llvm::Value * value,int lanes)442 llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) {
443 llvm::Constant* undef = llvm::UndefValue::get(
444 llvm::VectorType::get(value->getType(), lanes));
445 llvm::Constant* zero = ConstInt32(0);
446 value = builder_->CreateInsertElement(undef, value, zero);
447 llvm::Constant* mask = llvm::ConstantVector::getSplat(lanes, zero);
448 return builder_->CreateShuffleVector(value, undef, mask);
449 }
450
CreateVecSlice(llvm::Value * vec,int begin,int extent)451 llvm::Value* CodeGenLLVM::CreateVecSlice(llvm::Value* vec, int begin, int extent) {
452 int num_elems = static_cast<int>(vec->getType()->getVectorNumElements());
453 if (extent == num_elems && begin == 0) return vec;
454 CHECK(begin >= 0 && extent <= num_elems) << "Slicing out of bound!\n";
455 std::vector<llvm::Constant*> indices;
456 indices.reserve(extent);
457 for (int i = 0; i < extent; ++i) {
458 if (begin + i >= 0 && begin + i < num_elems) {
459 indices.push_back(llvm::ConstantInt::get(t_int32_, begin + i));
460 } else {
461 indices.push_back(llvm::UndefValue::get(t_int32_));
462 }
463 }
464 return builder_->CreateShuffleVector(vec, vec, llvm::ConstantVector::get(indices));
465 }
466
CreateVecFlip(llvm::Value * vec)467 llvm::Value* CodeGenLLVM::CreateVecFlip(llvm::Value* vec) {
468 int num_elems = static_cast<int>(vec->getType()->getVectorNumElements());
469 std::vector<unsigned> indices;
470 for (int i = 0; i < num_elems; ++i) {
471 indices.push_back(num_elems - i - 1);
472 }
473 return builder_->CreateShuffleVector(vec, vec, indices);
474 }
475
CreateVecPad(llvm::Value * vec,int target_lanes)476 llvm::Value* CodeGenLLVM::CreateVecPad(llvm::Value* vec, int target_lanes) {
477 llvm::Value* mask = llvm::UndefValue::get(LLVMType(Int(32, target_lanes)));
478 int num_elems = static_cast<int>(vec->getType()->getVectorNumElements());
479 if (num_elems == target_lanes) return vec;
480 CHECK_LT(num_elems, target_lanes);
481 for (int i = 0; i < num_elems; ++i) {
482 mask = builder_->CreateInsertElement(mask, ConstInt32(i), ConstInt32(i));
483 }
484 return builder_->CreateShuffleVector(vec, vec, mask);
485 }
486
CreateVecConcat(std::vector<llvm::Value * > vecs)487 llvm::Value* CodeGenLLVM::CreateVecConcat(std::vector<llvm::Value*> vecs) {
488 // concat vector, tree shape reduction
489 int total_lanes = 0;
490
491 for (llvm::Value* v : vecs) {
492 total_lanes += static_cast<int>(
493 v->getType()->getVectorNumElements());
494 }
495 while (vecs.size() > 1) {
496 std::vector<llvm::Value*> new_vecs;
497 for (size_t i = 0; i < vecs.size() - 1; i += 2) {
498 llvm::Value* lhs = vecs[i];
499 llvm::Value* rhs = vecs[i + 1];
500 const size_t lhs_lanes = lhs->getType()->getVectorNumElements();
501 const size_t rhs_lanes = rhs->getType()->getVectorNumElements();
502 if (lhs_lanes < rhs_lanes) {
503 lhs = CreateVecPad(lhs, rhs_lanes);
504 } else if (rhs_lanes < lhs_lanes) {
505 rhs = CreateVecPad(rhs, lhs_lanes);
506 }
507 const size_t shared_lanes = std::max(lhs_lanes, rhs_lanes);
508 std::vector<unsigned> mask;
509 for (size_t i = 0; i < lhs_lanes; ++i) {
510 mask.push_back(i);
511 }
512 for (size_t i = 0; i < rhs_lanes; ++i) {
513 mask.push_back(shared_lanes + i);
514 }
515 new_vecs.push_back(builder_->CreateShuffleVector(lhs, rhs, mask));
516 }
517 if (vecs.size() % 2 != 0) {
518 new_vecs.push_back(vecs.back());
519 }
520 vecs.swap(new_vecs);
521 }
522 return CreateVecSlice(vecs[0], 0, total_lanes);
523 }
524
525
CreateSerialFor(llvm::Value * begin,llvm::Value * end,llvm::Value * stride,const VarExpr & loop_var,const Stmt & body)526 void CodeGenLLVM::CreateSerialFor(llvm::Value* begin,
527 llvm::Value* end,
528 llvm::Value* stride,
529 const VarExpr& loop_var,
530 const Stmt& body) {
531 using llvm::BasicBlock;
532 BasicBlock* pre_block = builder_->GetInsertBlock();
533 BasicBlock* for_begin = BasicBlock::Create(
534 *ctx_, "for_begin", function_);
535 BasicBlock* for_body = BasicBlock::Create(
536 *ctx_, "for_body", function_);
537 BasicBlock* for_end = BasicBlock::Create(
538 *ctx_, "for_end", function_);
539 builder_->CreateBr(for_begin);
540 builder_->SetInsertPoint(for_begin);
541 llvm::PHINode* loop_value = builder_->CreatePHI(begin->getType(), 2);
542 loop_value->addIncoming(begin, pre_block);
543 CHECK(!var_map_.count(loop_var.get()));
544 var_map_[loop_var.get()] = loop_value;
545 builder_->CreateCondBr(CreateLT(loop_var.type(), loop_value, end),
546 for_body, for_end, md_very_likely_branch_);
547 builder_->SetInsertPoint(for_body);
548 this->VisitStmt(body);
549 var_map_.erase(loop_var.get());
550 llvm::Value* loop_next = CreateAdd(loop_var.type(), loop_value, stride);
551 loop_value->addIncoming(loop_next, builder_->GetInsertBlock());
552 builder_->CreateBr(for_begin);
553 builder_->SetInsertPoint(for_end);
554 }
555
556 // cast operatpr
CreateCast(Type from,Type to,llvm::Value * value)557 llvm::Value* CodeGenLLVM::CreateCast(Type from, Type to, llvm::Value* value) {
558 llvm::Type * target = LLVMType(to);
559 if (value->getType() == target) return value;
560 if (to.is_handle()) {
561 return builder_->CreateBitCast(value, target);
562 } else if (to.is_uint() && to.bits() == 1) {
563 if (from.is_float()) {
564 llvm::Constant* zero = llvm::ConstantFP::get(LLVMType(from), 0.);
565 return builder_->CreateFCmpONE(value, zero);
566 } else {
567 llvm::Constant* zero = llvm::ConstantInt::get(LLVMType(from), 0);
568 return builder_->CreateICmpNE(value, zero);
569 }
570 } else if (!from.is_float() && !to.is_float()) {
571 return builder_->CreateIntCast(value, target, from.is_int());
572 } else if (from.is_float() && to.is_int()) {
573 return builder_->CreateFPToSI(value, target);
574 } else if (from.is_float() && to.is_uint()) {
575 if (to.bits() < 8) {
576 value = builder_->CreateFPToUI(value, LLVMType(to.with_bits(8)));
577 return builder_->CreateIntCast(value, target, false);
578 } else {
579 return builder_->CreateFPToUI(value, target);
580 }
581 } else if (from.is_int() && to.is_float()) {
582 return builder_->CreateSIToFP(value, target);
583 } else if (from.is_uint() && to.is_float()) {
584 return builder_->CreateUIToFP(value, target);
585 } else {
586 CHECK(from.is_float() && to.is_float());
587 return builder_->CreateFPCast(value, target);
588 }
589 }
590
GetConstString(const std::string & str)591 llvm::Value* CodeGenLLVM::GetConstString(const std::string& str) {
592 auto it = str_map_.find(str);
593 if (it != str_map_.end()) return it->second;
594 llvm::Type* type = llvm::ArrayType::get(t_char_, str.length() + 1);
595 llvm::GlobalVariable *global = new llvm::GlobalVariable(
596 *module_, type, true, llvm::GlobalValue::PrivateLinkage, 0, ".str");
597 #if TVM_LLVM_VERSION >= 100
598 global->setAlignment(llvm::Align(1));
599 #else
600 global->setAlignment(1);
601 #endif
602 global->setInitializer(llvm::ConstantDataArray::getString(*ctx_, str));
603 llvm::Constant* zero = ConstInt32(0);
604 llvm::Constant* indices[] = {zero, zero};
605 llvm::Constant* ptr = llvm::ConstantExpr::getGetElementPtr(
606 type, global, indices);
607 str_map_[str] = ptr;
608 return ptr;
609 }
610
CreateBufferPtr(Type t,llvm::Value * buffer,llvm::Value * index)611 llvm::Value* CodeGenLLVM::CreateBufferPtr(
612 Type t, llvm::Value* buffer, llvm::Value* index) {
613 CHECK_EQ(t.lanes(), 1);
614 llvm::PointerType* btype = llvm::dyn_cast<llvm::PointerType>(buffer->getType());
615 CHECK(btype != nullptr);
616 llvm::PointerType* ptype = LLVMType(t)->getPointerTo(btype->getAddressSpace());
617 if (btype != ptype) {
618 buffer = builder_->CreatePointerCast(buffer, ptype);
619 }
620
621 return builder_->CreateInBoundsGEP(buffer, index);
622 }
623
CreateBufferVecPtr(Type t,llvm::Value * buffer,llvm::Value * index)624 llvm::Value* CodeGenLLVM::CreateBufferVecPtr(
625 Type t, llvm::Value* buffer, llvm::Value* index) {
626 CHECK_GT(t.lanes(), 1);
627 llvm::PointerType* btype = llvm::dyn_cast<llvm::PointerType>(buffer->getType());
628 CHECK(btype != nullptr);
629 llvm::PointerType* ptype = LLVMType(t)->getPointerTo(btype->getAddressSpace());
630 if (btype != ptype) {
631 buffer = builder_->CreatePointerCast(buffer, ptype);
632 }
633 return builder_->CreateInBoundsGEP(buffer, index);
634 }
635
GetVarValue(const Variable * v) const636 llvm::Value* CodeGenLLVM::GetVarValue(const Variable* v) const {
637 auto it = var_map_.find(v);
638 CHECK(it != var_map_.end()) << "cannot find variable " << v->name_hint;
639 return it->second;
640 }
641
CreateCallExtern(const Call * op)642 llvm::Value* CodeGenLLVM::CreateCallExtern(const Call* op) {
643 std::vector<llvm::Value*> arg_value;
644 std::vector<llvm::Type*> arg_type;
645 for (size_t i = 0; i < op->args.size(); ++i) {
646 arg_value.push_back(MakeValue(op->args[i]));
647 arg_type.push_back(arg_value.back()->getType());
648 }
649 llvm::FunctionType* ftype = llvm::FunctionType::get(
650 LLVMType(op->type), arg_type, false);
651 llvm::Function* f = module_->getFunction(op->name);
652 if (f == nullptr) {
653 f = llvm::Function::Create(
654 ftype, llvm::Function::ExternalLinkage,
655 op->name, module_.get());
656 }
657 llvm::CallInst* call = builder_->CreateCall(f, arg_value);
658 return call;
659 }
660
CreateIntrinsic(const Call * op)661 llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) {
662 if (op->is_intrinsic("llvm_intrin")) {
663 CHECK_GE(op->args.size(), 2U);
664 llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(
665 op->args[0].as<UIntImm>()->value);
666 const uint64_t *num_signature = as_const_uint(op->args[1]);
667 CHECK(num_signature) << "The second argument should be a uint represents number of arguments, "
668 << "but " << op->args[1] << " got!\n";
669 std::vector<llvm::Value*> arg_value;
670 std::vector<llvm::Type*> sig_type;
671 for (size_t i = 2; i < op->args.size(); ++i) {
672 arg_value.push_back(MakeValue(op->args[i]));
673 if (i - 2 < *num_signature) {
674 sig_type.push_back(arg_value.back()->getType());
675 }
676 }
677 llvm::Type *return_type = LLVMType(op->type);
678 if (sig_type.size() > 0 && return_type != sig_type[0]) {
679 sig_type.insert(sig_type.begin(), return_type);
680 }
681 llvm::Function* f = llvm::Intrinsic::getDeclaration(
682 module_.get(), id, sig_type);
683 return builder_->CreateCall(f, arg_value);
684 } else if (op->is_intrinsic(Call::bitwise_and)) {
685 return builder_->CreateAnd(MakeValue(op->args[0]), MakeValue(op->args[1]));
686 } else if (op->is_intrinsic(Call::bitwise_or)) {
687 return builder_->CreateOr(MakeValue(op->args[0]), MakeValue(op->args[1]));
688 } else if (op->is_intrinsic(Call::bitwise_not)) {
689 return builder_->CreateNot(MakeValue(op->args[0]));
690 } else if (op->is_intrinsic(Call::bitwise_xor)) {
691 return builder_->CreateXor(MakeValue(op->args[0]), MakeValue(op->args[1]));
692 } else if (op->is_intrinsic(Call::shift_left)) {
693 return builder_->CreateShl(MakeValue(op->args[0]), MakeValue(op->args[1]));
694 } else if (op->is_intrinsic(Call::shift_right)) {
695 if (op->args[0].type().is_int()) {
696 return builder_->CreateAShr(MakeValue(op->args[0]), MakeValue(op->args[1]));
697 } else {
698 return builder_->CreateLShr(MakeValue(op->args[0]), MakeValue(op->args[1]));
699 }
700 } else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) {
701 return CreateStorageSync(op);
702 } else if (op->is_intrinsic(intrinsic::tvm_address_of)) {
703 const Load *l = op->args[0].as<Load>();
704 CHECK(op->args.size() == 1 && l);
705 const Ramp *r = l->index.as<Ramp>();
706 llvm::Value* ptr;
707 unsigned addrspace;
708 if (!r) {
709 ptr = CreateBufferPtr(
710 l->type, MakeValue(l->buffer_var), MakeValue(l->index));
711 addrspace = llvm::dyn_cast<llvm::PointerType>(
712 ptr->getType())->getAddressSpace();
713 } else {
714 Expr index = r->base / make_const(Int(32), r->lanes);
715 ptr = CreateBufferVecPtr(
716 l->type, MakeValue(l->buffer_var), MakeValue(index));
717 addrspace = llvm::dyn_cast<llvm::PointerType>(
718 ptr->getType())->getAddressSpace();
719 }
720 return builder_->CreatePointerCast(ptr, t_char_->getPointerTo(addrspace));
721 } else if (op->is_intrinsic(Call::reinterpret) && is_zero(op->args[0])) {
722 return llvm::Constant::getNullValue(t_void_p_);
723 } else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) {
724 return builder_->CreateIsNull(MakeValue(op->args[0]));
725 } else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) {
726 CHECK_EQ(op->args[0].type().lanes(), 1)
727 << "if_then_else can only take scalar condition";
728 using llvm::BasicBlock;
729 BasicBlock* then_block = BasicBlock::Create(
730 *ctx_, "if_then", function_);
731 BasicBlock* else_block = BasicBlock::Create(
732 *ctx_, "if_else", function_);
733 BasicBlock* end_block = BasicBlock::Create(
734 *ctx_, "if_end", function_);
735 builder_->CreateCondBr(MakeValue(op->args[0]), then_block, else_block);
736 builder_->SetInsertPoint(then_block);
737 llvm::Value* then_value = MakeValue(op->args[1]);
738 BasicBlock* then_value_block = builder_->GetInsertBlock();
739 builder_->CreateBr(end_block);
740 builder_->SetInsertPoint(else_block);
741 llvm::Value* else_value = MakeValue(op->args[2]);
742 BasicBlock* else_value_block = builder_->GetInsertBlock();
743 builder_->CreateBr(end_block);
744 builder_->SetInsertPoint(end_block);
745 llvm::PHINode* value = builder_->CreatePHI(then_value->getType(), 2);
746 value->addIncoming(then_value, then_value_block);
747 value->addIncoming(else_value, else_value_block);
748 return value;
749 } else if (op->is_intrinsic(Call::reinterpret)) {
750 llvm::Type * target = LLVMType(op->type);
751 return builder_->CreateBitCast(MakeValue(op->args[0]), target);
752 } else if (op->is_intrinsic(Call::isnan)) {
753 // TODO(hgt312): set fast math flag
754 llvm::Value* a = MakeValue(op->args[0]);
755 return builder_->CreateFCmpUNO(a, a);
756 } else if (op->is_intrinsic("vectorlow")) {
757 llvm::Value *v = MakeValue(op->args[0]);
758 int l = v->getType()->getVectorNumElements();
759 return CreateVecSlice(v, 0, l/2);
760 } else if (op->is_intrinsic("vectorhigh")) {
761 llvm::Value *v = MakeValue(op->args[0]);
762 int l = v->getType()->getVectorNumElements();
763 return CreateVecSlice(v, l/2, l/2);
764 } else if (op->is_intrinsic("vectorcombine")) {
765 llvm::Value *v0 = MakeValue(op->args[0]);
766 llvm::Value *v1 = MakeValue(op->args[1]);
767 int num_elems = static_cast<int>(v0->getType()->getVectorNumElements()) * 2;
768 std::vector<unsigned> indices;
769 for (int i = 0; i < num_elems; ++i) {
770 indices.push_back(i);
771 }
772 return builder_->CreateShuffleVector(v0, v1, indices);
773 } else {
774 LOG(FATAL) << "unknown intrinsic " << op->name;
775 return nullptr;
776 }
777 }
778
Scalarize(const Expr & e,std::function<void (int i,llvm::Value * v)> f)779 void CodeGenLLVM::Scalarize(const Expr& e,
780 std::function<void(int i, llvm::Value* v)> f) {
781 if (const Ramp* ramp = e.as<Ramp>()) {
782 for (int i = 0; i < ramp->type.lanes(); ++i) {
783 Expr offset = ramp->base + (ramp->stride * i);
784 f(i, MakeValue(offset));
785 }
786 } else {
787 llvm::Value* value = MakeValue(e);
788 for (int i = 0; i < e.type().lanes(); ++i) {
789 f(i, builder_->CreateExtractElement(value, i));
790 }
791 }
792 }
793
794
795 // Visitors
VisitExpr_(const Variable * op)796 llvm::Value* CodeGenLLVM::VisitExpr_(const Variable* op) {
797 return GetVarValue(op);
798 }
799
VisitExpr_(const Cast * op)800 llvm::Value* CodeGenLLVM::VisitExpr_(const Cast* op) {
801 return CreateCast(op->value.type(), op->type, MakeValue(op->value));
802 }
VisitExpr_(const IntImm * op)803 llvm::Value* CodeGenLLVM::VisitExpr_(const IntImm* op) {
804 return llvm::ConstantInt::getSigned(LLVMType(op->type), op->value);
805 }
806
VisitExpr_(const UIntImm * op)807 llvm::Value* CodeGenLLVM::VisitExpr_(const UIntImm* op) {
808 return llvm::ConstantInt::get(LLVMType(op->type), op->value);
809 }
810
VisitExpr_(const FloatImm * op)811 llvm::Value* CodeGenLLVM::VisitExpr_(const FloatImm* op) {
812 return llvm::ConstantFP::get(LLVMType(op->type), op->value);
813 }
814
VisitExpr_(const StringImm * op)815 llvm::Value* CodeGenLLVM::VisitExpr_(const StringImm* op) {
816 return GetConstString(op->value);
817 }
818
819 #define DEFINE_CODEGEN_BINARY_OP(Op) \
820 llvm::Value* CodeGenLLVM::Create ## Op( \
821 Type t, llvm::Value* a, llvm::Value *b) { \
822 if (t.is_int()) { \
823 if (t.bits() >= 32) { \
824 return builder_->CreateNSW ## Op (a, b); \
825 } else { \
826 return builder_->Create ## Op (a, b); \
827 } \
828 } else if (t.is_uint()) { \
829 if (t.bits() >= 32) { \
830 return builder_->CreateNUW ## Op (a, b); \
831 } else { \
832 return builder_->Create ## Op (a, b); \
833 } \
834 } else { \
835 CHECK(t.is_float()); \
836 return builder_->CreateF ## Op (a, b); \
837 } \
838 } \
839 llvm::Value* CodeGenLLVM::VisitExpr_(const Op* op) { \
840 return Create ## Op(op->type, MakeValue(op->a), MakeValue(op->b)); \
841 }
842
843 DEFINE_CODEGEN_BINARY_OP(Add);
844 DEFINE_CODEGEN_BINARY_OP(Sub);
845 DEFINE_CODEGEN_BINARY_OP(Mul);
846
847 #define DEFINE_CODEGEN_CMP_OP(Op) \
848 llvm::Value* CodeGenLLVM::Create ## Op( \
849 Type t, llvm::Value* a, llvm::Value* b) { \
850 if (t.is_int()) { \
851 return builder_->CreateICmpS ## Op (a, b); \
852 } else if (t.is_uint()) { \
853 return builder_->CreateICmpU ## Op (a, b); \
854 } else { \
855 CHECK(t.is_float()); \
856 return builder_->CreateFCmpO ## Op (a, b); \
857 } \
858 } \
859 llvm::Value* CodeGenLLVM::VisitExpr_(const Op* op) { \
860 return Create ## Op(op->a.type(), MakeValue(op->a), MakeValue(op->b)); \
861 }
862
863 DEFINE_CODEGEN_CMP_OP(LT);
864 DEFINE_CODEGEN_CMP_OP(LE);
865 DEFINE_CODEGEN_CMP_OP(GT);
866 DEFINE_CODEGEN_CMP_OP(GE);
867
VisitExpr_(const Div * op)868 llvm::Value* CodeGenLLVM::VisitExpr_(const Div* op) {
869 llvm::Value* a = MakeValue(op->a);
870 llvm::Value* b = MakeValue(op->b);
871 if (op->type.is_int()) {
872 return builder_->CreateSDiv(a, b);
873 } else if (op->type.is_uint()) {
874 return builder_->CreateUDiv(a, b);
875 } else {
876 CHECK(op->type.is_float());
877 return builder_->CreateFDiv(a, b);
878 }
879 }
880
VisitExpr_(const Mod * op)881 llvm::Value* CodeGenLLVM::VisitExpr_(const Mod* op) {
882 llvm::Value* a = MakeValue(op->a);
883 llvm::Value* b = MakeValue(op->b);
884 if (op->type.is_int()) {
885 return builder_->CreateSRem(a, b);
886 } else if (op->type.is_uint()) {
887 return builder_->CreateURem(a, b);
888 } else {
889 CHECK(op->type.is_float());
890 return builder_->CreateFRem(a, b);
891 }
892 }
893
VisitExpr_(const Min * op)894 llvm::Value* CodeGenLLVM::VisitExpr_(const Min* op) {
895 llvm::Value* a = MakeValue(op->a);
896 llvm::Value* b = MakeValue(op->b);
897 return builder_->CreateSelect(CreateLT(op->a.type(), a, b), a, b);
898 }
899
VisitExpr_(const Max * op)900 llvm::Value* CodeGenLLVM::VisitExpr_(const Max* op) {
901 llvm::Value* a = MakeValue(op->a);
902 llvm::Value* b = MakeValue(op->b);
903 return builder_->CreateSelect(CreateGT(op->a.type(), a, b), a, b);
904 }
905
VisitExpr_(const EQ * op)906 llvm::Value* CodeGenLLVM::VisitExpr_(const EQ* op) {
907 llvm::Value* a = MakeValue(op->a);
908 llvm::Value* b = MakeValue(op->b);
909 if (op->a.type().is_int() || op->a.type().is_uint()) {
910 return builder_->CreateICmpEQ(a, b);
911 } else {
912 return builder_->CreateFCmpOEQ(a, b);
913 }
914 }
915
VisitExpr_(const NE * op)916 llvm::Value* CodeGenLLVM::VisitExpr_(const NE* op) {
917 llvm::Value* a = MakeValue(op->a);
918 llvm::Value* b = MakeValue(op->b);
919 if (op->a.type().is_int() || op->a.type().is_uint()) {
920 return builder_->CreateICmpNE(a, b);
921 } else {
922 return builder_->CreateFCmpONE(a, b);
923 }
924 }
925
VisitExpr_(const And * op)926 llvm::Value* CodeGenLLVM::VisitExpr_(const And* op) {
927 return builder_->CreateAnd(MakeValue(op->a), MakeValue(op->b));
928 }
929
VisitExpr_(const Or * op)930 llvm::Value* CodeGenLLVM::VisitExpr_(const Or* op) {
931 return builder_->CreateOr(MakeValue(op->a), MakeValue(op->b));
932 }
933
VisitExpr_(const Not * op)934 llvm::Value* CodeGenLLVM::VisitExpr_(const Not* op) {
935 return builder_->CreateNot(MakeValue(op->a));
936 }
937
VisitExpr_(const Select * op)938 llvm::Value* CodeGenLLVM::VisitExpr_(const Select* op) {
939 return builder_->CreateSelect(
940 MakeValue(op->condition),
941 MakeValue(op->true_value),
942 MakeValue(op->false_value));
943 }
944
VisitExpr_(const Let * op)945 llvm::Value* CodeGenLLVM::VisitExpr_(const Let* op) {
946 CHECK(!var_map_.count(op->var.get()));
947 var_map_[op->var.get()] = MakeValue(op->value);
948 analyzer_->Bind(op->var, op->value);
949 return MakeValue(op->body);
950 }
951
VisitExpr_(const Load * op)952 llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) {
953 Type t = op->type;
954 bool is_volatile = volatile_buf_.count(op->buffer_var.get());
955 llvm::Value* buffer = MakeValue(op->buffer_var);
956 llvm::Value* index = MakeValue(op->index);
957
958 if (t.lanes() == 1) {
959 int alignment, native_bits;
960 GetAlignment(t, op->buffer_var.get(), op->index, &alignment, &native_bits);
961 llvm::Value* ptr = CreateBufferPtr(t, buffer, index);
962 llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, alignment, is_volatile);
963 AddAliasInfo(load, op->buffer_var.get(), op->index, t);
964 return load;
965 } else {
966 // vector load
967 unsigned addrspace = llvm::dyn_cast<llvm::PointerType>(
968 buffer->getType())->getAddressSpace();
969 if (const Ramp* ramp = op->index.as<Ramp>()) {
970 if (is_one(ramp->stride)) {
971 int alignment, native_bits;
972 GetAlignment(t, op->buffer_var.get(), ramp->base, &alignment, &native_bits);
973 CHECK_EQ(ramp->lanes, t.lanes());
974 llvm::Value* ptr = CreateBufferPtr(
975 t.element_of(), buffer, MakeValue(ramp->base));
976 ptr = builder_->CreatePointerCast(ptr, LLVMType(t)->getPointerTo(addrspace));
977 llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, alignment, is_volatile);
978 AddAliasInfo(load, op->buffer_var.get(), op->index, t);
979 return load;
980 }
981 }
982 }
983 // scalarized load.
984 int basic_align = t.bits() / 8;
985 llvm::Value* ret = llvm::UndefValue::get(LLVMType(t));
986 auto f = [&](int i, llvm::Value* index) {
987 llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, index);
988 llvm::LoadInst* load = builder_->CreateAlignedLoad(
989 ptr, basic_align, is_volatile);
990 ret = builder_->CreateInsertElement(ret, load, ConstInt32(i));
991 AddAliasInfo(load, op->buffer_var.get(), Expr(), t);
992 };
993 this->Scalarize(op->index, f);
994 return ret;
995 }
996
VisitExpr_(const Call * op)997 llvm::Value* CodeGenLLVM::VisitExpr_(const Call* op) {
998 if (op->call_type == Call::Intrinsic ||
999 op->call_type == Call::PureIntrinsic) {
1000 return CreateIntrinsic(op);
1001 } else if (op->call_type == Call::Extern ||
1002 op->call_type == Call::PureExtern) {
1003 return CreateCallExtern(op);
1004 } else {
1005 LOG(FATAL) << "Unknown call type " <<
1006 "name= " << op->name <<
1007 " call_type= " << op->call_type;
1008 return nullptr;
1009 }
1010 }
1011
VisitExpr_(const Ramp * op)1012 llvm::Value* CodeGenLLVM::VisitExpr_(const Ramp* op) {
1013 llvm::Value* vec = llvm::UndefValue::get(LLVMType(op->type));
1014 for (int i = 0; i < op->lanes; ++i) {
1015 vec = builder_->CreateInsertElement(
1016 vec, MakeValue(op->base + op->stride * make_const(op->stride.type(), i)),
1017 ConstInt32(i));
1018 }
1019 return vec;
1020 }
1021
VisitExpr_(const Shuffle * op)1022 llvm::Value* CodeGenLLVM::VisitExpr_(const Shuffle* op) {
1023 std::vector<llvm::Value *> vecs(op->vectors.size());
1024 int total_lanes = 0;
1025 for (int i = 0, e = op->vectors.size(); i < e; ++i) {
1026 vecs[i] = VisitExpr(op->vectors[i]);
1027 total_lanes += op->vectors[i].type().lanes();
1028 }
1029 llvm::Value* v0 = CreateVecConcat(vecs);
1030 std::vector<uint32_t> idx(op->indices.size());
1031 for (int i = 0, e = op->indices.size(); i < e; ++i) {
1032 const int64_t *val = as_const_int(op->indices[i]);
1033 CHECK(val && *val >= 0 && *val < total_lanes) << "Shuffled indeces are suppose to be int, "
1034 << "but get " << op->indices[i] << "\n";
1035 idx[i] = *val;
1036 }
1037 llvm::Value* mask = llvm::ConstantDataVector::get(builder_->getContext(), idx);
1038 auto res = builder_->CreateShuffleVector(v0, llvm::UndefValue::get(v0->getType()), mask);
1039 return res;
1040 }
1041
VisitExpr_(const Broadcast * op)1042 llvm::Value* CodeGenLLVM::VisitExpr_(const Broadcast* op) {
1043 return CreateBroadcast(MakeValue(op->value), op->lanes);
1044 }
1045
VisitStmt_(const Store * op)1046 void CodeGenLLVM::VisitStmt_(const Store* op) {
1047 CHECK(is_one(op->predicate));
1048 Type t = op->value.type();
1049 bool is_volatile = volatile_buf_.count(op->buffer_var.get());
1050 llvm::Value* buffer = MakeValue(op->buffer_var);
1051 llvm::Value* index = MakeValue(op->index);
1052 llvm::Value* value = MakeValue(op->value);
1053
1054 if (t.lanes() == 1) {
1055 int alignment, native_bits;
1056 GetAlignment(t, op->buffer_var.get(), op->index, &alignment, &native_bits);
1057 llvm::Value* ptr = CreateBufferPtr(t, buffer, index);
1058 llvm::StoreInst* store = builder_->CreateAlignedStore(value, ptr, alignment, is_volatile);
1059 AddAliasInfo(store, op->buffer_var.get(), op->index, op->value.type());
1060 return;
1061 } else {
1062 // vector store
1063 unsigned addrspace = llvm::dyn_cast<llvm::PointerType>(
1064 buffer->getType())->getAddressSpace();
1065 if (const Ramp* ramp = op->index.as<Ramp>()) {
1066 if (is_one(ramp->stride)) {
1067 int alignment, native_bits;
1068 GetAlignment(t, op->buffer_var.get(), ramp->base, &alignment, &native_bits);
1069 CHECK_EQ(ramp->lanes, t.lanes());
1070 llvm::Value* ptr = CreateBufferPtr(
1071 t.element_of(), buffer, MakeValue(ramp->base));
1072 ptr = builder_->CreatePointerCast(ptr, LLVMType(t)->getPointerTo(addrspace));
1073 llvm::StoreInst* store = builder_->CreateAlignedStore(value, ptr, alignment, is_volatile);
1074 AddAliasInfo(store, op->buffer_var.get(), op->index, op->value.type());
1075 return;
1076 }
1077 }
1078 }
1079 CHECK_GE(t.bits(), 8);
1080 // scalarized store.
1081 int basic_align = t.bits() / 8;
1082 auto f = [&](int i, llvm::Value* index) {
1083 llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, index);
1084 llvm::StoreInst* store = builder_->CreateAlignedStore(
1085 builder_->CreateExtractElement(value, i),
1086 ptr, basic_align, is_volatile);
1087 AddAliasInfo(store, op->buffer_var.get(), Expr(), op->value.type());
1088 };
1089 this->Scalarize(op->index, f);
1090 }
1091
VisitStmt_(const For * op)1092 void CodeGenLLVM::VisitStmt_(const For* op) {
1093 CHECK(is_zero(op->min));
1094 analyzer_->Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent));
1095 if (op->for_type == ForType::Unrolled) {
1096 LOG(WARNING) << "Unroll hint get ignore at CodeGenLLVM backend, "
1097 << " consider set unroll_explicit=True";
1098 } else {
1099 CHECK(op->for_type == ForType::Serial);
1100 }
1101 CreateSerialFor(MakeValue(op->min), MakeValue(op->extent),
1102 ConstInt32(1), op->loop_var, op->body);
1103 }
1104
1105
VisitStmt_(const IfThenElse * op)1106 void CodeGenLLVM::VisitStmt_(const IfThenElse* op) {
1107 using llvm::BasicBlock;
1108 llvm::Value* cond = MakeValue(op->condition);
1109 BasicBlock* then_block = BasicBlock::Create(
1110 *ctx_, "if_then", function_);
1111 BasicBlock* end_block = BasicBlock::Create(
1112 *ctx_, "if_end", function_);
1113 if (op->else_case.defined()) {
1114 BasicBlock* else_block = BasicBlock::Create(
1115 *ctx_, "if_else", function_);
1116 builder_->CreateCondBr(cond, then_block, else_block);
1117 builder_->SetInsertPoint(then_block);
1118 this->VisitStmt(op->then_case);
1119 builder_->CreateBr(end_block);
1120 builder_->SetInsertPoint(else_block);
1121 this->VisitStmt(op->else_case);
1122 builder_->CreateBr(end_block);
1123 } else {
1124 builder_->CreateCondBr(cond, then_block, end_block, md_very_likely_branch_);
1125 builder_->SetInsertPoint(then_block);
1126 this->VisitStmt(op->then_case);
1127 builder_->CreateBr(end_block);
1128 }
1129 builder_->SetInsertPoint(end_block);
1130 }
1131
1132
VisitStmt_(const Allocate * op)1133 void CodeGenLLVM::VisitStmt_(const Allocate* op) {
1134 CHECK(!is_zero(op->condition));
1135 llvm::Value* buf = nullptr;
1136 if (op->new_expr.defined()) {
1137 CHECK_EQ(op->free_function, "nop");
1138 buf = MakeValue(op->new_expr);
1139 } else {
1140 int32_t constant_size = op->constant_allocation_size();
1141 CHECK_GT(constant_size, 0)
1142 << "Can only handle constant size stack allocation";
1143 StorageInfo& info = alloc_storage_info_[op->buffer_var.get()];
1144 if (constant_size % 4 == 0 && info.alignment == 0) {
1145 info.alignment = GetTempAllocaAlignment(op->type, constant_size);
1146 }
1147 // maximum necessary alignment in the NV devices
1148 if (info.alignment > 16) {
1149 info.alignment = 16;
1150 }
1151 llvm::AllocaInst* alloca = WithFunctionEntry([&]() {
1152 return builder_->CreateAlloca(
1153 LLVMType(op->type), ConstInt32(constant_size));
1154 });
1155 if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) {
1156 #if TVM_LLVM_VERSION >= 100
1157 alloca->setAlignment(llvm::Align(info.alignment));
1158 #else
1159 alloca->setAlignment(info.alignment);
1160 #endif
1161 }
1162 info.alignment = alloca->getAlignment();
1163 buf = alloca;
1164 }
1165 buf = builder_->CreatePointerCast(
1166 buf, LLVMType(op->type)->getPointerTo(
1167 buf->getType()->getPointerAddressSpace()));
1168 CHECK(!var_map_.count(op->buffer_var.get()));
1169 var_map_[op->buffer_var.get()] = buf;
1170 this->VisitStmt(op->body);
1171 }
1172
VisitStmt_(const AttrStmt * op)1173 void CodeGenLLVM::VisitStmt_(const AttrStmt* op) {
1174 if (op->attr_key == attr::thread_extent) {
1175 IterVar iv = Downcast<IterVar>(op->node);
1176 if (iv->thread_tag.length() != 0) {
1177 if (!var_map_.count(iv->var.get())) {
1178 var_map_[iv->var.get()] = GetThreadIndex(iv);
1179 analyzer_->Bind(iv->var, Range::make_by_min_extent(0, op->value));
1180 }
1181 }
1182 } else if (op->attr_key == ir::attr::storage_scope) {
1183 const Variable* v = op->node.as<Variable>();
1184 CHECK(v);
1185 alloc_storage_info_[v].scope =
1186 runtime::StorageScope::make(op->value.as<StringImm>()->value);
1187 } else if (op->attr_key == ir::attr::storage_alignment) {
1188 const Variable* v = op->node.as<Variable>();
1189 CHECK(v);
1190 alloc_storage_info_[v].alignment =
1191 static_cast<int>(op->value.as<IntImm>()->value);
1192 } else if (op->attr_key == ir::attr::volatile_scope) {
1193 const Variable* v = op->node.as<Variable>();
1194 CHECK(v);
1195 volatile_buf_.insert(v);
1196 }
1197 this->VisitStmt(op->body);
1198 }
1199
VisitStmt_(const AssertStmt * op)1200 void CodeGenLLVM::VisitStmt_(const AssertStmt* op) {
1201 With<arith::ConstraintContext> cctx(analyzer_.get(), op->condition);
1202 this->VisitStmt(op->body);
1203 }
1204
VisitStmt_(const LetStmt * op)1205 void CodeGenLLVM::VisitStmt_(const LetStmt* op) {
1206 CHECK(!var_map_.count(op->var.get()));
1207 if (op->var.type().is_handle()) {
1208 if (!is_restricted_) {
1209 alias_var_set_.insert(op->var.get());
1210 }
1211 }
1212 var_map_[op->var.get()] = MakeValue(op->value);
1213 analyzer_->Bind(op->var, op->value);
1214 this->VisitStmt(op->body);
1215 }
1216
VisitStmt_(const Block * op)1217 void CodeGenLLVM::VisitStmt_(const Block* op) {
1218 this->VisitStmt(op->first);
1219 if (op->rest.defined()) {
1220 this->VisitStmt(op->rest);
1221 }
1222 }
1223
VisitStmt_(const Evaluate * op)1224 void CodeGenLLVM::VisitStmt_(const Evaluate* op) {
1225 MakeValue(op->value);
1226 }
1227
VisitStmt_(const ProducerConsumer * op)1228 void CodeGenLLVM::VisitStmt_(const ProducerConsumer* op) {
1229 this->VisitStmt(op->body);
1230 }
1231 } // namespace codegen
1232 } // namespace tvm
1233 #endif // TVM_LLVM_VERSION
1234