1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file codegen_llvm.cc
22 */
23 #ifdef TVM_LLVM_VERSION
24 // Part of the code are adapted from Halide's CodeGen_LLVM
25 #include "codegen_llvm.h"
26
27 #include <tvm/runtime/c_runtime_api.h>
28 #include <tvm/runtime/device_api.h>
29 #include <tvm/tir/op.h>
30
31 #include <algorithm>
32
33 #include "../../arith/pattern_match.h"
34 #include "../build_common.h"
35 #include "codegen_cpu.h"
36 namespace tvm {
37 namespace codegen {
38
Create(llvm::TargetMachine * tm)39 std::unique_ptr<CodeGenLLVM> CodeGenLLVM::Create(llvm::TargetMachine* tm) {
40 std::string target = tm->getTarget().getName();
41 std::string factory_name = "tvm.codegen.llvm.target_" + target;
42 const PackedFunc* f = runtime::Registry::Get(factory_name);
43 if (f != nullptr) {
44 void* handle = (*f)();
45 return std::unique_ptr<CodeGenLLVM>(static_cast<CodeGenLLVM*>(handle));
46 } else {
47 return std::unique_ptr<CodeGenLLVM>(new CodeGenCPU());
48 }
49 }
50
Init(const std::string & module_name,llvm::TargetMachine * tm,llvm::LLVMContext * ctx,bool system_lib,bool dynamic_lookup,bool target_c_runtime)51 void CodeGenLLVM::Init(const std::string& module_name, llvm::TargetMachine* tm,
52 llvm::LLVMContext* ctx, bool system_lib, bool dynamic_lookup,
53 bool target_c_runtime) {
54 InitializeLLVM();
55 ctx_ = ctx;
56 builder_.reset(new IRBuilder(*ctx_));
57 module_.reset(new llvm::Module(module_name, *ctx_));
58 md_builder_.reset(new llvm::MDBuilder(*ctx_));
59 // types
60 t_void_ = llvm::Type::getVoidTy(*ctx_);
61 t_void_p_ = llvm::Type::getInt8Ty(*ctx_)->getPointerTo();
62 t_int_ = llvm::Type::getInt32Ty(*ctx_);
63 t_char_ = llvm::Type::getInt8Ty(*ctx_);
64 t_int8_ = llvm::Type::getInt8Ty(*ctx_);
65 t_int16_ = llvm::Type::getInt16Ty(*ctx_);
66 t_int32_ = llvm::Type::getInt32Ty(*ctx_);
67 t_int64_ = llvm::Type::getInt64Ty(*ctx_);
68 t_float64_ = llvm::Type::getDoubleTy(*ctx_);
69 // meta data
70 md_very_likely_branch_ = md_builder_->createBranchWeights(1 << 20, 1);
71 md_tbaa_root_ = md_builder_->createTBAARoot("tvm-tbaa");
72 md_tbaa_alias_set_ = md_builder_->createTBAANode("tvm-alias", md_tbaa_root_);
73 this->InitTarget(tm);
74 }
75
InitTarget(llvm::TargetMachine * tm)76 void CodeGenLLVM::InitTarget(llvm::TargetMachine* tm) {
77 module_->setTargetTriple(tm->getTargetTriple().str());
78 module_->setDataLayout(tm->createDataLayout());
79 data_layout_.reset(new llvm::DataLayout(module_.get()));
80 target_machine_ = tm;
81 if (native_vector_bits_ == 0) {
82 const auto& arch = tm->getTargetTriple().getArch();
83 if (arch == llvm::Triple::x86_64) {
84 // for avx512
85 native_vector_bits_ = 512;
86 } else if (arch == llvm::Triple::x86) {
87 native_vector_bits_ = 256;
88 } else if (arch == llvm::Triple::arm || arch == llvm::Triple::aarch64) {
89 native_vector_bits_ = 128;
90 } else {
91 native_vector_bits_ = 128;
92 std::string arch_name = std::string(tm->getTargetTriple().getArchName());
93 LOG(WARNING) << "Set native vector bits to be 128 for " << arch_name;
94 }
95 }
96 }
97
AddFunction(const PrimFunc & f)98 void CodeGenLLVM::AddFunction(const PrimFunc& f) { this->AddFunctionInternal(f, false); }
99
InitFuncState()100 void CodeGenLLVM::InitFuncState() {
101 var_map_.clear();
102 alias_var_set_.clear();
103 alloc_storage_info_.clear();
104 volatile_buf_.clear();
105 analyzer_.reset(new arith::Analyzer());
106 }
107
AddFunctionInternal(const PrimFunc & f,bool ret_void)108 void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) {
109 this->InitFuncState();
110
111 CHECK_EQ(f->buffer_map.size(), 0U)
112 << "Cannot codegen function with buffer_map, please lower them first";
113
114 std::vector<llvm::Type*> param_types;
115 is_restricted_ = f->HasNonzeroAttr(tir::attr::kNoAlias);
116 for (Var param : f->params) {
117 param_types.push_back(GetLLVMType(param));
118 if (!is_restricted_ && param.dtype().is_handle()) {
119 alias_var_set_.insert(param.get());
120 }
121 }
122 // TODO(tvm-team):
123 // Update the function type to respect the ret_type field of f.
124 // Once we allow more flexibility in the PrimFunc.
125 llvm::FunctionType* ftype =
126 llvm::FunctionType::get(ret_void ? t_void_ : t_int_, param_types, false);
127
128 auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
129 CHECK(global_symbol.defined())
130 << "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute";
131 CHECK(module_->getFunction(static_cast<std::string>(global_symbol.value())) == nullptr)
132 << "Function " << global_symbol << " already exist in module";
133
134 function_ = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage,
135 global_symbol.value().operator std::string(), module_.get());
136 function_->setCallingConv(llvm::CallingConv::C);
137 function_->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass);
138
139 // set var map and align information
140 auto arg_it = function_->arg_begin();
141 for (size_t i = 0; i < f->params.size(); ++i, ++arg_it) {
142 llvm::Argument* v = &(*arg_it);
143 const Var& var = f->params[i];
144 var_map_[var.get()] = v;
145 if (is_restricted_) {
146 if (var.dtype().is_handle() && !alias_var_set_.count(var.get())) {
147 // set non alias.
148 #if TVM_LLVM_VERSION >= 50
149 function_->addParamAttr(i, llvm::Attribute::NoAlias);
150 #else
151 function_->setDoesNotAlias(i + 1);
152 #endif
153 }
154 }
155 }
156 llvm::BasicBlock* entry = llvm::BasicBlock::Create(*ctx_, "entry", function_);
157 builder_->SetInsertPoint(entry);
158 this->VisitStmt(f->body);
159
160 // Add alignment attribute if needed.
161 #if TVM_LLVM_VERSION >= 50
162 for (size_t i = 0; i < f->params.size(); ++i) {
163 const Var& var = f->params[i];
164 auto f = alloc_storage_info_.find(var.get());
165 if (f != alloc_storage_info_.end()) {
166 unsigned align = f->second.alignment;
167 if (align > 1) {
168 auto attr = llvm::Attribute::get(*ctx_, llvm::Attribute::Alignment, align);
169 function_->addParamAttr(i, attr);
170 }
171 }
172 }
173 #endif
174
175 if (ret_void) {
176 builder_->CreateRetVoid();
177 } else {
178 builder_->CreateRet(ConstInt32(0));
179 }
180 }
181
Finish()182 std::unique_ptr<llvm::Module> CodeGenLLVM::Finish() {
183 this->AddStartupFunction();
184 for (size_t i = 0; i < link_modules_.size(); ++i) {
185 CHECK(!llvm::Linker::linkModules(*module_, std::move(link_modules_[i])))
186 << "Failed to link modules";
187 }
188 link_modules_.clear();
189 // optimize
190 this->Optimize();
191 return std::move(module_);
192 }
193
HandleImport(const std::string & code)194 void CodeGenLLVM::HandleImport(const std::string& code) {
195 std::unique_ptr<llvm::Module> mlib;
196 llvm::SMDiagnostic err;
197 if (code.length() >= 3 &&
198 (code.substr(code.length() - 3) == ".ll" || code.substr(code.length() - 3) == ".bc")) {
199 mlib = llvm::parseIRFile(code, err, *ctx_);
200 if (mlib.get() == nullptr) {
201 std::string msg = std::string(err.getMessage());
202 LOG(FATAL) << "Fail to load bitcode file " << code << "\n"
203 << "line " << err.getLineNo() << ":" << msg;
204 }
205 } else {
206 std::unique_ptr<llvm::MemoryBuffer> buf = llvm::MemoryBuffer::getMemBuffer(code);
207 mlib = llvm::parseIR(*buf, err, *ctx_);
208 if (mlib.get() == nullptr) {
209 std::string msg = std::string(err.getMessage());
210 LOG(FATAL) << "Fail to load llvm ir "
211 << "line " << err.getLineNo() << ":" << msg << "\ncontent:\n"
212 << code;
213 }
214 }
215 mlib->setTargetTriple(target_machine_->getTargetTriple().str());
216 mlib->setDataLayout(target_machine_->createDataLayout());
217 // mark all the functions as force inline
218 for (llvm::Function& f : mlib->functions()) {
219 f.removeFnAttr(llvm::Attribute::NoInline);
220 f.addFnAttr(llvm::Attribute::AlwaysInline);
221 f.setLinkage(llvm::GlobalValue::AvailableExternallyLinkage);
222 }
223 // add to linker libraries.
224 this->AddLinkModule(std::move(mlib));
225 }
226
AddLinkModule(std::unique_ptr<llvm::Module> && mod)227 void CodeGenLLVM::AddLinkModule(std::unique_ptr<llvm::Module>&& mod) {
228 link_modules_.emplace_back(std::move(mod));
229 }
230
AddMainFunction(const std::string & entry_func_name)231 void CodeGenLLVM::AddMainFunction(const std::string& entry_func_name) {
232 LOG(FATAL) << "not implemented";
233 }
234
GetThreadIndex(const IterVar & iv)235 llvm::Value* CodeGenLLVM::GetThreadIndex(const IterVar& iv) {
236 LOG(FATAL) << "not implemented";
237 return nullptr;
238 }
239
CreateStorageSync(const CallNode * op)240 llvm::Value* CodeGenLLVM::CreateStorageSync(const CallNode* op) {
241 LOG(FATAL) << "not implemented";
242 return nullptr;
243 }
244
245 class FPassManager : public llvm::legacy::FunctionPassManager {
246 public:
FPassManager(llvm::Module * m)247 explicit FPassManager(llvm::Module* m) : llvm::legacy::FunctionPassManager(m) {}
248 // override add to allow messaging
add(llvm::Pass * p)249 void add(llvm::Pass* p) final { llvm::legacy::FunctionPassManager::add(p); }
250 };
251
252 class MPassManager : public llvm::legacy::PassManager {
253 public:
254 // override add to allow messaging
add(llvm::Pass * p)255 void add(llvm::Pass* p) final { llvm::legacy::PassManager::add(p); }
256 };
257
InitPassManagerBuilder(llvm::PassManagerBuilder * builder)258 void CodeGenLLVM::InitPassManagerBuilder(llvm::PassManagerBuilder* builder) {}
259
Optimize()260 void CodeGenLLVM::Optimize() {
261 // pass manager
262 FPassManager fpass(module_.get());
263 MPassManager mpass;
264 mpass.add(llvm::createTargetTransformInfoWrapperPass(
265 target_machine_ ? target_machine_->getTargetIRAnalysis() : llvm::TargetIRAnalysis()));
266 fpass.add(llvm::createTargetTransformInfoWrapperPass(
267 target_machine_ ? target_machine_->getTargetIRAnalysis() : llvm::TargetIRAnalysis()));
268
269 // place optimization pass
270 llvm::PassManagerBuilder builder;
271 builder.OptLevel = 3;
272
273 #if TVM_LLVM_VERSION >= 50
274 builder.Inliner = llvm::createFunctionInliningPass(builder.OptLevel, 0, false);
275 #else
276 builder.Inliner = llvm::createFunctionInliningPass(builder.OptLevel, 0);
277 #endif
278 builder.LoopVectorize = true;
279 builder.SLPVectorize = true;
280 this->InitPassManagerBuilder(&builder);
281
282 #if TVM_LLVM_VERSION >= 50
283 target_machine_->adjustPassManager(builder);
284 #endif
285
286 builder.populateFunctionPassManager(fpass);
287 builder.populateModulePassManager(mpass);
288
289 fpass.doInitialization();
290 for (auto it = module_->begin(); it != module_->end(); ++it) {
291 fpass.run(*it);
292 }
293 fpass.doFinalization();
294 mpass.run(*module_);
295 }
296
NativeVectorBits(const runtime::StorageScope & storage_scope) const297 int CodeGenLLVM::NativeVectorBits(const runtime::StorageScope& storage_scope) const {
298 return native_vector_bits_;
299 }
300
GetGlobalAddressSpace() const301 unsigned CodeGenLLVM::GetGlobalAddressSpace() const { return 0; }
302
DTypeToLLVMType(const DataType & dtype) const303 llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const {
304 if (dtype.is_handle()) {
305 CHECK_EQ(dtype.lanes(), 1);
306 return t_void_p_;
307 }
308 if (dtype.is_void()) {
309 return t_void_;
310 }
311 llvm::Type* etype = nullptr;
312 if (dtype.is_int() || dtype.is_uint()) {
313 etype = llvm::Type::getIntNTy(*ctx_, dtype.bits());
314 } else if (dtype.is_float()) {
315 switch (dtype.bits()) {
316 case 16:
317 etype = llvm::Type::getHalfTy(*ctx_);
318 break;
319 case 32:
320 etype = llvm::Type::getFloatTy(*ctx_);
321 break;
322 case 64:
323 etype = llvm::Type::getDoubleTy(*ctx_);
324 break;
325 default:
326 LOG(FATAL) << "do not support " << dtype;
327 }
328 }
329 if (dtype.lanes() != 1) {
330 #if TVM_LLVM_VERSION >= 110
331 return llvm::FixedVectorType::get(etype, dtype.lanes());
332 #else
333 return llvm::VectorType::get(etype, dtype.lanes());
334 #endif
335 } else {
336 return etype;
337 }
338 }
339
GetLLVMType(const Type & type) const340 llvm::Type* CodeGenLLVM::GetLLVMType(const Type& type) const {
341 if (auto* ptr = type.as<PrimTypeNode>()) {
342 return DTypeToLLVMType(ptr->dtype);
343 } else if (auto* ptr = type.as<PointerTypeNode>()) {
344 // TODO(tvm-team) consider put storage scope into the pointer type.
345 return GetLLVMType(ptr->element_type)->getPointerTo(GetGlobalAddressSpace());
346 } else if (IsVoidType(type)) {
347 return t_void_;
348 } else {
349 LOG(FATAL) << "Type " << type << " does not have a corresponding LLVM Type";
350 return t_void_;
351 }
352 }
353
GetLLVMType(const PrimExpr & expr) const354 llvm::Type* CodeGenLLVM::GetLLVMType(const PrimExpr& expr) const {
355 return GetLLVMType(GetType(expr));
356 }
357
358 // Add tbaa alias information for load
359 //
360 // use a binary tree typed system to declare information
361 // and allow alias to be distinguished across nodes.
362 //
363 // This trick comes from Halide's CodeGen_LLVM
364 //
AddAliasInfo(llvm::Instruction * inst,const VarNode * buffer,PrimExpr index)365 void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst, const VarNode* buffer, PrimExpr index) {
366 if (alias_var_set_.count(buffer) != 0) {
367 // Mark all possibly aliased pointer as same type.
368 llvm::MDNode* meta = md_tbaa_alias_set_;
369 inst->setMetadata("tbaa", md_builder_->createTBAAStructTagNode(meta, meta, 0));
370 return;
371 }
372
373 int64_t base = 0, width = 0;
374 arith::PVar<IntImm> pbase, pstride;
375 arith::PVar<int> planes;
376 // create meta-data for alias analysis
377 // Use a group of binary tree ranges of memory banks.
378 if (index.defined()) {
379 if (arith::ramp(pbase, pstride, planes).Match(index)) {
380 base = pbase.Eval()->value;
381 int64_t xwith = planes.Eval() * pstride.Eval()->value;
382 width = 1;
383 while (width < xwith) {
384 width *= 2;
385 }
386 while (base % width) {
387 base -= base % width;
388 width *= 2;
389 }
390 } else if (auto* ptr = index.as<tir::IntImmNode>()) {
391 width = 1;
392 base = ptr->value;
393 }
394 }
395 llvm::MDNode* meta = md_tbaa_root_;
396 std::ostringstream buffer_addr;
397 buffer_addr << buffer;
398 meta = md_builder_->createTBAAScalarTypeNode(buffer_addr.str(), meta);
399
400 // Extract the underlying type of the allocated buffer.
401 llvm::Type* buf_type = GetVarValue(buffer)->getType()->getScalarType();
402 if (buf_type->isPointerTy()) {
403 buf_type = buf_type->getPointerElementType();
404 }
405
406 std::string tmp;
407 llvm::raw_string_ostream buffer_type(tmp);
408 buffer_type << *buf_type;
409 meta = md_builder_->createTBAAScalarTypeNode(buffer_type.str(), meta);
410
411 // create a tree-shape access structure.
412 if (width != 0) {
413 for (int64_t w = 1024; w >= width; w /= 2) {
414 int64_t b = (base / w) * w;
415 std::stringstream os;
416 os << buffer << ".w" << w << ".b" << b;
417 meta = md_builder_->createTBAAScalarTypeNode(os.str(), meta);
418 }
419 }
420 inst->setMetadata("tbaa", md_builder_->createTBAAStructTagNode(meta, meta, 0));
421 }
422
GetAlignment(DataType t,const VarNode * buf_var,const PrimExpr & index,int * p_alignment,int * p_native_bits)423 void CodeGenLLVM::GetAlignment(DataType t, const VarNode* buf_var, const PrimExpr& index,
424 int* p_alignment, int* p_native_bits) {
425 int max_align_bits = t.bits();
426 auto it = alloc_storage_info_.find(buf_var);
427 if (it != alloc_storage_info_.end()) {
428 const StorageInfo& info = it->second;
429 *p_native_bits = NativeVectorBits(info.scope);
430 max_align_bits = info.alignment * 8;
431 } else {
432 *p_native_bits = native_vector_bits_;
433 }
434
435 arith::ModularSet me = analyzer_->modular_set(index);
436 int64_t base = me->base;
437 int64_t coeff = me->coeff;
438
439 int align_bits = t.bits();
440 while (align_bits < max_align_bits && base % 2 == 0 && coeff % 2 == 0) {
441 base = base / 2;
442 coeff = coeff / 2;
443 align_bits *= 2;
444 }
445 if (align_bits < 8) {
446 align_bits = 8;
447 }
448 *p_alignment = align_bits / 8;
449 }
450
CreateDebugInfo(llvm::Module * module)451 std::unique_ptr<CodeGenLLVM::DebugInfo> CodeGenLLVM::CreateDebugInfo(llvm::Module* module) {
452 #if TVM_LLVM_VERSION >= 100
453 auto debug_info = std::make_unique<CodeGenLLVM::DebugInfo>();
454 debug_info->di_builder_ = std::make_unique<llvm::DIBuilder>(*module);
455 #else
456 auto debug_info = llvm::make_unique<CodeGenLLVM::DebugInfo>();
457 debug_info->di_builder_ = llvm::make_unique<llvm::DIBuilder>(*module);
458 #endif
459 // TODO(tulloch): pass this information through relay::Span classes to the IRModule instance?
460 debug_info->file_ = debug_info->di_builder_->createFile("model.tvm", "/tmp/");
461 debug_info->compilation_unit_ = debug_info->di_builder_->createCompileUnit(
462 llvm::dwarf::DW_LANG_C, debug_info->file_, "TVM", 0, "", 0, "",
463 llvm::DICompileUnit::DebugEmissionKind::FullDebug,
464 /* SplitDebugInlining */ true,
465 /* DebugInfoForProfiling */ true);
466 return debug_info;
467 }
468
CreateBroadcast(llvm::Value * value,int lanes)469 llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) {
470 #if TVM_LLVM_VERSION >= 110
471 llvm::Type* type = llvm::FixedVectorType::get(value->getType(), lanes);
472 #else
473 llvm::Type* type = llvm::VectorType::get(value->getType(), lanes);
474 #endif
475 llvm::Constant* undef = llvm::UndefValue::get(type);
476 llvm::Constant* zero = ConstInt32(0);
477 value = builder_->CreateInsertElement(undef, value, zero);
478 #if TVM_LLVM_VERSION >= 110
479 llvm::Constant* mask =
480 llvm::ConstantVector::getSplat(llvm::ElementCount(lanes, /*Scalable=*/false), zero);
481 #else
482 llvm::Constant* mask = llvm::ConstantVector::getSplat(lanes, zero);
483 #endif
484 return builder_->CreateShuffleVector(value, undef, mask);
485 }
486
CreateVecSlice(llvm::Value * vec,int begin,int extent)487 llvm::Value* CodeGenLLVM::CreateVecSlice(llvm::Value* vec, int begin, int extent) {
488 int num_elems = llvm::cast<llvm::VectorType>(vec->getType())->getNumElements();
489 if (extent == num_elems && begin == 0) return vec;
490 CHECK(begin >= 0 && extent <= num_elems) << "Slicing out of bound!\n";
491 std::vector<llvm::Constant*> indices;
492 indices.reserve(extent);
493 for (int i = 0; i < extent; ++i) {
494 if (begin + i >= 0 && begin + i < num_elems) {
495 indices.push_back(llvm::ConstantInt::get(t_int32_, begin + i));
496 } else {
497 indices.push_back(llvm::UndefValue::get(t_int32_));
498 }
499 }
500 return builder_->CreateShuffleVector(vec, vec, llvm::ConstantVector::get(indices));
501 }
502
CreateVecFlip(llvm::Value * vec)503 llvm::Value* CodeGenLLVM::CreateVecFlip(llvm::Value* vec) {
504 int num_elems = llvm::cast<llvm::VectorType>(vec->getType())->getNumElements();
505 #if TVM_LLVM_VERSION >= 110
506 std::vector<int> indices;
507 #else
508 std::vector<unsigned> indices;
509 #endif
510 for (int i = 0; i < num_elems; ++i) {
511 indices.push_back(num_elems - i - 1);
512 }
513 return builder_->CreateShuffleVector(vec, vec, indices);
514 }
515
CreateVecPad(llvm::Value * vec,int target_lanes)516 llvm::Value* CodeGenLLVM::CreateVecPad(llvm::Value* vec, int target_lanes) {
517 llvm::Value* mask = llvm::UndefValue::get(DTypeToLLVMType(DataType::Int(32, target_lanes)));
518 int num_elems = llvm::cast<llvm::VectorType>(vec->getType())->getNumElements();
519 if (num_elems == target_lanes) return vec;
520 CHECK_LT(num_elems, target_lanes);
521 for (int i = 0; i < num_elems; ++i) {
522 mask = builder_->CreateInsertElement(mask, ConstInt32(i), ConstInt32(i));
523 }
524 return builder_->CreateShuffleVector(vec, vec, mask);
525 }
526
CreateVecConcat(std::vector<llvm::Value * > vecs)527 llvm::Value* CodeGenLLVM::CreateVecConcat(std::vector<llvm::Value*> vecs) {
528 // concat vector, tree shape reduction
529 int total_lanes = 0;
530
531 for (llvm::Value* v : vecs) {
532 total_lanes += llvm::cast<llvm::VectorType>(v->getType())->getNumElements();
533 }
534 while (vecs.size() > 1) {
535 std::vector<llvm::Value*> new_vecs;
536 for (size_t i = 0; i < vecs.size() - 1; i += 2) {
537 llvm::Value* lhs = vecs[i];
538 llvm::Value* rhs = vecs[i + 1];
539 const size_t lhs_lanes = llvm::cast<llvm::VectorType>(lhs->getType())->getNumElements();
540 const size_t rhs_lanes = llvm::cast<llvm::VectorType>(rhs->getType())->getNumElements();
541 if (lhs_lanes < rhs_lanes) {
542 lhs = CreateVecPad(lhs, rhs_lanes);
543 } else if (rhs_lanes < lhs_lanes) {
544 rhs = CreateVecPad(rhs, lhs_lanes);
545 }
546 const size_t shared_lanes = std::max(lhs_lanes, rhs_lanes);
547 #if TVM_LLVM_VERSION >= 110
548 std::vector<int> mask;
549 #else
550 std::vector<unsigned> mask;
551 #endif
552 for (size_t i = 0; i < lhs_lanes; ++i) {
553 mask.push_back(i);
554 }
555 for (size_t i = 0; i < rhs_lanes; ++i) {
556 mask.push_back(shared_lanes + i);
557 }
558 new_vecs.push_back(builder_->CreateShuffleVector(lhs, rhs, mask));
559 }
560 if (vecs.size() % 2 != 0) {
561 new_vecs.push_back(vecs.back());
562 }
563 vecs.swap(new_vecs);
564 }
565 return CreateVecSlice(vecs[0], 0, total_lanes);
566 }
567
CreateSerialFor(llvm::Value * begin,llvm::Value * end,llvm::Value * stride,const Var & loop_var,const Stmt & body)568 void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Value* stride,
569 const Var& loop_var, const Stmt& body) {
570 using llvm::BasicBlock;
571 BasicBlock* pre_block = builder_->GetInsertBlock();
572 BasicBlock* for_begin = BasicBlock::Create(*ctx_, "for_begin", function_);
573 BasicBlock* for_body = BasicBlock::Create(*ctx_, "for_body", function_);
574 BasicBlock* for_end = BasicBlock::Create(*ctx_, "for_end", function_);
575 builder_->CreateBr(for_begin);
576 builder_->SetInsertPoint(for_begin);
577 llvm::PHINode* loop_value = builder_->CreatePHI(begin->getType(), 2);
578 loop_value->addIncoming(begin, pre_block);
579 CHECK(!var_map_.count(loop_var.get()));
580 var_map_[loop_var.get()] = loop_value;
581 builder_->CreateCondBr(CreateLT(loop_var.dtype(), loop_value, end), for_body, for_end,
582 md_very_likely_branch_);
583 builder_->SetInsertPoint(for_body);
584 this->VisitStmt(body);
585 var_map_.erase(loop_var.get());
586 llvm::Value* loop_next = CreateAdd(loop_var.dtype(), loop_value, stride);
587 loop_value->addIncoming(loop_next, builder_->GetInsertBlock());
588 builder_->CreateBr(for_begin);
589 builder_->SetInsertPoint(for_end);
590 }
591
592 // cast operatpr
CreateCast(DataType from,DataType to,llvm::Value * value)593 llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* value) {
594 llvm::Type* target = DTypeToLLVMType(to);
595 if (value->getType() == target) return value;
596 if (to.is_handle()) {
597 return builder_->CreateBitCast(value, target);
598 } else if (to.is_uint() && to.bits() == 1) {
599 if (from.is_float()) {
600 llvm::Constant* zero = llvm::ConstantFP::get(DTypeToLLVMType(from), 0.);
601 return builder_->CreateFCmpONE(value, zero);
602 } else {
603 llvm::Constant* zero = llvm::ConstantInt::get(DTypeToLLVMType(from), 0);
604 return builder_->CreateICmpNE(value, zero);
605 }
606 } else if (!from.is_float() && !to.is_float()) {
607 return builder_->CreateIntCast(value, target, from.is_int());
608 } else if (from.is_float() && to.is_int()) {
609 return builder_->CreateFPToSI(value, target);
610 } else if (from.is_float() && to.is_uint()) {
611 if (to.bits() < 8) {
612 value = builder_->CreateFPToUI(value, DTypeToLLVMType(to.with_bits(8)));
613 return builder_->CreateIntCast(value, target, false);
614 } else {
615 return builder_->CreateFPToUI(value, target);
616 }
617 } else if (from.is_int() && to.is_float()) {
618 return builder_->CreateSIToFP(value, target);
619 } else if (from.is_uint() && to.is_float()) {
620 return builder_->CreateUIToFP(value, target);
621 } else {
622 CHECK(from.is_float() && to.is_float());
623 return builder_->CreateFPCast(value, target);
624 }
625 }
626
GetConstString(const std::string & str)627 llvm::Constant* CodeGenLLVM::GetConstString(const std::string& str) {
628 auto it = str_map_.find(str);
629 if (it != str_map_.end()) return it->second;
630 llvm::Type* type = llvm::ArrayType::get(t_char_, str.length() + 1);
631 llvm::GlobalVariable* global = new llvm::GlobalVariable(
632 *module_, type, true, llvm::GlobalValue::PrivateLinkage, nullptr, ".str");
633 #if TVM_LLVM_VERSION >= 100
634 global->setAlignment(llvm::Align(1));
635 #else
636 global->setAlignment(1);
637 #endif
638 global->setInitializer(llvm::ConstantDataArray::getString(*ctx_, str));
639 llvm::Constant* zero = ConstInt32(0);
640 llvm::Constant* indices[] = {zero, zero};
641 llvm::Constant* ptr = llvm::ConstantExpr::getGetElementPtr(type, global, indices);
642 str_map_[str] = ptr;
643 return ptr;
644 }
645
CreateBufferPtr(DataType t,llvm::Value * buffer,llvm::Value * index)646 llvm::Value* CodeGenLLVM::CreateBufferPtr(DataType t, llvm::Value* buffer, llvm::Value* index) {
647 llvm::PointerType* btype = llvm::dyn_cast<llvm::PointerType>(buffer->getType());
648 CHECK(btype != nullptr);
649 llvm::PointerType* ptype = DTypeToLLVMType(t)->getPointerTo(btype->getAddressSpace());
650 if (btype != ptype) {
651 buffer = builder_->CreatePointerCast(buffer, ptype);
652 }
653 return builder_->CreateInBoundsGEP(buffer, index);
654 }
655
GetVarValue(const VarNode * v) const656 llvm::Value* CodeGenLLVM::GetVarValue(const VarNode* v) const {
657 auto it = var_map_.find(v);
658 CHECK(it != var_map_.end()) << "cannot find variable " << v->name_hint;
659 return it->second;
660 }
661
CreateCallExtern(Type ret_type,String global_symbol,const Array<PrimExpr> & args,bool skip_first_arg)662 llvm::Value* CodeGenLLVM::CreateCallExtern(Type ret_type, String global_symbol,
663 const Array<PrimExpr>& args, bool skip_first_arg) {
664 std::vector<llvm::Value*> arg_value;
665 std::vector<llvm::Type*> arg_type;
666 for (size_t i = static_cast<size_t>(skip_first_arg); i < args.size(); ++i) {
667 arg_value.push_back(MakeValue(args[i]));
668 arg_type.push_back(arg_value.back()->getType());
669 }
670 llvm::FunctionType* ftype = llvm::FunctionType::get(GetLLVMType(ret_type), arg_type, false);
671 llvm::Function* f = module_->getFunction(global_symbol);
672 if (f == nullptr) {
673 f = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage,
674 global_symbol.operator llvm::StringRef(), module_.get());
675 }
676 llvm::CallInst* call = builder_->CreateCall(f, arg_value);
677 return call;
678 }
679
GetIntrinsicDecl(llvm::Intrinsic::ID id,llvm::Type * ret_type,llvm::ArrayRef<llvm::Type * > arg_types)680 llvm::Function* CodeGenLLVM::GetIntrinsicDecl(llvm::Intrinsic::ID id, llvm::Type* ret_type,
681 llvm::ArrayRef<llvm::Type*> arg_types) {
682 llvm::Module* module = module_.get();
683
684 if (!llvm::Intrinsic::isOverloaded(id)) {
685 return llvm::Intrinsic::getDeclaration(module, id, {});
686 }
687
688 llvm::SmallVector<llvm::Intrinsic::IITDescriptor, 4> infos;
689 llvm::Intrinsic::getIntrinsicInfoTableEntries(id, infos);
690 llvm::SmallVector<llvm::Type*, 4> overload_types;
691
692 #if TVM_LLVM_VERSION >= 90
693 auto try_match = [&](llvm::FunctionType* f_ty, bool var_arg) {
694 overload_types.clear();
695 llvm::ArrayRef<llvm::Intrinsic::IITDescriptor> ref(infos);
696 auto match = llvm::Intrinsic::matchIntrinsicSignature(f_ty, ref, overload_types);
697 if (match == llvm::Intrinsic::MatchIntrinsicTypes_Match) {
698 bool error = llvm::Intrinsic::matchIntrinsicVarArg(var_arg, ref);
699 if (error) {
700 return llvm::Intrinsic::MatchIntrinsicTypes_NoMatchArg;
701 }
702 }
703 return match;
704 };
705
706 // First, try matching the signature assuming non-vararg case.
707 auto* fn_ty = llvm::FunctionType::get(ret_type, arg_types, false);
708 switch (try_match(fn_ty, false)) {
709 case llvm::Intrinsic::MatchIntrinsicTypes_NoMatchRet:
710 // The return type doesn't match, there is nothing else to do.
711 return nullptr;
712 case llvm::Intrinsic::MatchIntrinsicTypes_Match:
713 return llvm::Intrinsic::getDeclaration(module, id, overload_types);
714 case llvm::Intrinsic::MatchIntrinsicTypes_NoMatchArg:
715 break;
716 }
717
718 // Keep adding one type at a time (starting from empty list), and
719 // try matching the vararg signature.
720 llvm::SmallVector<llvm::Type*, 4> var_types;
721 for (int i = 0, e = arg_types.size(); i <= e; ++i) {
722 if (i > 0) var_types.push_back(arg_types[i - 1]);
723 auto* ft = llvm::FunctionType::get(ret_type, var_types, true);
724 if (try_match(ft, true) == llvm::Intrinsic::MatchIntrinsicTypes_Match) {
725 return llvm::Intrinsic::getDeclaration(module, id, overload_types);
726 }
727 }
728 // Failed to identify the type.
729 return nullptr;
730
731 #else // TVM_LLVM_VERSION
732 llvm::ArrayRef<llvm::Intrinsic::IITDescriptor> ref(infos);
733 // matchIntrinsicType returns true on error.
734 if (llvm::Intrinsic::matchIntrinsicType(ret_type, ref, overload_types)) {
735 return nullptr;
736 }
737 for (llvm::Type* t : arg_types) {
738 if (llvm::Intrinsic::matchIntrinsicType(t, ref, overload_types)) {
739 return nullptr;
740 }
741 }
742 return llvm::Intrinsic::getDeclaration(module, id, overload_types);
743 #endif // TVM_LLVM_VERSION
744 }
745
CreateIntrinsic(const CallNode * op)746 llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
747 if (op->op.same_as(builtin_call_llvm_intrin_) || op->op.same_as(builtin_call_llvm_pure_intrin_)) {
748 CHECK_GE(op->args.size(), 2U);
749 llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(Downcast<IntImm>(op->args[0])->value);
750 int64_t num_signature = Downcast<IntImm>(op->args[1])->value;
751 std::vector<llvm::Value*> arg_value;
752 std::vector<llvm::Type*> arg_type;
753 for (size_t i = 2; i < op->args.size(); ++i) {
754 arg_value.push_back(MakeValue(op->args[i]));
755 if (i - 2 < static_cast<size_t>(num_signature)) {
756 arg_type.push_back(arg_value.back()->getType());
757 }
758 }
759 // LLVM's prefetch intrinsic returns "void", while TVM's prefetch
760 // returns int32. This causes problems because prefetch is one of
761 // those intrinsics that is generated automatically via the
762 // tvm.intrin.rule mechanism. Any other intrinsic with a type
763 // mismatch will have to be treated specially here.
764 // TODO(kparzysz-quic): fix this once TVM prefetch uses the same
765 // type as LLVM.
766 llvm::Type* return_type = (id != llvm::Intrinsic::prefetch) ? GetLLVMType(GetRef<PrimExpr>(op))
767 : llvm::Type::getVoidTy(*ctx_);
768 llvm::Function* f = GetIntrinsicDecl(id, return_type, arg_type);
769 CHECK(f) << "Cannot find intrinsic declaration, possible type mismatch: "
770 << llvm::Intrinsic::getName(id, {});
771 return builder_->CreateCall(f, arg_value);
772 } else if (op->op.same_as(builtin::bitwise_and())) {
773 return builder_->CreateAnd(MakeValue(op->args[0]), MakeValue(op->args[1]));
774 } else if (op->op.same_as(builtin::bitwise_or())) {
775 return builder_->CreateOr(MakeValue(op->args[0]), MakeValue(op->args[1]));
776 } else if (op->op.same_as(builtin::bitwise_not())) {
777 return builder_->CreateNot(MakeValue(op->args[0]));
778 } else if (op->op.same_as(builtin::bitwise_xor())) {
779 return builder_->CreateXor(MakeValue(op->args[0]), MakeValue(op->args[1]));
780 } else if (op->op.same_as(builtin::shift_left())) {
781 return builder_->CreateShl(MakeValue(op->args[0]), MakeValue(op->args[1]));
782 } else if (op->op.same_as(builtin::shift_right())) {
783 if (op->args[0].dtype().is_int()) {
784 return builder_->CreateAShr(MakeValue(op->args[0]), MakeValue(op->args[1]));
785 } else {
786 return builder_->CreateLShr(MakeValue(op->args[0]), MakeValue(op->args[1]));
787 }
788 } else if (op->op.same_as(builtin::tvm_storage_sync())) {
789 return CreateStorageSync(op);
790 } else if (op->op.same_as(builtin::address_of())) {
791 const LoadNode* l = op->args[0].as<LoadNode>();
792 CHECK(op->args.size() == 1 && l);
793 const RampNode* r = l->index.as<RampNode>();
794 llvm::Value* ptr;
795 unsigned addrspace;
796 if (!r) {
797 ptr = CreateBufferPtr(l->dtype, MakeValue(l->buffer_var), MakeValue(l->index));
798 addrspace = llvm::dyn_cast<llvm::PointerType>(ptr->getType())->getAddressSpace();
799 } else {
800 PrimExpr index = r->base / make_const(DataType::Int(32), r->lanes);
801 ptr = CreateBufferPtr(l->dtype, MakeValue(l->buffer_var), MakeValue(index));
802 addrspace = llvm::dyn_cast<llvm::PointerType>(ptr->getType())->getAddressSpace();
803 }
804 return builder_->CreatePointerCast(ptr, t_char_->getPointerTo(addrspace));
805 } else if (op->op.same_as(builtin::reinterpret()) && is_zero(op->args[0])) {
806 return llvm::Constant::getNullValue(t_void_p_);
807 } else if (op->op.same_as(builtin::isnullptr())) {
808 return builder_->CreateIsNull(MakeValue(op->args[0]));
809 } else if (op->op.same_as(builtin::large_uint_imm())) {
810 CHECK_EQ(op->args.size(), 2U);
811 uint64_t low = static_cast<uint64_t>(Downcast<IntImm>(op->args[0])->value);
812 uint64_t high = static_cast<uint64_t>(Downcast<IntImm>(op->args[1])->value);
813 uint64_t val = (high << 32U) | low;
814 return llvm::ConstantInt::get(DTypeToLLVMType(op->dtype), val);
815 } else if (op->op.same_as(builtin::if_then_else())) {
816 CHECK_EQ(op->args[0].dtype().lanes(), 1) << "if_then_else can only take scalar condition";
817 using llvm::BasicBlock;
818 BasicBlock* then_block = BasicBlock::Create(*ctx_, "if_then", function_);
819 BasicBlock* else_block = BasicBlock::Create(*ctx_, "if_else", function_);
820 BasicBlock* end_block = BasicBlock::Create(*ctx_, "if_end", function_);
821 builder_->CreateCondBr(MakeValue(op->args[0]), then_block, else_block);
822 builder_->SetInsertPoint(then_block);
823 llvm::Value* then_value = MakeValue(op->args[1]);
824 BasicBlock* then_value_block = builder_->GetInsertBlock();
825 builder_->CreateBr(end_block);
826 builder_->SetInsertPoint(else_block);
827 llvm::Value* else_value = MakeValue(op->args[2]);
828 BasicBlock* else_value_block = builder_->GetInsertBlock();
829 builder_->CreateBr(end_block);
830 builder_->SetInsertPoint(end_block);
831 llvm::PHINode* value = builder_->CreatePHI(then_value->getType(), 2);
832 value->addIncoming(then_value, then_value_block);
833 value->addIncoming(else_value, else_value_block);
834 return value;
835 } else if (op->op.same_as(builtin::reinterpret())) {
836 llvm::Type* target = DTypeToLLVMType(op->dtype);
837 return builder_->CreateBitCast(MakeValue(op->args[0]), target);
838 } else if (op->op.same_as(builtin::isnan())) {
839 // TODO(hgt312): set fast math flag
840 llvm::Value* a = MakeValue(op->args[0]);
841 return builder_->CreateFCmpUNO(a, a);
842 } else if (op->op.same_as(builtin::vectorlow())) {
843 llvm::Value* v = MakeValue(op->args[0]);
844 int l = llvm::cast<llvm::VectorType>(v->getType())->getNumElements();
845 return CreateVecSlice(v, 0, l / 2);
846 } else if (op->op.same_as(builtin::vectorhigh())) {
847 llvm::Value* v = MakeValue(op->args[0]);
848 int l = llvm::cast<llvm::VectorType>(v->getType())->getNumElements();
849 return CreateVecSlice(v, l / 2, l / 2);
850 } else if (op->op.same_as(builtin::vectorcombine())) {
851 llvm::Value* v0 = MakeValue(op->args[0]);
852 llvm::Value* v1 = MakeValue(op->args[1]);
853 int num_elems = llvm::cast<llvm::VectorType>(v0->getType())->getNumElements() * 2;
854 #if TVM_LLVM_VERSION >= 110
855 std::vector<int> indices;
856 #else
857 std::vector<unsigned> indices;
858 #endif
859 for (int i = 0; i < num_elems; ++i) {
860 indices.push_back(i);
861 }
862 return builder_->CreateShuffleVector(v0, v1, indices);
863 } else {
864 LOG(FATAL) << "unknown intrinsic " << op->op;
865 return nullptr;
866 }
867 }
868
Scalarize(const PrimExpr & e,std::function<void (int i,llvm::Value * v)> f)869 void CodeGenLLVM::Scalarize(const PrimExpr& e, std::function<void(int i, llvm::Value* v)> f) {
870 if (const RampNode* ramp = e.as<RampNode>()) {
871 for (int i = 0; i < ramp->dtype.lanes(); ++i) {
872 PrimExpr offset = ramp->base + (ramp->stride * i);
873 f(i, MakeValue(offset));
874 }
875 } else {
876 llvm::Value* value = MakeValue(e);
877 for (int i = 0; i < e.dtype().lanes(); ++i) {
878 f(i, builder_->CreateExtractElement(value, i));
879 }
880 }
881 }
882
883 // Visitors
VisitExpr_(const VarNode * op)884 llvm::Value* CodeGenLLVM::VisitExpr_(const VarNode* op) { return GetVarValue(op); }
885
VisitExpr_(const CastNode * op)886 llvm::Value* CodeGenLLVM::VisitExpr_(const CastNode* op) {
887 return CreateCast(op->value.dtype(), op->dtype, MakeValue(op->value));
888 }
VisitExpr_(const IntImmNode * op)889 llvm::Value* CodeGenLLVM::VisitExpr_(const IntImmNode* op) {
890 return llvm::ConstantInt::getSigned(DTypeToLLVMType(op->dtype), op->value);
891 }
892
VisitExpr_(const FloatImmNode * op)893 llvm::Value* CodeGenLLVM::VisitExpr_(const FloatImmNode* op) {
894 return llvm::ConstantFP::get(DTypeToLLVMType(op->dtype), op->value);
895 }
896
VisitExpr_(const StringImmNode * op)897 llvm::Value* CodeGenLLVM::VisitExpr_(const StringImmNode* op) { return GetConstString(op->value); }
898
899 #define DEFINE_CODEGEN_BINARY_OP(Op) \
900 llvm::Value* CodeGenLLVM::Create##Op(DataType t, llvm::Value* a, llvm::Value* b) { \
901 if (t.is_int()) { \
902 if (t.bits() >= 32) { \
903 return builder_->CreateNSW##Op(a, b); \
904 } else { \
905 return builder_->Create##Op(a, b); \
906 } \
907 } else if (t.is_uint()) { \
908 if (t.bits() >= 32) { \
909 return builder_->CreateNUW##Op(a, b); \
910 } else { \
911 return builder_->Create##Op(a, b); \
912 } \
913 } else { \
914 CHECK(t.is_float()); \
915 return builder_->CreateF##Op(a, b); \
916 } \
917 } \
918 llvm::Value* CodeGenLLVM::VisitExpr_(const Op##Node* op) { \
919 return Create##Op(op->dtype, MakeValue(op->a), MakeValue(op->b)); \
920 }
921
922 DEFINE_CODEGEN_BINARY_OP(Add);
923 DEFINE_CODEGEN_BINARY_OP(Sub);
924 DEFINE_CODEGEN_BINARY_OP(Mul);
925
926 #define DEFINE_CODEGEN_CMP_OP(Op) \
927 llvm::Value* CodeGenLLVM::Create##Op(DataType t, llvm::Value* a, llvm::Value* b) { \
928 if (t.is_int()) { \
929 return builder_->CreateICmpS##Op(a, b); \
930 } else if (t.is_uint()) { \
931 return builder_->CreateICmpU##Op(a, b); \
932 } else { \
933 CHECK(t.is_float()); \
934 return builder_->CreateFCmpO##Op(a, b); \
935 } \
936 } \
937 llvm::Value* CodeGenLLVM::VisitExpr_(const Op##Node* op) { \
938 return Create##Op(op->a.dtype(), MakeValue(op->a), MakeValue(op->b)); \
939 }
940
941 DEFINE_CODEGEN_CMP_OP(LT);
942 DEFINE_CODEGEN_CMP_OP(LE);
943 DEFINE_CODEGEN_CMP_OP(GT);
944 DEFINE_CODEGEN_CMP_OP(GE);
945
VisitExpr_(const DivNode * op)946 llvm::Value* CodeGenLLVM::VisitExpr_(const DivNode* op) {
947 llvm::Value* a = MakeValue(op->a);
948 llvm::Value* b = MakeValue(op->b);
949 if (op->dtype.is_int()) {
950 return builder_->CreateSDiv(a, b);
951 } else if (op->dtype.is_uint()) {
952 return builder_->CreateUDiv(a, b);
953 } else {
954 CHECK(op->dtype.is_float());
955 return builder_->CreateFDiv(a, b);
956 }
957 }
958
VisitExpr_(const ModNode * op)959 llvm::Value* CodeGenLLVM::VisitExpr_(const ModNode* op) {
960 llvm::Value* a = MakeValue(op->a);
961 llvm::Value* b = MakeValue(op->b);
962 if (op->dtype.is_int()) {
963 return builder_->CreateSRem(a, b);
964 } else if (op->dtype.is_uint()) {
965 return builder_->CreateURem(a, b);
966 } else {
967 CHECK(op->dtype.is_float());
968 return builder_->CreateFRem(a, b);
969 }
970 }
971
VisitExpr_(const MinNode * op)972 llvm::Value* CodeGenLLVM::VisitExpr_(const MinNode* op) {
973 llvm::Value* a = MakeValue(op->a);
974 llvm::Value* b = MakeValue(op->b);
975 return builder_->CreateSelect(CreateLT(op->a.dtype(), a, b), a, b);
976 }
977
VisitExpr_(const MaxNode * op)978 llvm::Value* CodeGenLLVM::VisitExpr_(const MaxNode* op) {
979 llvm::Value* a = MakeValue(op->a);
980 llvm::Value* b = MakeValue(op->b);
981 return builder_->CreateSelect(CreateGT(op->a.dtype(), a, b), a, b);
982 }
983
VisitExpr_(const EQNode * op)984 llvm::Value* CodeGenLLVM::VisitExpr_(const EQNode* op) {
985 llvm::Value* a = MakeValue(op->a);
986 llvm::Value* b = MakeValue(op->b);
987 if (op->a.dtype().is_int() || op->a.dtype().is_uint()) {
988 return builder_->CreateICmpEQ(a, b);
989 } else {
990 return builder_->CreateFCmpOEQ(a, b);
991 }
992 }
993
VisitExpr_(const NENode * op)994 llvm::Value* CodeGenLLVM::VisitExpr_(const NENode* op) {
995 llvm::Value* a = MakeValue(op->a);
996 llvm::Value* b = MakeValue(op->b);
997 if (op->a.dtype().is_int() || op->a.dtype().is_uint()) {
998 return builder_->CreateICmpNE(a, b);
999 } else {
1000 return builder_->CreateFCmpONE(a, b);
1001 }
1002 }
1003
VisitExpr_(const AndNode * op)1004 llvm::Value* CodeGenLLVM::VisitExpr_(const AndNode* op) {
1005 return builder_->CreateAnd(MakeValue(op->a), MakeValue(op->b));
1006 }
1007
VisitExpr_(const OrNode * op)1008 llvm::Value* CodeGenLLVM::VisitExpr_(const OrNode* op) {
1009 return builder_->CreateOr(MakeValue(op->a), MakeValue(op->b));
1010 }
1011
VisitExpr_(const NotNode * op)1012 llvm::Value* CodeGenLLVM::VisitExpr_(const NotNode* op) {
1013 return builder_->CreateNot(MakeValue(op->a));
1014 }
1015
VisitExpr_(const SelectNode * op)1016 llvm::Value* CodeGenLLVM::VisitExpr_(const SelectNode* op) {
1017 return builder_->CreateSelect(MakeValue(op->condition), MakeValue(op->true_value),
1018 MakeValue(op->false_value));
1019 }
1020
VisitExpr_(const LetNode * op)1021 llvm::Value* CodeGenLLVM::VisitExpr_(const LetNode* op) {
1022 auto it = let_binding_.find(op->var);
1023 if (it != let_binding_.end()) {
1024 CHECK(deep_equal_(it->second->value, op->value))
1025 << "Let cannot bind the same var to two different values";
1026 } else {
1027 let_binding_[op->var] = op;
1028 }
1029 var_map_[op->var.get()] = MakeValue(op->value);
1030 analyzer_->Bind(op->var, op->value);
1031 return MakeValue(op->body);
1032 }
1033
VisitExpr_(const LoadNode * op)1034 llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) {
1035 DataType t = op->dtype;
1036 bool is_volatile = volatile_buf_.count(op->buffer_var.get());
1037 llvm::Value* buffer = MakeValue(op->buffer_var);
1038 llvm::Value* index = MakeValue(op->index);
1039
1040 if (t.lanes() == 1) {
1041 int alignment, native_bits;
1042 GetAlignment(t, op->buffer_var.get(), op->index, &alignment, &native_bits);
1043 llvm::Value* ptr = CreateBufferPtr(t, buffer, index);
1044 #if TVM_LLVM_VERSION >= 110
1045 llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, llvm::Align(alignment), is_volatile);
1046 #else
1047 llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, alignment, is_volatile);
1048 #endif
1049 AddAliasInfo(load, op->buffer_var.get(), op->index);
1050 return load;
1051 } else {
1052 // vector load
1053 unsigned addrspace = llvm::dyn_cast<llvm::PointerType>(buffer->getType())->getAddressSpace();
1054 if (const RampNode* ramp = op->index.as<RampNode>()) {
1055 if (is_one(ramp->stride)) {
1056 int alignment, native_bits;
1057 GetAlignment(t, op->buffer_var.get(), ramp->base, &alignment, &native_bits);
1058 CHECK_EQ(ramp->lanes, t.lanes());
1059 llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, MakeValue(ramp->base));
1060 ptr = builder_->CreatePointerCast(ptr, DTypeToLLVMType(t)->getPointerTo(addrspace));
1061 #if TVM_LLVM_VERSION >= 110
1062 llvm::LoadInst* load =
1063 builder_->CreateAlignedLoad(ptr, llvm::Align(alignment), is_volatile);
1064 #else
1065 llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, alignment, is_volatile);
1066 #endif
1067 AddAliasInfo(load, op->buffer_var.get(), op->index);
1068 return load;
1069 }
1070 }
1071 }
1072 // scalarized load.
1073 int basic_align = t.bits() / 8;
1074 llvm::Value* ret = llvm::UndefValue::get(DTypeToLLVMType(t));
1075 auto f = [&](int i, llvm::Value* index) {
1076 llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, index);
1077 #if TVM_LLVM_VERSION >= 110
1078 llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, llvm::Align(basic_align), is_volatile);
1079 #else
1080 llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, basic_align, is_volatile);
1081 #endif
1082 ret = builder_->CreateInsertElement(ret, load, ConstInt32(i));
1083 AddAliasInfo(load, op->buffer_var.get(), PrimExpr());
1084 };
1085 this->Scalarize(op->index, f);
1086 return ret;
1087 }
1088
VisitExpr_(const CallNode * op)1089 llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) {
1090 if (auto* ptr_op = op->op.as<OpNode>()) {
1091 auto call_op = GetRef<Op>(ptr_op);
1092 if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) {
1093 // call extern intrinsic
1094 CHECK_GE(op->args.size(), 1U);
1095 auto global_symbol = Downcast<StringImm>(op->args[0]);
1096 return this->CreateCallExtern(GetType(GetRef<PrimExpr>(op)), global_symbol->value, op->args,
1097 true);
1098 } else if (op_attr_global_symbol_.count(call_op)) {
1099 // call extern if the op itself have a global symbol.
1100 return this->CreateCallExtern(GetType(GetRef<PrimExpr>(op)), op_attr_global_symbol_[call_op],
1101 op->args, false);
1102 } else {
1103 return CreateIntrinsic(op);
1104 }
1105 } else {
1106 CHECK(op->op.as<GlobalVarNode>());
1107 LOG(FATAL) << "Do not yet support cross function call";
1108 return nullptr;
1109 }
1110 }
1111
VisitExpr_(const RampNode * op)1112 llvm::Value* CodeGenLLVM::VisitExpr_(const RampNode* op) {
1113 llvm::Value* vec = llvm::UndefValue::get(DTypeToLLVMType(op->dtype));
1114 for (int i = 0; i < op->lanes; ++i) {
1115 vec = builder_->CreateInsertElement(
1116 vec, MakeValue(op->base + op->stride * make_const(op->stride.dtype(), i)), ConstInt32(i));
1117 }
1118 return vec;
1119 }
1120
VisitExpr_(const ShuffleNode * op)1121 llvm::Value* CodeGenLLVM::VisitExpr_(const ShuffleNode* op) {
1122 std::vector<llvm::Value*> vecs(op->vectors.size());
1123 int total_lanes = 0;
1124 for (int i = 0, e = op->vectors.size(); i < e; ++i) {
1125 vecs[i] = VisitExpr(op->vectors[i]);
1126 total_lanes += op->vectors[i].dtype().lanes();
1127 }
1128 llvm::Value* v0 = CreateVecConcat(vecs);
1129 std::vector<uint32_t> idx(op->indices.size());
1130 for (int i = 0, e = op->indices.size(); i < e; ++i) {
1131 const int64_t* val = as_const_int(op->indices[i]);
1132 CHECK(val && *val >= 0 && *val < total_lanes) << "Shuffled indeces are suppose to be int, "
1133 << "but get " << op->indices[i] << "\n";
1134 idx[i] = *val;
1135 }
1136 llvm::Value* mask = llvm::ConstantDataVector::get(builder_->getContext(), idx);
1137 auto res = builder_->CreateShuffleVector(v0, llvm::UndefValue::get(v0->getType()), mask);
1138 // If the output is a single-element vector, convert it back to a scalar.
1139 if (idx.size() == 1) {
1140 res = builder_->CreateExtractElement(res, ConstInt32(0));
1141 }
1142 return res;
1143 }
1144
VisitExpr_(const BroadcastNode * op)1145 llvm::Value* CodeGenLLVM::VisitExpr_(const BroadcastNode* op) {
1146 return CreateBroadcast(MakeValue(op->value), op->lanes);
1147 }
1148
VisitStmt_(const StoreNode * op)1149 void CodeGenLLVM::VisitStmt_(const StoreNode* op) {
1150 CHECK(is_one(op->predicate));
1151 DataType t = op->value.dtype();
1152 bool is_volatile = volatile_buf_.count(op->buffer_var.get());
1153 llvm::Value* buffer = MakeValue(op->buffer_var);
1154 llvm::Value* index = MakeValue(op->index);
1155 llvm::Value* value = MakeValue(op->value);
1156
1157 if (t.lanes() == 1) {
1158 int alignment, native_bits;
1159 GetAlignment(t, op->buffer_var.get(), op->index, &alignment, &native_bits);
1160 llvm::Value* ptr = CreateBufferPtr(t, buffer, index);
1161 #if TVM_LLVM_VERSION >= 110
1162 llvm::StoreInst* store =
1163 builder_->CreateAlignedStore(value, ptr, llvm::Align(alignment), is_volatile);
1164 #else
1165 llvm::StoreInst* store = builder_->CreateAlignedStore(value, ptr, alignment, is_volatile);
1166 #endif
1167 AddAliasInfo(store, op->buffer_var.get(), op->index);
1168 return;
1169 } else {
1170 // vector store
1171 unsigned addrspace = llvm::dyn_cast<llvm::PointerType>(buffer->getType())->getAddressSpace();
1172 if (const RampNode* ramp = op->index.as<RampNode>()) {
1173 if (is_one(ramp->stride)) {
1174 int alignment, native_bits;
1175 GetAlignment(t, op->buffer_var.get(), ramp->base, &alignment, &native_bits);
1176 CHECK_EQ(ramp->lanes, t.lanes());
1177 llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, MakeValue(ramp->base));
1178 ptr = builder_->CreatePointerCast(ptr, DTypeToLLVMType(t)->getPointerTo(addrspace));
1179 #if TVM_LLVM_VERSION >= 110
1180 llvm::StoreInst* store =
1181 builder_->CreateAlignedStore(value, ptr, llvm::Align(alignment), is_volatile);
1182 #else
1183 llvm::StoreInst* store = builder_->CreateAlignedStore(value, ptr, alignment, is_volatile);
1184 #endif
1185 AddAliasInfo(store, op->buffer_var.get(), op->index);
1186 return;
1187 }
1188 }
1189 }
1190 CHECK_GE(t.bits(), 8);
1191 // scalarized store.
1192 int basic_align = t.bits() / 8;
1193 auto f = [&](int i, llvm::Value* index) {
1194 llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, index);
1195 #if TVM_LLVM_VERSION >= 110
1196 llvm::StoreInst* store = builder_->CreateAlignedStore(
1197 builder_->CreateExtractElement(value, i), ptr, llvm::Align(basic_align), is_volatile);
1198 #else
1199 llvm::StoreInst* store = builder_->CreateAlignedStore(builder_->CreateExtractElement(value, i),
1200 ptr, basic_align, is_volatile);
1201 #endif
1202 AddAliasInfo(store, op->buffer_var.get(), PrimExpr());
1203 };
1204 this->Scalarize(op->index, f);
1205 }
1206
VisitStmt_(const ForNode * op)1207 void CodeGenLLVM::VisitStmt_(const ForNode* op) {
1208 CHECK(is_zero(op->min));
1209 analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
1210 if (op->for_type == ForType::Unrolled) {
1211 LOG(WARNING) << "Unroll hint get ignore at CodeGenLLVM backend, "
1212 << " consider set unroll_explicit=True";
1213 } else {
1214 CHECK(op->for_type == ForType::Serial);
1215 }
1216 CreateSerialFor(MakeValue(op->min), MakeValue(op->extent),
1217 llvm::ConstantInt::getSigned(GetLLVMType(op->extent), 1), op->loop_var, op->body);
1218 }
1219
VisitStmt_(const IfThenElseNode * op)1220 void CodeGenLLVM::VisitStmt_(const IfThenElseNode* op) {
1221 using llvm::BasicBlock;
1222 llvm::Value* cond = MakeValue(op->condition);
1223 BasicBlock* then_block = BasicBlock::Create(*ctx_, "if_then", function_);
1224 BasicBlock* end_block = BasicBlock::Create(*ctx_, "if_end", function_);
1225 if (op->else_case.defined()) {
1226 BasicBlock* else_block = BasicBlock::Create(*ctx_, "if_else", function_);
1227 builder_->CreateCondBr(cond, then_block, else_block);
1228 builder_->SetInsertPoint(then_block);
1229 this->VisitStmt(op->then_case);
1230 builder_->CreateBr(end_block);
1231 builder_->SetInsertPoint(else_block);
1232 this->VisitStmt(op->else_case);
1233 builder_->CreateBr(end_block);
1234 } else {
1235 builder_->CreateCondBr(cond, then_block, end_block, md_very_likely_branch_);
1236 builder_->SetInsertPoint(then_block);
1237 this->VisitStmt(op->then_case);
1238 builder_->CreateBr(end_block);
1239 }
1240 builder_->SetInsertPoint(end_block);
1241 }
1242
VisitStmt_(const AllocateNode * op)1243 void CodeGenLLVM::VisitStmt_(const AllocateNode* op) {
1244 CHECK(!is_zero(op->condition));
1245 llvm::Value* buf = nullptr;
1246
1247 int32_t constant_size = op->constant_allocation_size();
1248 CHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation";
1249 StorageInfo& info = alloc_storage_info_[op->buffer_var.get()];
1250 if (constant_size % 4 == 0 && info.alignment == 0) {
1251 info.alignment = GetTempAllocaAlignment(op->dtype, constant_size);
1252 }
1253 // maximum necessary alignment in the NV devices
1254 if (info.alignment > 16) {
1255 info.alignment = 16;
1256 }
1257 llvm::AllocaInst* alloca = WithFunctionEntry([&]() {
1258 return builder_->CreateAlloca(DTypeToLLVMType(op->dtype), ConstInt32(constant_size));
1259 });
1260 if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) {
1261 #if TVM_LLVM_VERSION >= 100
1262 alloca->setAlignment(llvm::Align(info.alignment));
1263 #else
1264 alloca->setAlignment(info.alignment);
1265 #endif
1266 }
1267 info.alignment = alloca->getAlignment();
1268 buf = alloca;
1269
1270 buf = builder_->CreatePointerCast(
1271 buf, DTypeToLLVMType(op->dtype)->getPointerTo(buf->getType()->getPointerAddressSpace()));
1272 CHECK(!var_map_.count(op->buffer_var.get()));
1273 var_map_[op->buffer_var.get()] = buf;
1274 this->VisitStmt(op->body);
1275 }
1276
VisitStmt_(const AttrStmtNode * op)1277 void CodeGenLLVM::VisitStmt_(const AttrStmtNode* op) {
1278 if (op->attr_key == tir::attr::thread_extent) {
1279 IterVar iv = Downcast<IterVar>(op->node);
1280 if (iv->thread_tag.length() != 0) {
1281 if (!var_map_.count(iv->var.get())) {
1282 var_map_[iv->var.get()] = GetThreadIndex(iv);
1283 analyzer_->Bind(iv->var, Range::FromMinExtent(0, op->value));
1284 }
1285 }
1286 } else if (op->attr_key == tir::attr::storage_scope) {
1287 const VarNode* v = op->node.as<VarNode>();
1288 CHECK(v);
1289 alloc_storage_info_[v].scope =
1290 runtime::StorageScope::Create(op->value.as<StringImmNode>()->value);
1291 } else if (op->attr_key == tir::attr::storage_alignment) {
1292 const VarNode* v = op->node.as<VarNode>();
1293 CHECK(v);
1294 alloc_storage_info_[v].alignment = static_cast<int>(op->value.as<IntImmNode>()->value);
1295 if (var_map_.count(v) && alloc_storage_info_[v].alignment > 1) {
1296 builder_->CreateAlignmentAssumption(*data_layout_, GetVarValue(v),
1297 alloc_storage_info_[v].alignment);
1298 }
1299 } else if (op->attr_key == tir::attr::volatile_scope) {
1300 const VarNode* v = op->node.as<VarNode>();
1301 CHECK(v);
1302 volatile_buf_.insert(v);
1303 }
1304 this->VisitStmt(op->body);
1305 }
1306
VisitStmt_(const AssertStmtNode * op)1307 void CodeGenLLVM::VisitStmt_(const AssertStmtNode* op) {
1308 With<arith::ConstraintContext> cctx(analyzer_.get(), op->condition);
1309 this->VisitStmt(op->body);
1310 }
1311
VisitStmt_(const LetStmtNode * op)1312 void CodeGenLLVM::VisitStmt_(const LetStmtNode* op) {
1313 const VarNode* v = op->var.get();
1314 CHECK(!var_map_.count(v));
1315 if (v->dtype.is_handle()) {
1316 if (!is_restricted_) {
1317 alias_var_set_.insert(v);
1318 }
1319 }
1320 var_map_[v] = MakeValue(op->value);
1321 analyzer_->Bind(op->var, op->value);
1322 if (alloc_storage_info_.count(v) && alloc_storage_info_[v].alignment > 1) {
1323 builder_->CreateAlignmentAssumption(*data_layout_, GetVarValue(v),
1324 alloc_storage_info_[v].alignment);
1325 }
1326 this->VisitStmt(op->body);
1327 }
1328
VisitStmt_(const SeqStmtNode * op)1329 void CodeGenLLVM::VisitStmt_(const SeqStmtNode* op) {
1330 for (Stmt stmt : op->seq) {
1331 this->VisitStmt(stmt);
1332 }
1333 }
1334
VisitStmt_(const EvaluateNode * op)1335 void CodeGenLLVM::VisitStmt_(const EvaluateNode* op) { MakeValue(op->value); }
1336
1337 } // namespace codegen
1338 } // namespace tvm
1339 #endif // TVM_LLVM_VERSION
1340