1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file codegen_llvm.cc
22  */
23 #ifdef TVM_LLVM_VERSION
24 // Part of the code are adapted from Halide's CodeGen_LLVM
25 #include "codegen_llvm.h"
26 
27 #include <tvm/runtime/c_runtime_api.h>
28 #include <tvm/runtime/device_api.h>
29 #include <tvm/tir/op.h>
30 
31 #include <algorithm>
32 
33 #include "../../arith/pattern_match.h"
34 #include "../build_common.h"
35 #include "codegen_cpu.h"
36 namespace tvm {
37 namespace codegen {
38 
Create(llvm::TargetMachine * tm)39 std::unique_ptr<CodeGenLLVM> CodeGenLLVM::Create(llvm::TargetMachine* tm) {
40   std::string target = tm->getTarget().getName();
41   std::string factory_name = "tvm.codegen.llvm.target_" + target;
42   const PackedFunc* f = runtime::Registry::Get(factory_name);
43   if (f != nullptr) {
44     void* handle = (*f)();
45     return std::unique_ptr<CodeGenLLVM>(static_cast<CodeGenLLVM*>(handle));
46   } else {
47     return std::unique_ptr<CodeGenLLVM>(new CodeGenCPU());
48   }
49 }
50 
Init(const std::string & module_name,llvm::TargetMachine * tm,llvm::LLVMContext * ctx,bool system_lib,bool dynamic_lookup,bool target_c_runtime)51 void CodeGenLLVM::Init(const std::string& module_name, llvm::TargetMachine* tm,
52                        llvm::LLVMContext* ctx, bool system_lib, bool dynamic_lookup,
53                        bool target_c_runtime) {
54   InitializeLLVM();
55   ctx_ = ctx;
56   builder_.reset(new IRBuilder(*ctx_));
57   module_.reset(new llvm::Module(module_name, *ctx_));
58   md_builder_.reset(new llvm::MDBuilder(*ctx_));
59   // types
60   t_void_ = llvm::Type::getVoidTy(*ctx_);
61   t_void_p_ = llvm::Type::getInt8Ty(*ctx_)->getPointerTo();
62   t_int_ = llvm::Type::getInt32Ty(*ctx_);
63   t_char_ = llvm::Type::getInt8Ty(*ctx_);
64   t_int8_ = llvm::Type::getInt8Ty(*ctx_);
65   t_int16_ = llvm::Type::getInt16Ty(*ctx_);
66   t_int32_ = llvm::Type::getInt32Ty(*ctx_);
67   t_int64_ = llvm::Type::getInt64Ty(*ctx_);
68   t_float64_ = llvm::Type::getDoubleTy(*ctx_);
69   // meta data
70   md_very_likely_branch_ = md_builder_->createBranchWeights(1 << 20, 1);
71   md_tbaa_root_ = md_builder_->createTBAARoot("tvm-tbaa");
72   md_tbaa_alias_set_ = md_builder_->createTBAANode("tvm-alias", md_tbaa_root_);
73   this->InitTarget(tm);
74 }
75 
InitTarget(llvm::TargetMachine * tm)76 void CodeGenLLVM::InitTarget(llvm::TargetMachine* tm) {
77   module_->setTargetTriple(tm->getTargetTriple().str());
78   module_->setDataLayout(tm->createDataLayout());
79   data_layout_.reset(new llvm::DataLayout(module_.get()));
80   target_machine_ = tm;
81   if (native_vector_bits_ == 0) {
82     const auto& arch = tm->getTargetTriple().getArch();
83     if (arch == llvm::Triple::x86_64) {
84       // for avx512
85       native_vector_bits_ = 512;
86     } else if (arch == llvm::Triple::x86) {
87       native_vector_bits_ = 256;
88     } else if (arch == llvm::Triple::arm || arch == llvm::Triple::aarch64) {
89       native_vector_bits_ = 128;
90     } else {
91       native_vector_bits_ = 128;
92       std::string arch_name = std::string(tm->getTargetTriple().getArchName());
93       LOG(WARNING) << "Set native vector bits to be 128 for " << arch_name;
94     }
95   }
96 }
97 
AddFunction(const PrimFunc & f)98 void CodeGenLLVM::AddFunction(const PrimFunc& f) { this->AddFunctionInternal(f, false); }
99 
InitFuncState()100 void CodeGenLLVM::InitFuncState() {
101   var_map_.clear();
102   alias_var_set_.clear();
103   alloc_storage_info_.clear();
104   volatile_buf_.clear();
105   analyzer_.reset(new arith::Analyzer());
106 }
107 
AddFunctionInternal(const PrimFunc & f,bool ret_void)108 void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) {
109   this->InitFuncState();
110 
111   CHECK_EQ(f->buffer_map.size(), 0U)
112       << "Cannot codegen function with buffer_map, please lower them first";
113 
114   std::vector<llvm::Type*> param_types;
115   is_restricted_ = f->HasNonzeroAttr(tir::attr::kNoAlias);
116   for (Var param : f->params) {
117     param_types.push_back(GetLLVMType(param));
118     if (!is_restricted_ && param.dtype().is_handle()) {
119       alias_var_set_.insert(param.get());
120     }
121   }
122   // TODO(tvm-team):
123   // Update the function type to respect the ret_type field of f.
124   // Once we allow more flexibility in the PrimFunc.
125   llvm::FunctionType* ftype =
126       llvm::FunctionType::get(ret_void ? t_void_ : t_int_, param_types, false);
127 
128   auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
129   CHECK(global_symbol.defined())
130       << "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute";
131   CHECK(module_->getFunction(static_cast<std::string>(global_symbol.value())) == nullptr)
132       << "Function " << global_symbol << " already exist in module";
133 
134   function_ = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage,
135                                      global_symbol.value().operator std::string(), module_.get());
136   function_->setCallingConv(llvm::CallingConv::C);
137   function_->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass);
138 
139   // set var map and align information
140   auto arg_it = function_->arg_begin();
141   for (size_t i = 0; i < f->params.size(); ++i, ++arg_it) {
142     llvm::Argument* v = &(*arg_it);
143     const Var& var = f->params[i];
144     var_map_[var.get()] = v;
145     if (is_restricted_) {
146       if (var.dtype().is_handle() && !alias_var_set_.count(var.get())) {
147         // set non alias.
148 #if TVM_LLVM_VERSION >= 50
149         function_->addParamAttr(i, llvm::Attribute::NoAlias);
150 #else
151         function_->setDoesNotAlias(i + 1);
152 #endif
153       }
154     }
155   }
156   llvm::BasicBlock* entry = llvm::BasicBlock::Create(*ctx_, "entry", function_);
157   builder_->SetInsertPoint(entry);
158   this->VisitStmt(f->body);
159 
160   // Add alignment attribute if needed.
161 #if TVM_LLVM_VERSION >= 50
162   for (size_t i = 0; i < f->params.size(); ++i) {
163     const Var& var = f->params[i];
164     auto f = alloc_storage_info_.find(var.get());
165     if (f != alloc_storage_info_.end()) {
166       unsigned align = f->second.alignment;
167       if (align > 1) {
168         auto attr = llvm::Attribute::get(*ctx_, llvm::Attribute::Alignment, align);
169         function_->addParamAttr(i, attr);
170       }
171     }
172   }
173 #endif
174 
175   if (ret_void) {
176     builder_->CreateRetVoid();
177   } else {
178     builder_->CreateRet(ConstInt32(0));
179   }
180 }
181 
Finish()182 std::unique_ptr<llvm::Module> CodeGenLLVM::Finish() {
183   this->AddStartupFunction();
184   for (size_t i = 0; i < link_modules_.size(); ++i) {
185     CHECK(!llvm::Linker::linkModules(*module_, std::move(link_modules_[i])))
186         << "Failed to link modules";
187   }
188   link_modules_.clear();
189   // optimize
190   this->Optimize();
191   return std::move(module_);
192 }
193 
HandleImport(const std::string & code)194 void CodeGenLLVM::HandleImport(const std::string& code) {
195   std::unique_ptr<llvm::Module> mlib;
196   llvm::SMDiagnostic err;
197   if (code.length() >= 3 &&
198       (code.substr(code.length() - 3) == ".ll" || code.substr(code.length() - 3) == ".bc")) {
199     mlib = llvm::parseIRFile(code, err, *ctx_);
200     if (mlib.get() == nullptr) {
201       std::string msg = std::string(err.getMessage());
202       LOG(FATAL) << "Fail to load bitcode file " << code << "\n"
203                  << "line " << err.getLineNo() << ":" << msg;
204     }
205   } else {
206     std::unique_ptr<llvm::MemoryBuffer> buf = llvm::MemoryBuffer::getMemBuffer(code);
207     mlib = llvm::parseIR(*buf, err, *ctx_);
208     if (mlib.get() == nullptr) {
209       std::string msg = std::string(err.getMessage());
210       LOG(FATAL) << "Fail to load llvm ir "
211                  << "line " << err.getLineNo() << ":" << msg << "\ncontent:\n"
212                  << code;
213     }
214   }
215   mlib->setTargetTriple(target_machine_->getTargetTriple().str());
216   mlib->setDataLayout(target_machine_->createDataLayout());
217   // mark all the functions as force inline
218   for (llvm::Function& f : mlib->functions()) {
219     f.removeFnAttr(llvm::Attribute::NoInline);
220     f.addFnAttr(llvm::Attribute::AlwaysInline);
221     f.setLinkage(llvm::GlobalValue::AvailableExternallyLinkage);
222   }
223   // add to linker libraries.
224   this->AddLinkModule(std::move(mlib));
225 }
226 
AddLinkModule(std::unique_ptr<llvm::Module> && mod)227 void CodeGenLLVM::AddLinkModule(std::unique_ptr<llvm::Module>&& mod) {
228   link_modules_.emplace_back(std::move(mod));
229 }
230 
AddMainFunction(const std::string & entry_func_name)231 void CodeGenLLVM::AddMainFunction(const std::string& entry_func_name) {
232   LOG(FATAL) << "not implemented";
233 }
234 
GetThreadIndex(const IterVar & iv)235 llvm::Value* CodeGenLLVM::GetThreadIndex(const IterVar& iv) {
236   LOG(FATAL) << "not implemented";
237   return nullptr;
238 }
239 
CreateStorageSync(const CallNode * op)240 llvm::Value* CodeGenLLVM::CreateStorageSync(const CallNode* op) {
241   LOG(FATAL) << "not implemented";
242   return nullptr;
243 }
244 
245 class FPassManager : public llvm::legacy::FunctionPassManager {
246  public:
FPassManager(llvm::Module * m)247   explicit FPassManager(llvm::Module* m) : llvm::legacy::FunctionPassManager(m) {}
248   // override add to allow messaging
add(llvm::Pass * p)249   void add(llvm::Pass* p) final { llvm::legacy::FunctionPassManager::add(p); }
250 };
251 
252 class MPassManager : public llvm::legacy::PassManager {
253  public:
254   // override add to allow messaging
add(llvm::Pass * p)255   void add(llvm::Pass* p) final { llvm::legacy::PassManager::add(p); }
256 };
257 
InitPassManagerBuilder(llvm::PassManagerBuilder * builder)258 void CodeGenLLVM::InitPassManagerBuilder(llvm::PassManagerBuilder* builder) {}
259 
Optimize()260 void CodeGenLLVM::Optimize() {
261   // pass manager
262   FPassManager fpass(module_.get());
263   MPassManager mpass;
264   mpass.add(llvm::createTargetTransformInfoWrapperPass(
265       target_machine_ ? target_machine_->getTargetIRAnalysis() : llvm::TargetIRAnalysis()));
266   fpass.add(llvm::createTargetTransformInfoWrapperPass(
267       target_machine_ ? target_machine_->getTargetIRAnalysis() : llvm::TargetIRAnalysis()));
268 
269   // place optimization pass
270   llvm::PassManagerBuilder builder;
271   builder.OptLevel = 3;
272 
273 #if TVM_LLVM_VERSION >= 50
274   builder.Inliner = llvm::createFunctionInliningPass(builder.OptLevel, 0, false);
275 #else
276   builder.Inliner = llvm::createFunctionInliningPass(builder.OptLevel, 0);
277 #endif
278   builder.LoopVectorize = true;
279   builder.SLPVectorize = true;
280   this->InitPassManagerBuilder(&builder);
281 
282 #if TVM_LLVM_VERSION >= 50
283   target_machine_->adjustPassManager(builder);
284 #endif
285 
286   builder.populateFunctionPassManager(fpass);
287   builder.populateModulePassManager(mpass);
288 
289   fpass.doInitialization();
290   for (auto it = module_->begin(); it != module_->end(); ++it) {
291     fpass.run(*it);
292   }
293   fpass.doFinalization();
294   mpass.run(*module_);
295 }
296 
NativeVectorBits(const runtime::StorageScope & storage_scope) const297 int CodeGenLLVM::NativeVectorBits(const runtime::StorageScope& storage_scope) const {
298   return native_vector_bits_;
299 }
300 
GetGlobalAddressSpace() const301 unsigned CodeGenLLVM::GetGlobalAddressSpace() const { return 0; }
302 
DTypeToLLVMType(const DataType & dtype) const303 llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const {
304   if (dtype.is_handle()) {
305     CHECK_EQ(dtype.lanes(), 1);
306     return t_void_p_;
307   }
308   if (dtype.is_void()) {
309     return t_void_;
310   }
311   llvm::Type* etype = nullptr;
312   if (dtype.is_int() || dtype.is_uint()) {
313     etype = llvm::Type::getIntNTy(*ctx_, dtype.bits());
314   } else if (dtype.is_float()) {
315     switch (dtype.bits()) {
316       case 16:
317         etype = llvm::Type::getHalfTy(*ctx_);
318         break;
319       case 32:
320         etype = llvm::Type::getFloatTy(*ctx_);
321         break;
322       case 64:
323         etype = llvm::Type::getDoubleTy(*ctx_);
324         break;
325       default:
326         LOG(FATAL) << "do not support " << dtype;
327     }
328   }
329   if (dtype.lanes() != 1) {
330 #if TVM_LLVM_VERSION >= 110
331     return llvm::FixedVectorType::get(etype, dtype.lanes());
332 #else
333     return llvm::VectorType::get(etype, dtype.lanes());
334 #endif
335   } else {
336     return etype;
337   }
338 }
339 
GetLLVMType(const Type & type) const340 llvm::Type* CodeGenLLVM::GetLLVMType(const Type& type) const {
341   if (auto* ptr = type.as<PrimTypeNode>()) {
342     return DTypeToLLVMType(ptr->dtype);
343   } else if (auto* ptr = type.as<PointerTypeNode>()) {
344     // TODO(tvm-team) consider put storage scope into the pointer type.
345     return GetLLVMType(ptr->element_type)->getPointerTo(GetGlobalAddressSpace());
346   } else if (IsVoidType(type)) {
347     return t_void_;
348   } else {
349     LOG(FATAL) << "Type " << type << " does not have a corresponding LLVM Type";
350     return t_void_;
351   }
352 }
353 
GetLLVMType(const PrimExpr & expr) const354 llvm::Type* CodeGenLLVM::GetLLVMType(const PrimExpr& expr) const {
355   return GetLLVMType(GetType(expr));
356 }
357 
358 // Add tbaa alias information for load
359 //
360 // use a binary tree typed system to declare information
361 // and allow alias to be distinguished across nodes.
362 //
363 // This trick comes from Halide's CodeGen_LLVM
364 //
AddAliasInfo(llvm::Instruction * inst,const VarNode * buffer,PrimExpr index)365 void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst, const VarNode* buffer, PrimExpr index) {
366   if (alias_var_set_.count(buffer) != 0) {
367     // Mark all possibly aliased pointer as same type.
368     llvm::MDNode* meta = md_tbaa_alias_set_;
369     inst->setMetadata("tbaa", md_builder_->createTBAAStructTagNode(meta, meta, 0));
370     return;
371   }
372 
373   int64_t base = 0, width = 0;
374   arith::PVar<IntImm> pbase, pstride;
375   arith::PVar<int> planes;
376   // create meta-data for alias analysis
377   // Use a group of binary tree ranges of memory banks.
378   if (index.defined()) {
379     if (arith::ramp(pbase, pstride, planes).Match(index)) {
380       base = pbase.Eval()->value;
381       int64_t xwith = planes.Eval() * pstride.Eval()->value;
382       width = 1;
383       while (width < xwith) {
384         width *= 2;
385       }
386       while (base % width) {
387         base -= base % width;
388         width *= 2;
389       }
390     } else if (auto* ptr = index.as<tir::IntImmNode>()) {
391       width = 1;
392       base = ptr->value;
393     }
394   }
395   llvm::MDNode* meta = md_tbaa_root_;
396   std::ostringstream buffer_addr;
397   buffer_addr << buffer;
398   meta = md_builder_->createTBAAScalarTypeNode(buffer_addr.str(), meta);
399 
400   // Extract the underlying type of the allocated buffer.
401   llvm::Type* buf_type = GetVarValue(buffer)->getType()->getScalarType();
402   if (buf_type->isPointerTy()) {
403     buf_type = buf_type->getPointerElementType();
404   }
405 
406   std::string tmp;
407   llvm::raw_string_ostream buffer_type(tmp);
408   buffer_type << *buf_type;
409   meta = md_builder_->createTBAAScalarTypeNode(buffer_type.str(), meta);
410 
411   // create a tree-shape access structure.
412   if (width != 0) {
413     for (int64_t w = 1024; w >= width; w /= 2) {
414       int64_t b = (base / w) * w;
415       std::stringstream os;
416       os << buffer << ".w" << w << ".b" << b;
417       meta = md_builder_->createTBAAScalarTypeNode(os.str(), meta);
418     }
419   }
420   inst->setMetadata("tbaa", md_builder_->createTBAAStructTagNode(meta, meta, 0));
421 }
422 
GetAlignment(DataType t,const VarNode * buf_var,const PrimExpr & index,int * p_alignment,int * p_native_bits)423 void CodeGenLLVM::GetAlignment(DataType t, const VarNode* buf_var, const PrimExpr& index,
424                                int* p_alignment, int* p_native_bits) {
425   int max_align_bits = t.bits();
426   auto it = alloc_storage_info_.find(buf_var);
427   if (it != alloc_storage_info_.end()) {
428     const StorageInfo& info = it->second;
429     *p_native_bits = NativeVectorBits(info.scope);
430     max_align_bits = info.alignment * 8;
431   } else {
432     *p_native_bits = native_vector_bits_;
433   }
434 
435   arith::ModularSet me = analyzer_->modular_set(index);
436   int64_t base = me->base;
437   int64_t coeff = me->coeff;
438 
439   int align_bits = t.bits();
440   while (align_bits < max_align_bits && base % 2 == 0 && coeff % 2 == 0) {
441     base = base / 2;
442     coeff = coeff / 2;
443     align_bits *= 2;
444   }
445   if (align_bits < 8) {
446     align_bits = 8;
447   }
448   *p_alignment = align_bits / 8;
449 }
450 
CreateDebugInfo(llvm::Module * module)451 std::unique_ptr<CodeGenLLVM::DebugInfo> CodeGenLLVM::CreateDebugInfo(llvm::Module* module) {
452 #if TVM_LLVM_VERSION >= 100
453   auto debug_info = std::make_unique<CodeGenLLVM::DebugInfo>();
454   debug_info->di_builder_ = std::make_unique<llvm::DIBuilder>(*module);
455 #else
456   auto debug_info = llvm::make_unique<CodeGenLLVM::DebugInfo>();
457   debug_info->di_builder_ = llvm::make_unique<llvm::DIBuilder>(*module);
458 #endif
459   // TODO(tulloch): pass this information through relay::Span classes to the IRModule instance?
460   debug_info->file_ = debug_info->di_builder_->createFile("model.tvm", "/tmp/");
461   debug_info->compilation_unit_ = debug_info->di_builder_->createCompileUnit(
462       llvm::dwarf::DW_LANG_C, debug_info->file_, "TVM", 0, "", 0, "",
463       llvm::DICompileUnit::DebugEmissionKind::FullDebug,
464       /* SplitDebugInlining */ true,
465       /* DebugInfoForProfiling */ true);
466   return debug_info;
467 }
468 
CreateBroadcast(llvm::Value * value,int lanes)469 llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) {
470 #if TVM_LLVM_VERSION >= 110
471   llvm::Type* type = llvm::FixedVectorType::get(value->getType(), lanes);
472 #else
473   llvm::Type* type = llvm::VectorType::get(value->getType(), lanes);
474 #endif
475   llvm::Constant* undef = llvm::UndefValue::get(type);
476   llvm::Constant* zero = ConstInt32(0);
477   value = builder_->CreateInsertElement(undef, value, zero);
478 #if TVM_LLVM_VERSION >= 110
479   llvm::Constant* mask =
480       llvm::ConstantVector::getSplat(llvm::ElementCount(lanes, /*Scalable=*/false), zero);
481 #else
482   llvm::Constant* mask = llvm::ConstantVector::getSplat(lanes, zero);
483 #endif
484   return builder_->CreateShuffleVector(value, undef, mask);
485 }
486 
CreateVecSlice(llvm::Value * vec,int begin,int extent)487 llvm::Value* CodeGenLLVM::CreateVecSlice(llvm::Value* vec, int begin, int extent) {
488   int num_elems = llvm::cast<llvm::VectorType>(vec->getType())->getNumElements();
489   if (extent == num_elems && begin == 0) return vec;
490   CHECK(begin >= 0 && extent <= num_elems) << "Slicing out of bound!\n";
491   std::vector<llvm::Constant*> indices;
492   indices.reserve(extent);
493   for (int i = 0; i < extent; ++i) {
494     if (begin + i >= 0 && begin + i < num_elems) {
495       indices.push_back(llvm::ConstantInt::get(t_int32_, begin + i));
496     } else {
497       indices.push_back(llvm::UndefValue::get(t_int32_));
498     }
499   }
500   return builder_->CreateShuffleVector(vec, vec, llvm::ConstantVector::get(indices));
501 }
502 
CreateVecFlip(llvm::Value * vec)503 llvm::Value* CodeGenLLVM::CreateVecFlip(llvm::Value* vec) {
504   int num_elems = llvm::cast<llvm::VectorType>(vec->getType())->getNumElements();
505 #if TVM_LLVM_VERSION >= 110
506   std::vector<int> indices;
507 #else
508   std::vector<unsigned> indices;
509 #endif
510   for (int i = 0; i < num_elems; ++i) {
511     indices.push_back(num_elems - i - 1);
512   }
513   return builder_->CreateShuffleVector(vec, vec, indices);
514 }
515 
CreateVecPad(llvm::Value * vec,int target_lanes)516 llvm::Value* CodeGenLLVM::CreateVecPad(llvm::Value* vec, int target_lanes) {
517   llvm::Value* mask = llvm::UndefValue::get(DTypeToLLVMType(DataType::Int(32, target_lanes)));
518   int num_elems = llvm::cast<llvm::VectorType>(vec->getType())->getNumElements();
519   if (num_elems == target_lanes) return vec;
520   CHECK_LT(num_elems, target_lanes);
521   for (int i = 0; i < num_elems; ++i) {
522     mask = builder_->CreateInsertElement(mask, ConstInt32(i), ConstInt32(i));
523   }
524   return builder_->CreateShuffleVector(vec, vec, mask);
525 }
526 
CreateVecConcat(std::vector<llvm::Value * > vecs)527 llvm::Value* CodeGenLLVM::CreateVecConcat(std::vector<llvm::Value*> vecs) {
528   // concat vector, tree shape reduction
529   int total_lanes = 0;
530 
531   for (llvm::Value* v : vecs) {
532     total_lanes += llvm::cast<llvm::VectorType>(v->getType())->getNumElements();
533   }
534   while (vecs.size() > 1) {
535     std::vector<llvm::Value*> new_vecs;
536     for (size_t i = 0; i < vecs.size() - 1; i += 2) {
537       llvm::Value* lhs = vecs[i];
538       llvm::Value* rhs = vecs[i + 1];
539       const size_t lhs_lanes = llvm::cast<llvm::VectorType>(lhs->getType())->getNumElements();
540       const size_t rhs_lanes = llvm::cast<llvm::VectorType>(rhs->getType())->getNumElements();
541       if (lhs_lanes < rhs_lanes) {
542         lhs = CreateVecPad(lhs, rhs_lanes);
543       } else if (rhs_lanes < lhs_lanes) {
544         rhs = CreateVecPad(rhs, lhs_lanes);
545       }
546       const size_t shared_lanes = std::max(lhs_lanes, rhs_lanes);
547 #if TVM_LLVM_VERSION >= 110
548       std::vector<int> mask;
549 #else
550       std::vector<unsigned> mask;
551 #endif
552       for (size_t i = 0; i < lhs_lanes; ++i) {
553         mask.push_back(i);
554       }
555       for (size_t i = 0; i < rhs_lanes; ++i) {
556         mask.push_back(shared_lanes + i);
557       }
558       new_vecs.push_back(builder_->CreateShuffleVector(lhs, rhs, mask));
559     }
560     if (vecs.size() % 2 != 0) {
561       new_vecs.push_back(vecs.back());
562     }
563     vecs.swap(new_vecs);
564   }
565   return CreateVecSlice(vecs[0], 0, total_lanes);
566 }
567 
CreateSerialFor(llvm::Value * begin,llvm::Value * end,llvm::Value * stride,const Var & loop_var,const Stmt & body)568 void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Value* stride,
569                                   const Var& loop_var, const Stmt& body) {
570   using llvm::BasicBlock;
571   BasicBlock* pre_block = builder_->GetInsertBlock();
572   BasicBlock* for_begin = BasicBlock::Create(*ctx_, "for_begin", function_);
573   BasicBlock* for_body = BasicBlock::Create(*ctx_, "for_body", function_);
574   BasicBlock* for_end = BasicBlock::Create(*ctx_, "for_end", function_);
575   builder_->CreateBr(for_begin);
576   builder_->SetInsertPoint(for_begin);
577   llvm::PHINode* loop_value = builder_->CreatePHI(begin->getType(), 2);
578   loop_value->addIncoming(begin, pre_block);
579   CHECK(!var_map_.count(loop_var.get()));
580   var_map_[loop_var.get()] = loop_value;
581   builder_->CreateCondBr(CreateLT(loop_var.dtype(), loop_value, end), for_body, for_end,
582                          md_very_likely_branch_);
583   builder_->SetInsertPoint(for_body);
584   this->VisitStmt(body);
585   var_map_.erase(loop_var.get());
586   llvm::Value* loop_next = CreateAdd(loop_var.dtype(), loop_value, stride);
587   loop_value->addIncoming(loop_next, builder_->GetInsertBlock());
588   builder_->CreateBr(for_begin);
589   builder_->SetInsertPoint(for_end);
590 }
591 
592 // cast operatpr
CreateCast(DataType from,DataType to,llvm::Value * value)593 llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* value) {
594   llvm::Type* target = DTypeToLLVMType(to);
595   if (value->getType() == target) return value;
596   if (to.is_handle()) {
597     return builder_->CreateBitCast(value, target);
598   } else if (to.is_uint() && to.bits() == 1) {
599     if (from.is_float()) {
600       llvm::Constant* zero = llvm::ConstantFP::get(DTypeToLLVMType(from), 0.);
601       return builder_->CreateFCmpONE(value, zero);
602     } else {
603       llvm::Constant* zero = llvm::ConstantInt::get(DTypeToLLVMType(from), 0);
604       return builder_->CreateICmpNE(value, zero);
605     }
606   } else if (!from.is_float() && !to.is_float()) {
607     return builder_->CreateIntCast(value, target, from.is_int());
608   } else if (from.is_float() && to.is_int()) {
609     return builder_->CreateFPToSI(value, target);
610   } else if (from.is_float() && to.is_uint()) {
611     if (to.bits() < 8) {
612       value = builder_->CreateFPToUI(value, DTypeToLLVMType(to.with_bits(8)));
613       return builder_->CreateIntCast(value, target, false);
614     } else {
615       return builder_->CreateFPToUI(value, target);
616     }
617   } else if (from.is_int() && to.is_float()) {
618     return builder_->CreateSIToFP(value, target);
619   } else if (from.is_uint() && to.is_float()) {
620     return builder_->CreateUIToFP(value, target);
621   } else {
622     CHECK(from.is_float() && to.is_float());
623     return builder_->CreateFPCast(value, target);
624   }
625 }
626 
GetConstString(const std::string & str)627 llvm::Constant* CodeGenLLVM::GetConstString(const std::string& str) {
628   auto it = str_map_.find(str);
629   if (it != str_map_.end()) return it->second;
630   llvm::Type* type = llvm::ArrayType::get(t_char_, str.length() + 1);
631   llvm::GlobalVariable* global = new llvm::GlobalVariable(
632       *module_, type, true, llvm::GlobalValue::PrivateLinkage, nullptr, ".str");
633 #if TVM_LLVM_VERSION >= 100
634   global->setAlignment(llvm::Align(1));
635 #else
636   global->setAlignment(1);
637 #endif
638   global->setInitializer(llvm::ConstantDataArray::getString(*ctx_, str));
639   llvm::Constant* zero = ConstInt32(0);
640   llvm::Constant* indices[] = {zero, zero};
641   llvm::Constant* ptr = llvm::ConstantExpr::getGetElementPtr(type, global, indices);
642   str_map_[str] = ptr;
643   return ptr;
644 }
645 
CreateBufferPtr(DataType t,llvm::Value * buffer,llvm::Value * index)646 llvm::Value* CodeGenLLVM::CreateBufferPtr(DataType t, llvm::Value* buffer, llvm::Value* index) {
647   llvm::PointerType* btype = llvm::dyn_cast<llvm::PointerType>(buffer->getType());
648   CHECK(btype != nullptr);
649   llvm::PointerType* ptype = DTypeToLLVMType(t)->getPointerTo(btype->getAddressSpace());
650   if (btype != ptype) {
651     buffer = builder_->CreatePointerCast(buffer, ptype);
652   }
653   return builder_->CreateInBoundsGEP(buffer, index);
654 }
655 
GetVarValue(const VarNode * v) const656 llvm::Value* CodeGenLLVM::GetVarValue(const VarNode* v) const {
657   auto it = var_map_.find(v);
658   CHECK(it != var_map_.end()) << "cannot find variable " << v->name_hint;
659   return it->second;
660 }
661 
CreateCallExtern(Type ret_type,String global_symbol,const Array<PrimExpr> & args,bool skip_first_arg)662 llvm::Value* CodeGenLLVM::CreateCallExtern(Type ret_type, String global_symbol,
663                                            const Array<PrimExpr>& args, bool skip_first_arg) {
664   std::vector<llvm::Value*> arg_value;
665   std::vector<llvm::Type*> arg_type;
666   for (size_t i = static_cast<size_t>(skip_first_arg); i < args.size(); ++i) {
667     arg_value.push_back(MakeValue(args[i]));
668     arg_type.push_back(arg_value.back()->getType());
669   }
670   llvm::FunctionType* ftype = llvm::FunctionType::get(GetLLVMType(ret_type), arg_type, false);
671   llvm::Function* f = module_->getFunction(global_symbol);
672   if (f == nullptr) {
673     f = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage,
674                                global_symbol.operator llvm::StringRef(), module_.get());
675   }
676   llvm::CallInst* call = builder_->CreateCall(f, arg_value);
677   return call;
678 }
679 
GetIntrinsicDecl(llvm::Intrinsic::ID id,llvm::Type * ret_type,llvm::ArrayRef<llvm::Type * > arg_types)680 llvm::Function* CodeGenLLVM::GetIntrinsicDecl(llvm::Intrinsic::ID id, llvm::Type* ret_type,
681                                               llvm::ArrayRef<llvm::Type*> arg_types) {
682   llvm::Module* module = module_.get();
683 
684   if (!llvm::Intrinsic::isOverloaded(id)) {
685     return llvm::Intrinsic::getDeclaration(module, id, {});
686   }
687 
688   llvm::SmallVector<llvm::Intrinsic::IITDescriptor, 4> infos;
689   llvm::Intrinsic::getIntrinsicInfoTableEntries(id, infos);
690   llvm::SmallVector<llvm::Type*, 4> overload_types;
691 
692 #if TVM_LLVM_VERSION >= 90
693   auto try_match = [&](llvm::FunctionType* f_ty, bool var_arg) {
694     overload_types.clear();
695     llvm::ArrayRef<llvm::Intrinsic::IITDescriptor> ref(infos);
696     auto match = llvm::Intrinsic::matchIntrinsicSignature(f_ty, ref, overload_types);
697     if (match == llvm::Intrinsic::MatchIntrinsicTypes_Match) {
698       bool error = llvm::Intrinsic::matchIntrinsicVarArg(var_arg, ref);
699       if (error) {
700         return llvm::Intrinsic::MatchIntrinsicTypes_NoMatchArg;
701       }
702     }
703     return match;
704   };
705 
706   // First, try matching the signature assuming non-vararg case.
707   auto* fn_ty = llvm::FunctionType::get(ret_type, arg_types, false);
708   switch (try_match(fn_ty, false)) {
709     case llvm::Intrinsic::MatchIntrinsicTypes_NoMatchRet:
710       // The return type doesn't match, there is nothing else to do.
711       return nullptr;
712     case llvm::Intrinsic::MatchIntrinsicTypes_Match:
713       return llvm::Intrinsic::getDeclaration(module, id, overload_types);
714     case llvm::Intrinsic::MatchIntrinsicTypes_NoMatchArg:
715       break;
716   }
717 
718   // Keep adding one type at a time (starting from empty list), and
719   // try matching the vararg signature.
720   llvm::SmallVector<llvm::Type*, 4> var_types;
721   for (int i = 0, e = arg_types.size(); i <= e; ++i) {
722     if (i > 0) var_types.push_back(arg_types[i - 1]);
723     auto* ft = llvm::FunctionType::get(ret_type, var_types, true);
724     if (try_match(ft, true) == llvm::Intrinsic::MatchIntrinsicTypes_Match) {
725       return llvm::Intrinsic::getDeclaration(module, id, overload_types);
726     }
727   }
728   // Failed to identify the type.
729   return nullptr;
730 
731 #else   // TVM_LLVM_VERSION
732   llvm::ArrayRef<llvm::Intrinsic::IITDescriptor> ref(infos);
733   // matchIntrinsicType returns true on error.
734   if (llvm::Intrinsic::matchIntrinsicType(ret_type, ref, overload_types)) {
735     return nullptr;
736   }
737   for (llvm::Type* t : arg_types) {
738     if (llvm::Intrinsic::matchIntrinsicType(t, ref, overload_types)) {
739       return nullptr;
740     }
741   }
742   return llvm::Intrinsic::getDeclaration(module, id, overload_types);
743 #endif  // TVM_LLVM_VERSION
744 }
745 
CreateIntrinsic(const CallNode * op)746 llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
747   if (op->op.same_as(builtin_call_llvm_intrin_) || op->op.same_as(builtin_call_llvm_pure_intrin_)) {
748     CHECK_GE(op->args.size(), 2U);
749     llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(Downcast<IntImm>(op->args[0])->value);
750     int64_t num_signature = Downcast<IntImm>(op->args[1])->value;
751     std::vector<llvm::Value*> arg_value;
752     std::vector<llvm::Type*> arg_type;
753     for (size_t i = 2; i < op->args.size(); ++i) {
754       arg_value.push_back(MakeValue(op->args[i]));
755       if (i - 2 < static_cast<size_t>(num_signature)) {
756         arg_type.push_back(arg_value.back()->getType());
757       }
758     }
759     // LLVM's prefetch intrinsic returns "void", while TVM's prefetch
760     // returns int32. This causes problems because prefetch is one of
761     // those intrinsics that is generated automatically via the
762     // tvm.intrin.rule mechanism. Any other intrinsic with a type
763     // mismatch will have to be treated specially here.
764     // TODO(kparzysz-quic): fix this once TVM prefetch uses the same
765     // type as LLVM.
766     llvm::Type* return_type = (id != llvm::Intrinsic::prefetch) ? GetLLVMType(GetRef<PrimExpr>(op))
767                                                                 : llvm::Type::getVoidTy(*ctx_);
768     llvm::Function* f = GetIntrinsicDecl(id, return_type, arg_type);
769     CHECK(f) << "Cannot find intrinsic declaration, possible type mismatch: "
770              << llvm::Intrinsic::getName(id, {});
771     return builder_->CreateCall(f, arg_value);
772   } else if (op->op.same_as(builtin::bitwise_and())) {
773     return builder_->CreateAnd(MakeValue(op->args[0]), MakeValue(op->args[1]));
774   } else if (op->op.same_as(builtin::bitwise_or())) {
775     return builder_->CreateOr(MakeValue(op->args[0]), MakeValue(op->args[1]));
776   } else if (op->op.same_as(builtin::bitwise_not())) {
777     return builder_->CreateNot(MakeValue(op->args[0]));
778   } else if (op->op.same_as(builtin::bitwise_xor())) {
779     return builder_->CreateXor(MakeValue(op->args[0]), MakeValue(op->args[1]));
780   } else if (op->op.same_as(builtin::shift_left())) {
781     return builder_->CreateShl(MakeValue(op->args[0]), MakeValue(op->args[1]));
782   } else if (op->op.same_as(builtin::shift_right())) {
783     if (op->args[0].dtype().is_int()) {
784       return builder_->CreateAShr(MakeValue(op->args[0]), MakeValue(op->args[1]));
785     } else {
786       return builder_->CreateLShr(MakeValue(op->args[0]), MakeValue(op->args[1]));
787     }
788   } else if (op->op.same_as(builtin::tvm_storage_sync())) {
789     return CreateStorageSync(op);
790   } else if (op->op.same_as(builtin::address_of())) {
791     const LoadNode* l = op->args[0].as<LoadNode>();
792     CHECK(op->args.size() == 1 && l);
793     const RampNode* r = l->index.as<RampNode>();
794     llvm::Value* ptr;
795     unsigned addrspace;
796     if (!r) {
797       ptr = CreateBufferPtr(l->dtype, MakeValue(l->buffer_var), MakeValue(l->index));
798       addrspace = llvm::dyn_cast<llvm::PointerType>(ptr->getType())->getAddressSpace();
799     } else {
800       PrimExpr index = r->base / make_const(DataType::Int(32), r->lanes);
801       ptr = CreateBufferPtr(l->dtype, MakeValue(l->buffer_var), MakeValue(index));
802       addrspace = llvm::dyn_cast<llvm::PointerType>(ptr->getType())->getAddressSpace();
803     }
804     return builder_->CreatePointerCast(ptr, t_char_->getPointerTo(addrspace));
805   } else if (op->op.same_as(builtin::reinterpret()) && is_zero(op->args[0])) {
806     return llvm::Constant::getNullValue(t_void_p_);
807   } else if (op->op.same_as(builtin::isnullptr())) {
808     return builder_->CreateIsNull(MakeValue(op->args[0]));
809   } else if (op->op.same_as(builtin::large_uint_imm())) {
810     CHECK_EQ(op->args.size(), 2U);
811     uint64_t low = static_cast<uint64_t>(Downcast<IntImm>(op->args[0])->value);
812     uint64_t high = static_cast<uint64_t>(Downcast<IntImm>(op->args[1])->value);
813     uint64_t val = (high << 32U) | low;
814     return llvm::ConstantInt::get(DTypeToLLVMType(op->dtype), val);
815   } else if (op->op.same_as(builtin::if_then_else())) {
816     CHECK_EQ(op->args[0].dtype().lanes(), 1) << "if_then_else can only take scalar condition";
817     using llvm::BasicBlock;
818     BasicBlock* then_block = BasicBlock::Create(*ctx_, "if_then", function_);
819     BasicBlock* else_block = BasicBlock::Create(*ctx_, "if_else", function_);
820     BasicBlock* end_block = BasicBlock::Create(*ctx_, "if_end", function_);
821     builder_->CreateCondBr(MakeValue(op->args[0]), then_block, else_block);
822     builder_->SetInsertPoint(then_block);
823     llvm::Value* then_value = MakeValue(op->args[1]);
824     BasicBlock* then_value_block = builder_->GetInsertBlock();
825     builder_->CreateBr(end_block);
826     builder_->SetInsertPoint(else_block);
827     llvm::Value* else_value = MakeValue(op->args[2]);
828     BasicBlock* else_value_block = builder_->GetInsertBlock();
829     builder_->CreateBr(end_block);
830     builder_->SetInsertPoint(end_block);
831     llvm::PHINode* value = builder_->CreatePHI(then_value->getType(), 2);
832     value->addIncoming(then_value, then_value_block);
833     value->addIncoming(else_value, else_value_block);
834     return value;
835   } else if (op->op.same_as(builtin::reinterpret())) {
836     llvm::Type* target = DTypeToLLVMType(op->dtype);
837     return builder_->CreateBitCast(MakeValue(op->args[0]), target);
838   } else if (op->op.same_as(builtin::isnan())) {
839     // TODO(hgt312): set fast math flag
840     llvm::Value* a = MakeValue(op->args[0]);
841     return builder_->CreateFCmpUNO(a, a);
842   } else if (op->op.same_as(builtin::vectorlow())) {
843     llvm::Value* v = MakeValue(op->args[0]);
844     int l = llvm::cast<llvm::VectorType>(v->getType())->getNumElements();
845     return CreateVecSlice(v, 0, l / 2);
846   } else if (op->op.same_as(builtin::vectorhigh())) {
847     llvm::Value* v = MakeValue(op->args[0]);
848     int l = llvm::cast<llvm::VectorType>(v->getType())->getNumElements();
849     return CreateVecSlice(v, l / 2, l / 2);
850   } else if (op->op.same_as(builtin::vectorcombine())) {
851     llvm::Value* v0 = MakeValue(op->args[0]);
852     llvm::Value* v1 = MakeValue(op->args[1]);
853     int num_elems = llvm::cast<llvm::VectorType>(v0->getType())->getNumElements() * 2;
854 #if TVM_LLVM_VERSION >= 110
855     std::vector<int> indices;
856 #else
857     std::vector<unsigned> indices;
858 #endif
859     for (int i = 0; i < num_elems; ++i) {
860       indices.push_back(i);
861     }
862     return builder_->CreateShuffleVector(v0, v1, indices);
863   } else {
864     LOG(FATAL) << "unknown intrinsic " << op->op;
865     return nullptr;
866   }
867 }
868 
Scalarize(const PrimExpr & e,std::function<void (int i,llvm::Value * v)> f)869 void CodeGenLLVM::Scalarize(const PrimExpr& e, std::function<void(int i, llvm::Value* v)> f) {
870   if (const RampNode* ramp = e.as<RampNode>()) {
871     for (int i = 0; i < ramp->dtype.lanes(); ++i) {
872       PrimExpr offset = ramp->base + (ramp->stride * i);
873       f(i, MakeValue(offset));
874     }
875   } else {
876     llvm::Value* value = MakeValue(e);
877     for (int i = 0; i < e.dtype().lanes(); ++i) {
878       f(i, builder_->CreateExtractElement(value, i));
879     }
880   }
881 }
882 
883 // Visitors
VisitExpr_(const VarNode * op)884 llvm::Value* CodeGenLLVM::VisitExpr_(const VarNode* op) { return GetVarValue(op); }
885 
VisitExpr_(const CastNode * op)886 llvm::Value* CodeGenLLVM::VisitExpr_(const CastNode* op) {
887   return CreateCast(op->value.dtype(), op->dtype, MakeValue(op->value));
888 }
VisitExpr_(const IntImmNode * op)889 llvm::Value* CodeGenLLVM::VisitExpr_(const IntImmNode* op) {
890   return llvm::ConstantInt::getSigned(DTypeToLLVMType(op->dtype), op->value);
891 }
892 
VisitExpr_(const FloatImmNode * op)893 llvm::Value* CodeGenLLVM::VisitExpr_(const FloatImmNode* op) {
894   return llvm::ConstantFP::get(DTypeToLLVMType(op->dtype), op->value);
895 }
896 
VisitExpr_(const StringImmNode * op)897 llvm::Value* CodeGenLLVM::VisitExpr_(const StringImmNode* op) { return GetConstString(op->value); }
898 
899 #define DEFINE_CODEGEN_BINARY_OP(Op)                                                 \
900   llvm::Value* CodeGenLLVM::Create##Op(DataType t, llvm::Value* a, llvm::Value* b) { \
901     if (t.is_int()) {                                                                \
902       if (t.bits() >= 32) {                                                          \
903         return builder_->CreateNSW##Op(a, b);                                        \
904       } else {                                                                       \
905         return builder_->Create##Op(a, b);                                           \
906       }                                                                              \
907     } else if (t.is_uint()) {                                                        \
908       if (t.bits() >= 32) {                                                          \
909         return builder_->CreateNUW##Op(a, b);                                        \
910       } else {                                                                       \
911         return builder_->Create##Op(a, b);                                           \
912       }                                                                              \
913     } else {                                                                         \
914       CHECK(t.is_float());                                                           \
915       return builder_->CreateF##Op(a, b);                                            \
916     }                                                                                \
917   }                                                                                  \
918   llvm::Value* CodeGenLLVM::VisitExpr_(const Op##Node* op) {                         \
919     return Create##Op(op->dtype, MakeValue(op->a), MakeValue(op->b));                \
920   }
921 
922 DEFINE_CODEGEN_BINARY_OP(Add);
923 DEFINE_CODEGEN_BINARY_OP(Sub);
924 DEFINE_CODEGEN_BINARY_OP(Mul);
925 
926 #define DEFINE_CODEGEN_CMP_OP(Op)                                                    \
927   llvm::Value* CodeGenLLVM::Create##Op(DataType t, llvm::Value* a, llvm::Value* b) { \
928     if (t.is_int()) {                                                                \
929       return builder_->CreateICmpS##Op(a, b);                                        \
930     } else if (t.is_uint()) {                                                        \
931       return builder_->CreateICmpU##Op(a, b);                                        \
932     } else {                                                                         \
933       CHECK(t.is_float());                                                           \
934       return builder_->CreateFCmpO##Op(a, b);                                        \
935     }                                                                                \
936   }                                                                                  \
937   llvm::Value* CodeGenLLVM::VisitExpr_(const Op##Node* op) {                         \
938     return Create##Op(op->a.dtype(), MakeValue(op->a), MakeValue(op->b));            \
939   }
940 
941 DEFINE_CODEGEN_CMP_OP(LT);
942 DEFINE_CODEGEN_CMP_OP(LE);
943 DEFINE_CODEGEN_CMP_OP(GT);
944 DEFINE_CODEGEN_CMP_OP(GE);
945 
VisitExpr_(const DivNode * op)946 llvm::Value* CodeGenLLVM::VisitExpr_(const DivNode* op) {
947   llvm::Value* a = MakeValue(op->a);
948   llvm::Value* b = MakeValue(op->b);
949   if (op->dtype.is_int()) {
950     return builder_->CreateSDiv(a, b);
951   } else if (op->dtype.is_uint()) {
952     return builder_->CreateUDiv(a, b);
953   } else {
954     CHECK(op->dtype.is_float());
955     return builder_->CreateFDiv(a, b);
956   }
957 }
958 
VisitExpr_(const ModNode * op)959 llvm::Value* CodeGenLLVM::VisitExpr_(const ModNode* op) {
960   llvm::Value* a = MakeValue(op->a);
961   llvm::Value* b = MakeValue(op->b);
962   if (op->dtype.is_int()) {
963     return builder_->CreateSRem(a, b);
964   } else if (op->dtype.is_uint()) {
965     return builder_->CreateURem(a, b);
966   } else {
967     CHECK(op->dtype.is_float());
968     return builder_->CreateFRem(a, b);
969   }
970 }
971 
VisitExpr_(const MinNode * op)972 llvm::Value* CodeGenLLVM::VisitExpr_(const MinNode* op) {
973   llvm::Value* a = MakeValue(op->a);
974   llvm::Value* b = MakeValue(op->b);
975   return builder_->CreateSelect(CreateLT(op->a.dtype(), a, b), a, b);
976 }
977 
VisitExpr_(const MaxNode * op)978 llvm::Value* CodeGenLLVM::VisitExpr_(const MaxNode* op) {
979   llvm::Value* a = MakeValue(op->a);
980   llvm::Value* b = MakeValue(op->b);
981   return builder_->CreateSelect(CreateGT(op->a.dtype(), a, b), a, b);
982 }
983 
VisitExpr_(const EQNode * op)984 llvm::Value* CodeGenLLVM::VisitExpr_(const EQNode* op) {
985   llvm::Value* a = MakeValue(op->a);
986   llvm::Value* b = MakeValue(op->b);
987   if (op->a.dtype().is_int() || op->a.dtype().is_uint()) {
988     return builder_->CreateICmpEQ(a, b);
989   } else {
990     return builder_->CreateFCmpOEQ(a, b);
991   }
992 }
993 
VisitExpr_(const NENode * op)994 llvm::Value* CodeGenLLVM::VisitExpr_(const NENode* op) {
995   llvm::Value* a = MakeValue(op->a);
996   llvm::Value* b = MakeValue(op->b);
997   if (op->a.dtype().is_int() || op->a.dtype().is_uint()) {
998     return builder_->CreateICmpNE(a, b);
999   } else {
1000     return builder_->CreateFCmpONE(a, b);
1001   }
1002 }
1003 
VisitExpr_(const AndNode * op)1004 llvm::Value* CodeGenLLVM::VisitExpr_(const AndNode* op) {
1005   return builder_->CreateAnd(MakeValue(op->a), MakeValue(op->b));
1006 }
1007 
VisitExpr_(const OrNode * op)1008 llvm::Value* CodeGenLLVM::VisitExpr_(const OrNode* op) {
1009   return builder_->CreateOr(MakeValue(op->a), MakeValue(op->b));
1010 }
1011 
VisitExpr_(const NotNode * op)1012 llvm::Value* CodeGenLLVM::VisitExpr_(const NotNode* op) {
1013   return builder_->CreateNot(MakeValue(op->a));
1014 }
1015 
VisitExpr_(const SelectNode * op)1016 llvm::Value* CodeGenLLVM::VisitExpr_(const SelectNode* op) {
1017   return builder_->CreateSelect(MakeValue(op->condition), MakeValue(op->true_value),
1018                                 MakeValue(op->false_value));
1019 }
1020 
VisitExpr_(const LetNode * op)1021 llvm::Value* CodeGenLLVM::VisitExpr_(const LetNode* op) {
1022   auto it = let_binding_.find(op->var);
1023   if (it != let_binding_.end()) {
1024     CHECK(deep_equal_(it->second->value, op->value))
1025         << "Let cannot bind the same var to two different values";
1026   } else {
1027     let_binding_[op->var] = op;
1028   }
1029   var_map_[op->var.get()] = MakeValue(op->value);
1030   analyzer_->Bind(op->var, op->value);
1031   return MakeValue(op->body);
1032 }
1033 
VisitExpr_(const LoadNode * op)1034 llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) {
1035   DataType t = op->dtype;
1036   bool is_volatile = volatile_buf_.count(op->buffer_var.get());
1037   llvm::Value* buffer = MakeValue(op->buffer_var);
1038   llvm::Value* index = MakeValue(op->index);
1039 
1040   if (t.lanes() == 1) {
1041     int alignment, native_bits;
1042     GetAlignment(t, op->buffer_var.get(), op->index, &alignment, &native_bits);
1043     llvm::Value* ptr = CreateBufferPtr(t, buffer, index);
1044 #if TVM_LLVM_VERSION >= 110
1045     llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, llvm::Align(alignment), is_volatile);
1046 #else
1047     llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, alignment, is_volatile);
1048 #endif
1049     AddAliasInfo(load, op->buffer_var.get(), op->index);
1050     return load;
1051   } else {
1052     // vector load
1053     unsigned addrspace = llvm::dyn_cast<llvm::PointerType>(buffer->getType())->getAddressSpace();
1054     if (const RampNode* ramp = op->index.as<RampNode>()) {
1055       if (is_one(ramp->stride)) {
1056         int alignment, native_bits;
1057         GetAlignment(t, op->buffer_var.get(), ramp->base, &alignment, &native_bits);
1058         CHECK_EQ(ramp->lanes, t.lanes());
1059         llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, MakeValue(ramp->base));
1060         ptr = builder_->CreatePointerCast(ptr, DTypeToLLVMType(t)->getPointerTo(addrspace));
1061 #if TVM_LLVM_VERSION >= 110
1062         llvm::LoadInst* load =
1063             builder_->CreateAlignedLoad(ptr, llvm::Align(alignment), is_volatile);
1064 #else
1065         llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, alignment, is_volatile);
1066 #endif
1067         AddAliasInfo(load, op->buffer_var.get(), op->index);
1068         return load;
1069       }
1070     }
1071   }
1072   // scalarized load.
1073   int basic_align = t.bits() / 8;
1074   llvm::Value* ret = llvm::UndefValue::get(DTypeToLLVMType(t));
1075   auto f = [&](int i, llvm::Value* index) {
1076     llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, index);
1077 #if TVM_LLVM_VERSION >= 110
1078     llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, llvm::Align(basic_align), is_volatile);
1079 #else
1080     llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, basic_align, is_volatile);
1081 #endif
1082     ret = builder_->CreateInsertElement(ret, load, ConstInt32(i));
1083     AddAliasInfo(load, op->buffer_var.get(), PrimExpr());
1084   };
1085   this->Scalarize(op->index, f);
1086   return ret;
1087 }
1088 
VisitExpr_(const CallNode * op)1089 llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) {
1090   if (auto* ptr_op = op->op.as<OpNode>()) {
1091     auto call_op = GetRef<Op>(ptr_op);
1092     if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) {
1093       // call extern intrinsic
1094       CHECK_GE(op->args.size(), 1U);
1095       auto global_symbol = Downcast<StringImm>(op->args[0]);
1096       return this->CreateCallExtern(GetType(GetRef<PrimExpr>(op)), global_symbol->value, op->args,
1097                                     true);
1098     } else if (op_attr_global_symbol_.count(call_op)) {
1099       // call extern if the op itself have a global symbol.
1100       return this->CreateCallExtern(GetType(GetRef<PrimExpr>(op)), op_attr_global_symbol_[call_op],
1101                                     op->args, false);
1102     } else {
1103       return CreateIntrinsic(op);
1104     }
1105   } else {
1106     CHECK(op->op.as<GlobalVarNode>());
1107     LOG(FATAL) << "Do not yet support cross function call";
1108     return nullptr;
1109   }
1110 }
1111 
VisitExpr_(const RampNode * op)1112 llvm::Value* CodeGenLLVM::VisitExpr_(const RampNode* op) {
1113   llvm::Value* vec = llvm::UndefValue::get(DTypeToLLVMType(op->dtype));
1114   for (int i = 0; i < op->lanes; ++i) {
1115     vec = builder_->CreateInsertElement(
1116         vec, MakeValue(op->base + op->stride * make_const(op->stride.dtype(), i)), ConstInt32(i));
1117   }
1118   return vec;
1119 }
1120 
VisitExpr_(const ShuffleNode * op)1121 llvm::Value* CodeGenLLVM::VisitExpr_(const ShuffleNode* op) {
1122   std::vector<llvm::Value*> vecs(op->vectors.size());
1123   int total_lanes = 0;
1124   for (int i = 0, e = op->vectors.size(); i < e; ++i) {
1125     vecs[i] = VisitExpr(op->vectors[i]);
1126     total_lanes += op->vectors[i].dtype().lanes();
1127   }
1128   llvm::Value* v0 = CreateVecConcat(vecs);
1129   std::vector<uint32_t> idx(op->indices.size());
1130   for (int i = 0, e = op->indices.size(); i < e; ++i) {
1131     const int64_t* val = as_const_int(op->indices[i]);
1132     CHECK(val && *val >= 0 && *val < total_lanes) << "Shuffled indeces are suppose to be int, "
1133                                                   << "but get " << op->indices[i] << "\n";
1134     idx[i] = *val;
1135   }
1136   llvm::Value* mask = llvm::ConstantDataVector::get(builder_->getContext(), idx);
1137   auto res = builder_->CreateShuffleVector(v0, llvm::UndefValue::get(v0->getType()), mask);
1138   // If the output is a single-element vector, convert it back to a scalar.
1139   if (idx.size() == 1) {
1140     res = builder_->CreateExtractElement(res, ConstInt32(0));
1141   }
1142   return res;
1143 }
1144 
VisitExpr_(const BroadcastNode * op)1145 llvm::Value* CodeGenLLVM::VisitExpr_(const BroadcastNode* op) {
1146   return CreateBroadcast(MakeValue(op->value), op->lanes);
1147 }
1148 
VisitStmt_(const StoreNode * op)1149 void CodeGenLLVM::VisitStmt_(const StoreNode* op) {
1150   CHECK(is_one(op->predicate));
1151   DataType t = op->value.dtype();
1152   bool is_volatile = volatile_buf_.count(op->buffer_var.get());
1153   llvm::Value* buffer = MakeValue(op->buffer_var);
1154   llvm::Value* index = MakeValue(op->index);
1155   llvm::Value* value = MakeValue(op->value);
1156 
1157   if (t.lanes() == 1) {
1158     int alignment, native_bits;
1159     GetAlignment(t, op->buffer_var.get(), op->index, &alignment, &native_bits);
1160     llvm::Value* ptr = CreateBufferPtr(t, buffer, index);
1161 #if TVM_LLVM_VERSION >= 110
1162     llvm::StoreInst* store =
1163         builder_->CreateAlignedStore(value, ptr, llvm::Align(alignment), is_volatile);
1164 #else
1165     llvm::StoreInst* store = builder_->CreateAlignedStore(value, ptr, alignment, is_volatile);
1166 #endif
1167     AddAliasInfo(store, op->buffer_var.get(), op->index);
1168     return;
1169   } else {
1170     // vector store
1171     unsigned addrspace = llvm::dyn_cast<llvm::PointerType>(buffer->getType())->getAddressSpace();
1172     if (const RampNode* ramp = op->index.as<RampNode>()) {
1173       if (is_one(ramp->stride)) {
1174         int alignment, native_bits;
1175         GetAlignment(t, op->buffer_var.get(), ramp->base, &alignment, &native_bits);
1176         CHECK_EQ(ramp->lanes, t.lanes());
1177         llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, MakeValue(ramp->base));
1178         ptr = builder_->CreatePointerCast(ptr, DTypeToLLVMType(t)->getPointerTo(addrspace));
1179 #if TVM_LLVM_VERSION >= 110
1180         llvm::StoreInst* store =
1181             builder_->CreateAlignedStore(value, ptr, llvm::Align(alignment), is_volatile);
1182 #else
1183         llvm::StoreInst* store = builder_->CreateAlignedStore(value, ptr, alignment, is_volatile);
1184 #endif
1185         AddAliasInfo(store, op->buffer_var.get(), op->index);
1186         return;
1187       }
1188     }
1189   }
1190   CHECK_GE(t.bits(), 8);
1191   // scalarized store.
1192   int basic_align = t.bits() / 8;
1193   auto f = [&](int i, llvm::Value* index) {
1194     llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, index);
1195 #if TVM_LLVM_VERSION >= 110
1196     llvm::StoreInst* store = builder_->CreateAlignedStore(
1197         builder_->CreateExtractElement(value, i), ptr, llvm::Align(basic_align), is_volatile);
1198 #else
1199     llvm::StoreInst* store = builder_->CreateAlignedStore(builder_->CreateExtractElement(value, i),
1200                                                           ptr, basic_align, is_volatile);
1201 #endif
1202     AddAliasInfo(store, op->buffer_var.get(), PrimExpr());
1203   };
1204   this->Scalarize(op->index, f);
1205 }
1206 
VisitStmt_(const ForNode * op)1207 void CodeGenLLVM::VisitStmt_(const ForNode* op) {
1208   CHECK(is_zero(op->min));
1209   analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
1210   if (op->for_type == ForType::Unrolled) {
1211     LOG(WARNING) << "Unroll hint get ignore at CodeGenLLVM backend, "
1212                  << " consider set unroll_explicit=True";
1213   } else {
1214     CHECK(op->for_type == ForType::Serial);
1215   }
1216   CreateSerialFor(MakeValue(op->min), MakeValue(op->extent),
1217                   llvm::ConstantInt::getSigned(GetLLVMType(op->extent), 1), op->loop_var, op->body);
1218 }
1219 
VisitStmt_(const IfThenElseNode * op)1220 void CodeGenLLVM::VisitStmt_(const IfThenElseNode* op) {
1221   using llvm::BasicBlock;
1222   llvm::Value* cond = MakeValue(op->condition);
1223   BasicBlock* then_block = BasicBlock::Create(*ctx_, "if_then", function_);
1224   BasicBlock* end_block = BasicBlock::Create(*ctx_, "if_end", function_);
1225   if (op->else_case.defined()) {
1226     BasicBlock* else_block = BasicBlock::Create(*ctx_, "if_else", function_);
1227     builder_->CreateCondBr(cond, then_block, else_block);
1228     builder_->SetInsertPoint(then_block);
1229     this->VisitStmt(op->then_case);
1230     builder_->CreateBr(end_block);
1231     builder_->SetInsertPoint(else_block);
1232     this->VisitStmt(op->else_case);
1233     builder_->CreateBr(end_block);
1234   } else {
1235     builder_->CreateCondBr(cond, then_block, end_block, md_very_likely_branch_);
1236     builder_->SetInsertPoint(then_block);
1237     this->VisitStmt(op->then_case);
1238     builder_->CreateBr(end_block);
1239   }
1240   builder_->SetInsertPoint(end_block);
1241 }
1242 
VisitStmt_(const AllocateNode * op)1243 void CodeGenLLVM::VisitStmt_(const AllocateNode* op) {
1244   CHECK(!is_zero(op->condition));
1245   llvm::Value* buf = nullptr;
1246 
1247   int32_t constant_size = op->constant_allocation_size();
1248   CHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation";
1249   StorageInfo& info = alloc_storage_info_[op->buffer_var.get()];
1250   if (constant_size % 4 == 0 && info.alignment == 0) {
1251     info.alignment = GetTempAllocaAlignment(op->dtype, constant_size);
1252   }
1253   // maximum necessary alignment in the NV devices
1254   if (info.alignment > 16) {
1255     info.alignment = 16;
1256   }
1257   llvm::AllocaInst* alloca = WithFunctionEntry([&]() {
1258     return builder_->CreateAlloca(DTypeToLLVMType(op->dtype), ConstInt32(constant_size));
1259   });
1260   if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) {
1261 #if TVM_LLVM_VERSION >= 100
1262     alloca->setAlignment(llvm::Align(info.alignment));
1263 #else
1264     alloca->setAlignment(info.alignment);
1265 #endif
1266   }
1267   info.alignment = alloca->getAlignment();
1268   buf = alloca;
1269 
1270   buf = builder_->CreatePointerCast(
1271       buf, DTypeToLLVMType(op->dtype)->getPointerTo(buf->getType()->getPointerAddressSpace()));
1272   CHECK(!var_map_.count(op->buffer_var.get()));
1273   var_map_[op->buffer_var.get()] = buf;
1274   this->VisitStmt(op->body);
1275 }
1276 
VisitStmt_(const AttrStmtNode * op)1277 void CodeGenLLVM::VisitStmt_(const AttrStmtNode* op) {
1278   if (op->attr_key == tir::attr::thread_extent) {
1279     IterVar iv = Downcast<IterVar>(op->node);
1280     if (iv->thread_tag.length() != 0) {
1281       if (!var_map_.count(iv->var.get())) {
1282         var_map_[iv->var.get()] = GetThreadIndex(iv);
1283         analyzer_->Bind(iv->var, Range::FromMinExtent(0, op->value));
1284       }
1285     }
1286   } else if (op->attr_key == tir::attr::storage_scope) {
1287     const VarNode* v = op->node.as<VarNode>();
1288     CHECK(v);
1289     alloc_storage_info_[v].scope =
1290         runtime::StorageScope::Create(op->value.as<StringImmNode>()->value);
1291   } else if (op->attr_key == tir::attr::storage_alignment) {
1292     const VarNode* v = op->node.as<VarNode>();
1293     CHECK(v);
1294     alloc_storage_info_[v].alignment = static_cast<int>(op->value.as<IntImmNode>()->value);
1295     if (var_map_.count(v) && alloc_storage_info_[v].alignment > 1) {
1296       builder_->CreateAlignmentAssumption(*data_layout_, GetVarValue(v),
1297                                           alloc_storage_info_[v].alignment);
1298     }
1299   } else if (op->attr_key == tir::attr::volatile_scope) {
1300     const VarNode* v = op->node.as<VarNode>();
1301     CHECK(v);
1302     volatile_buf_.insert(v);
1303   }
1304   this->VisitStmt(op->body);
1305 }
1306 
VisitStmt_(const AssertStmtNode * op)1307 void CodeGenLLVM::VisitStmt_(const AssertStmtNode* op) {
1308   With<arith::ConstraintContext> cctx(analyzer_.get(), op->condition);
1309   this->VisitStmt(op->body);
1310 }
1311 
VisitStmt_(const LetStmtNode * op)1312 void CodeGenLLVM::VisitStmt_(const LetStmtNode* op) {
1313   const VarNode* v = op->var.get();
1314   CHECK(!var_map_.count(v));
1315   if (v->dtype.is_handle()) {
1316     if (!is_restricted_) {
1317       alias_var_set_.insert(v);
1318     }
1319   }
1320   var_map_[v] = MakeValue(op->value);
1321   analyzer_->Bind(op->var, op->value);
1322   if (alloc_storage_info_.count(v) && alloc_storage_info_[v].alignment > 1) {
1323     builder_->CreateAlignmentAssumption(*data_layout_, GetVarValue(v),
1324                                         alloc_storage_info_[v].alignment);
1325   }
1326   this->VisitStmt(op->body);
1327 }
1328 
VisitStmt_(const SeqStmtNode * op)1329 void CodeGenLLVM::VisitStmt_(const SeqStmtNode* op) {
1330   for (Stmt stmt : op->seq) {
1331     this->VisitStmt(stmt);
1332   }
1333 }
1334 
VisitStmt_(const EvaluateNode * op)1335 void CodeGenLLVM::VisitStmt_(const EvaluateNode* op) { MakeValue(op->value); }
1336 
1337 }  // namespace codegen
1338 }  // namespace tvm
1339 #endif  // TVM_LLVM_VERSION
1340