1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file codegen_llvm.cc
22  */
23 #ifdef TVM_LLVM_VERSION
24 // Part of the code are adapted from Halide's CodeGen_LLVM
25 #include <tvm/runtime/device_api.h>
26 #include <tvm/runtime/c_runtime_api.h>
27 
28 #include <algorithm>
29 
30 #include "codegen_llvm.h"
31 #include "codegen_cpu.h"
32 #include "../build_common.h"
33 #include "../../pass/ir_util.h"
34 #include "../../arithmetic/compute_expr.h"
35 
36 namespace tvm {
37 namespace codegen {
38 
Create(llvm::TargetMachine * tm)39 std::unique_ptr<CodeGenLLVM> CodeGenLLVM::Create(llvm::TargetMachine *tm) {
40   std::string target = tm->getTarget().getName();
41   std::string factory_name = "tvm.codegen.llvm.target_" + target;
42   const PackedFunc* f = runtime::Registry::Get(factory_name);
43   if (f != nullptr) {
44     void* handle = (*f)();
45     return std::unique_ptr<CodeGenLLVM>(static_cast<CodeGenLLVM*>(handle));
46   } else {
47     return std::unique_ptr<CodeGenLLVM>(new CodeGenCPU());
48   }
49 }
50 
Init(const std::string & module_name,llvm::TargetMachine * tm,llvm::LLVMContext * ctx,bool system_lib,bool dynamic_lookup)51 void CodeGenLLVM::Init(const std::string& module_name,
52                        llvm::TargetMachine* tm,
53                        llvm::LLVMContext* ctx,
54                        bool system_lib,
55                        bool dynamic_lookup) {
56   InitializeLLVM();
57   ctx_ = ctx;
58   builder_.reset(new IRBuilder(*ctx_));
59   module_.reset(new llvm::Module(module_name, *ctx_));
60   md_builder_.reset(new llvm::MDBuilder(*ctx_));
61   // types
62   t_void_ = llvm::Type::getVoidTy(*ctx_);
63   t_void_p_ = llvm::Type::getInt8Ty(*ctx_)->getPointerTo();
64   t_int_ = llvm::Type::getInt32Ty(*ctx_);
65   t_char_ = llvm::Type::getInt8Ty(*ctx_);
66   t_int8_ = llvm::Type::getInt8Ty(*ctx_);
67   t_int16_ = llvm::Type::getInt16Ty(*ctx_);
68   t_int32_ = llvm::Type::getInt32Ty(*ctx_);
69   t_int64_ = llvm::Type::getInt64Ty(*ctx_);
70   t_float64_ = llvm::Type::getDoubleTy(*ctx_);
71   // meta data
72   md_very_likely_branch_ = md_builder_->createBranchWeights(1<<20, 1);
73   md_tbaa_root_ = md_builder_->createTBAARoot("tvm-tbaa");
74   md_tbaa_alias_set_ = md_builder_->createTBAANode("tvm-alias", md_tbaa_root_);
75   this->InitTarget(tm);
76 }
77 
InitTarget(llvm::TargetMachine * tm)78 void CodeGenLLVM::InitTarget(llvm::TargetMachine* tm) {
79   module_->setTargetTriple(tm->getTargetTriple().str());
80   module_->setDataLayout(tm->createDataLayout());
81   data_layout_.reset(new llvm::DataLayout(module_.get()));
82   target_machine_ = tm;
83   if (native_vector_bits_ == 0) {
84     const auto& arch = tm->getTargetTriple().getArch();
85     if (arch == llvm::Triple::x86_64) {
86       // for avx512
87       native_vector_bits_ = 512;
88     } else if (arch == llvm::Triple::x86) {
89       native_vector_bits_ = 256;
90     } else if (arch == llvm::Triple::arm || arch == llvm::Triple::aarch64) {
91       native_vector_bits_ = 128;
92     } else {
93       native_vector_bits_ = 128;
94       std::string arch_name = tm->getTargetTriple().getArchName();
95       LOG(WARNING) << "Set native vector bits to be 128 for " << arch_name;
96     }
97   }
98 }
99 
AddFunction(const LoweredFunc & f)100 void CodeGenLLVM::AddFunction(const LoweredFunc& f) {
101   this->AddFunctionInternal(f, false);
102 }
103 
InitFuncState()104 void CodeGenLLVM::InitFuncState() {
105   var_map_.clear();
106   alias_var_set_.clear();
107   alloc_storage_info_.clear();
108   volatile_buf_.clear();
109   analyzer_.reset(new arith::Analyzer());
110 }
111 
112 
AddFunctionInternal(const LoweredFunc & f,bool ret_void)113 void CodeGenLLVM::AddFunctionInternal(const LoweredFunc& f, bool ret_void) {
114   this->InitFuncState();
115   std::vector<llvm::Type*> arg_types;
116   is_restricted_ = f->is_restricted;
117   for (Var arg : f->args) {
118     Type t = arg.type();
119     if (t.is_handle()) {
120       auto it = f->handle_data_type.find(arg);
121       if (it != f->handle_data_type.end()) {
122         arg_types.push_back(LLVMType((*it).second.type())
123                             ->getPointerTo(GetGlobalAddressSpace()));
124       } else {
125         arg_types.push_back(t_int8_->getPointerTo(GetGlobalAddressSpace()));
126       }
127       if (!is_restricted_) {
128         alias_var_set_.insert(arg.get());
129       }
130     } else {
131       arg_types.push_back(LLVMType(arg.type()));
132     }
133   }
134   llvm::FunctionType* ftype = llvm::FunctionType::get(
135       ret_void ? t_void_ : t_int_, arg_types, false);
136   CHECK(module_->getFunction(f->name) == nullptr)
137       << "Function " << f->name << " already exist in module";
138   function_ = llvm::Function::Create(
139       ftype, llvm::Function::ExternalLinkage,
140       f->name, module_.get());
141   function_->setCallingConv(llvm::CallingConv::C);
142   function_->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass);
143   // set var map and align information
144   auto arg_it = function_->arg_begin();
145   for (size_t i = 0; i < f->args.size(); ++i, ++arg_it) {
146     llvm::Argument* v = &(*arg_it);
147     const Var& var = f->args[i];
148     var_map_[var.get()] = v;
149     if (is_restricted_) {
150       if (var.type().is_handle() && !alias_var_set_.count(var.get())) {
151         // set non alias.
152 #if TVM_LLVM_VERSION >= 50
153         function_->addParamAttr(i, llvm::Attribute::NoAlias);
154 #else
155         function_->setDoesNotAlias(i + 1);
156 #endif
157       }
158     }
159   }
160   llvm::BasicBlock* entry = llvm::BasicBlock::Create(*ctx_, "entry", function_);
161   builder_->SetInsertPoint(entry);
162   this->VisitStmt(f->body);
163   if (ret_void) {
164     builder_->CreateRetVoid();
165   } else {
166     builder_->CreateRet(ConstInt32(0));
167   }
168 }
169 
170 
Finish()171 std::unique_ptr<llvm::Module> CodeGenLLVM::Finish() {
172   this->AddStartupFunction();
173   for (size_t i = 0; i < link_modules_.size(); ++i) {
174     CHECK(!llvm::Linker::linkModules(*module_, std::move(link_modules_[i])))
175         << "Failed to link modules";
176   }
177   link_modules_.clear();
178   // optimize
179   this->Optimize();
180   return std::move(module_);
181 }
182 
183 
HandleImport(const std::string & code)184 void CodeGenLLVM::HandleImport(const std::string& code) {
185   std::unique_ptr<llvm::Module> mlib;
186   llvm::SMDiagnostic err;
187   if (code.length() >= 3 &&
188       (code.substr(code.length() - 3) == ".ll" ||
189        code.substr(code.length() - 3) == ".bc")) {
190     mlib = llvm::parseIRFile(code, err, *ctx_);
191     if (mlib.get() == nullptr) {
192       std::string msg = err.getMessage();
193       LOG(FATAL) << "Fail to load bitcode file " << code << "\n"
194                  << "line " << err.getLineNo() << ":" << msg;
195     }
196   } else {
197     std::unique_ptr<llvm::MemoryBuffer> buf =
198         llvm::MemoryBuffer::getMemBuffer(code);
199     mlib = llvm::parseIR(*buf, err, *ctx_);
200     if (mlib.get() == nullptr) {
201       std::string msg = err.getMessage();
202       LOG(FATAL) << "Fail to load llvm ir "
203                  << "line " << err.getLineNo() << ":" << msg
204                  << "\ncontent:\n"  << code;
205     }
206   }
207   mlib->setTargetTriple(target_machine_->getTargetTriple().str());
208   mlib->setDataLayout(target_machine_->createDataLayout());
209   // mark all the functions as force inline
210   for (llvm::Function &f : mlib->functions()) {
211     f.removeFnAttr(llvm::Attribute::NoInline);
212     f.addFnAttr(llvm::Attribute::AlwaysInline);
213     f.setLinkage(llvm::GlobalValue::AvailableExternallyLinkage);
214   }
215   // add to linker libraries.
216   this->AddLinkModule(std::move(mlib));
217 }
218 
AddLinkModule(std::unique_ptr<llvm::Module> && mod)219 void CodeGenLLVM::AddLinkModule(std::unique_ptr<llvm::Module>&& mod) {
220   link_modules_.emplace_back(std::move(mod));
221 }
222 
AddMainFunction(const std::string & entry_func_name)223 void CodeGenLLVM::AddMainFunction(const std::string& entry_func_name) {
224   LOG(FATAL) << "not implemented";
225 }
226 
GetThreadIndex(const IterVar & iv)227 llvm::Value* CodeGenLLVM::GetThreadIndex(const IterVar& iv) {
228   LOG(FATAL) << "not implemented";
229   return nullptr;
230 }
231 
CreateStorageSync(const Call * op)232 llvm::Value* CodeGenLLVM::CreateStorageSync(const Call* op) {
233   LOG(FATAL) << "not implemented";
234   return nullptr;
235 }
236 
237 class FPassManager : public llvm::legacy::FunctionPassManager {
238  public:
FPassManager(llvm::Module * m)239   explicit FPassManager(llvm::Module* m)
240       : llvm::legacy::FunctionPassManager(m) {}
241   // override add to allow messaging
add(llvm::Pass * p)242   void add(llvm::Pass* p) final {
243     llvm::legacy::FunctionPassManager::add(p);
244   }
245 };
246 
247 class MPassManager : public llvm::legacy::PassManager {
248  public:
249   // override add to allow messaging
add(llvm::Pass * p)250   void add(llvm::Pass* p) final {
251     llvm::legacy::PassManager::add(p);
252   }
253 };
254 
InitPassManagerBuilder(llvm::PassManagerBuilder * builder)255 void CodeGenLLVM::InitPassManagerBuilder(llvm::PassManagerBuilder* builder) {
256 }
257 
Optimize()258 void CodeGenLLVM::Optimize() {
259   // pass manager
260   FPassManager fpass(module_.get());
261   MPassManager mpass;
262   mpass.add(llvm::createTargetTransformInfoWrapperPass(
263               target_machine_ ? target_machine_->getTargetIRAnalysis() :
264                                 llvm::TargetIRAnalysis()));
265   fpass.add(llvm::createTargetTransformInfoWrapperPass(
266               target_machine_ ? target_machine_->getTargetIRAnalysis() :
267               llvm::TargetIRAnalysis()));
268 
269   // place optimization pass
270   llvm::PassManagerBuilder builder;
271   builder.OptLevel = 3;
272 
273 #if TVM_LLVM_VERSION >= 50
274   builder.Inliner = llvm::createFunctionInliningPass(builder.OptLevel, 0, false);
275 #else
276   builder.Inliner = llvm::createFunctionInliningPass(builder.OptLevel, 0);
277 #endif
278   builder.LoopVectorize = true;
279   builder.SLPVectorize = true;
280   this->InitPassManagerBuilder(&builder);
281 
282 #if TVM_LLVM_VERSION >= 50
283   target_machine_->adjustPassManager(builder);
284 #endif
285 
286   builder.populateFunctionPassManager(fpass);
287   builder.populateModulePassManager(mpass);
288 
289   fpass.doInitialization();
290   for (auto it = module_->begin(); it != module_->end(); ++it) {
291     fpass.run(*it);
292   }
293   fpass.doFinalization();
294   mpass.run(*module_);
295 }
296 
NativeVectorBits(const runtime::StorageScope & storage_scope) const297 int CodeGenLLVM::NativeVectorBits(const runtime::StorageScope& storage_scope) const {
298   return native_vector_bits_;
299 }
300 
GetGlobalAddressSpace()301 unsigned CodeGenLLVM::GetGlobalAddressSpace() {
302   return 0;
303 }
304 
LLVMType(const Type & t) const305 llvm::Type* CodeGenLLVM::LLVMType(const Type& t) const {
306   if (t.is_handle()) {
307     CHECK_EQ(t.lanes(), 1);
308     return t_void_p_;
309   }
310   llvm::Type* etype = nullptr;
311   if (t.is_int() || t.is_uint()) {
312     etype = llvm::Type::getIntNTy(*ctx_, t.bits());
313   } else if (t.is_float()) {
314     switch (t.bits()) {
315       case 16: etype = llvm::Type::getHalfTy(*ctx_); break;
316       case 32: etype = llvm::Type::getFloatTy(*ctx_); break;
317       case 64: etype = llvm::Type::getDoubleTy(*ctx_); break;
318       default: LOG(FATAL) << "do not support " << t;
319     }
320   }
321   if (t.lanes() != 1) {
322     return llvm::VectorType::get(etype, t.lanes());
323   } else {
324     return etype;
325   }
326 }
327 
328 // Add tbaa alias information for load
329 //
330 // use a binary tree typed system to declare information
331 // and allow alias to be distinguished across nodes.
332 //
333 // This trick comes from Halide's CodeGen_LLVM
334 //
AddAliasInfo(llvm::Instruction * inst,const Variable * buffer,Expr index,Type type)335 void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst,
336                                const Variable* buffer,
337                                Expr index,
338                                Type type) {
339   if (alias_var_set_.count(buffer) != 0) {
340     // Mark all possibly aliased pointer as same type.
341     llvm::MDNode* meta = md_tbaa_alias_set_;
342     inst->setMetadata(
343         "tbaa",
344         md_builder_->createTBAAStructTagNode(meta, meta, 0));
345     return;
346   }
347   int base = 0, width = 0;
348   // create meta-data for alias analysis
349   // Use a group of binary tree ranges of memory banks.
350   if (index.defined()) {
351     const Ramp* ramp = index.as<Ramp>();
352     if (ramp) {
353       int base, stride;
354       if (arith::GetConstInt(ramp->base, &base) &&
355           arith::GetConstInt(ramp->stride, &stride)) {
356         int xwith = ramp->lanes * stride;
357         width = 1;
358         while (width < xwith) {
359           width *= 2;
360         }
361         while (base % width) {
362           base -= base % width;
363           width *= 2;
364         }
365       }
366     } else {
367       if (arith::GetConstInt(index, &base)) width = 1;
368     }
369   }
370   llvm::MDNode* meta = md_tbaa_root_;
371   std::ostringstream buffer_addr, buffer_type;
372   buffer_addr << buffer;
373   meta = md_builder_->createTBAAScalarTypeNode(buffer_addr.str(), meta);
374   buffer_type << type.element_of();
375   meta = md_builder_->createTBAAScalarTypeNode(buffer_type.str(), meta);
376   // create a tree-shape access structure.
377   if (width != 0) {
378     for (int w = 1024; w >= width; w /= 2) {
379       int b = (base / w) * w;
380       std::stringstream os;
381       os << buffer << ".w" << w << ".b" << b;
382       meta = md_builder_->createTBAAScalarTypeNode(os.str(), meta);
383     }
384   }
385   inst->setMetadata(
386       "tbaa",
387       md_builder_->createTBAAStructTagNode(meta, meta, 0));
388 }
389 
GetAlignment(Type t,const Variable * buf_var,const Expr & index,int * p_alignment,int * p_native_bits)390 void CodeGenLLVM::GetAlignment(Type t,
391                                const Variable* buf_var,
392                                const Expr& index,
393                                int* p_alignment,
394                                int* p_native_bits) {
395   int max_align_bits = t.bits();
396   auto it = alloc_storage_info_.find(buf_var);
397   if (it != alloc_storage_info_.end()) {
398     const StorageInfo& info = it->second;
399     *p_native_bits = NativeVectorBits(info.scope);
400     max_align_bits = info.alignment * 8;
401   } else {
402     *p_native_bits = native_vector_bits_;
403   }
404 
405   arith::ModularSet me = analyzer_->modular_set(index);
406   int64_t base = me->base;
407   int64_t coeff = me->coeff;
408 
409   int align_bits = t.bits();
410   while (align_bits < max_align_bits &&
411          base % 2  == 0 &&
412          coeff % 2 == 0) {
413     base =  base / 2;
414     coeff =  coeff / 2;
415     align_bits *= 2;
416   }
417   if (align_bits < 8) {
418     align_bits = 8;
419   }
420   *p_alignment = align_bits / 8;
421 }
422 
423 std::unique_ptr<CodeGenLLVM::DebugInfo>
CreateDebugInfo(llvm::Module * module)424 CodeGenLLVM::CreateDebugInfo(llvm::Module* module) {
425 #if TVM_LLVM_VERSION >= 100
426   auto debug_info = std::make_unique<CodeGenLLVM::DebugInfo>();
427   debug_info->di_builder_ = std::make_unique<llvm::DIBuilder>(*module);
428 #else
429   auto debug_info = llvm::make_unique<CodeGenLLVM::DebugInfo>();
430   debug_info->di_builder_ = llvm::make_unique<llvm::DIBuilder>(*module);
431 #endif
432   // TODO(tulloch): pass this information through relay::Span classes to the LoweredFunc instance?
433   debug_info->file_ = debug_info->di_builder_->createFile("model.tvm", "/tmp/");
434   debug_info->compilation_unit_ = debug_info->di_builder_->createCompileUnit(
435       llvm::dwarf::DW_LANG_C, debug_info->file_, "TVM", 0, "", 0, "",
436       llvm::DICompileUnit::DebugEmissionKind::FullDebug,
437       /* SplitDebugInlining */ true,
438       /* DebugInfoForProfiling */ true);
439   return debug_info;
440 }
441 
CreateBroadcast(llvm::Value * value,int lanes)442 llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) {
443   llvm::Constant* undef = llvm::UndefValue::get(
444       llvm::VectorType::get(value->getType(), lanes));
445   llvm::Constant* zero = ConstInt32(0);
446   value = builder_->CreateInsertElement(undef, value, zero);
447   llvm::Constant* mask = llvm::ConstantVector::getSplat(lanes, zero);
448   return builder_->CreateShuffleVector(value, undef, mask);
449 }
450 
CreateVecSlice(llvm::Value * vec,int begin,int extent)451 llvm::Value* CodeGenLLVM::CreateVecSlice(llvm::Value* vec, int begin, int extent) {
452   int num_elems = static_cast<int>(vec->getType()->getVectorNumElements());
453   if (extent == num_elems && begin == 0) return vec;
454   CHECK(begin >= 0 && extent <= num_elems) << "Slicing out of bound!\n";
455   std::vector<llvm::Constant*> indices;
456   indices.reserve(extent);
457   for (int i = 0; i < extent; ++i) {
458     if (begin + i >= 0 && begin + i < num_elems) {
459       indices.push_back(llvm::ConstantInt::get(t_int32_, begin + i));
460     } else {
461       indices.push_back(llvm::UndefValue::get(t_int32_));
462     }
463   }
464   return builder_->CreateShuffleVector(vec, vec, llvm::ConstantVector::get(indices));
465 }
466 
CreateVecFlip(llvm::Value * vec)467 llvm::Value* CodeGenLLVM::CreateVecFlip(llvm::Value* vec) {
468   int num_elems = static_cast<int>(vec->getType()->getVectorNumElements());
469   std::vector<unsigned> indices;
470   for (int i = 0; i < num_elems; ++i) {
471     indices.push_back(num_elems - i - 1);
472   }
473   return builder_->CreateShuffleVector(vec, vec, indices);
474 }
475 
CreateVecPad(llvm::Value * vec,int target_lanes)476 llvm::Value* CodeGenLLVM::CreateVecPad(llvm::Value* vec, int target_lanes) {
477   llvm::Value* mask = llvm::UndefValue::get(LLVMType(Int(32, target_lanes)));
478   int num_elems = static_cast<int>(vec->getType()->getVectorNumElements());
479   if (num_elems == target_lanes) return vec;
480   CHECK_LT(num_elems, target_lanes);
481   for (int i = 0; i < num_elems; ++i) {
482     mask = builder_->CreateInsertElement(mask, ConstInt32(i), ConstInt32(i));
483   }
484   return builder_->CreateShuffleVector(vec, vec, mask);
485 }
486 
CreateVecConcat(std::vector<llvm::Value * > vecs)487 llvm::Value* CodeGenLLVM::CreateVecConcat(std::vector<llvm::Value*> vecs) {
488   // concat vector, tree shape reduction
489   int total_lanes = 0;
490 
491   for (llvm::Value* v : vecs) {
492     total_lanes += static_cast<int>(
493         v->getType()->getVectorNumElements());
494   }
495   while (vecs.size() > 1) {
496     std::vector<llvm::Value*> new_vecs;
497     for (size_t i = 0; i < vecs.size() - 1; i += 2) {
498       llvm::Value* lhs = vecs[i];
499       llvm::Value* rhs = vecs[i + 1];
500       const size_t lhs_lanes = lhs->getType()->getVectorNumElements();
501       const size_t rhs_lanes = rhs->getType()->getVectorNumElements();
502       if (lhs_lanes < rhs_lanes) {
503         lhs = CreateVecPad(lhs, rhs_lanes);
504       } else if (rhs_lanes < lhs_lanes) {
505         rhs = CreateVecPad(rhs, lhs_lanes);
506       }
507       const size_t shared_lanes = std::max(lhs_lanes, rhs_lanes);
508       std::vector<unsigned> mask;
509       for (size_t i = 0; i < lhs_lanes; ++i) {
510         mask.push_back(i);
511       }
512       for (size_t i = 0; i < rhs_lanes; ++i) {
513         mask.push_back(shared_lanes + i);
514       }
515       new_vecs.push_back(builder_->CreateShuffleVector(lhs, rhs, mask));
516     }
517     if (vecs.size() % 2 != 0) {
518       new_vecs.push_back(vecs.back());
519     }
520     vecs.swap(new_vecs);
521   }
522   return CreateVecSlice(vecs[0], 0, total_lanes);
523 }
524 
525 
CreateSerialFor(llvm::Value * begin,llvm::Value * end,llvm::Value * stride,const VarExpr & loop_var,const Stmt & body)526 void CodeGenLLVM::CreateSerialFor(llvm::Value* begin,
527                                   llvm::Value* end,
528                                   llvm::Value* stride,
529                                   const VarExpr& loop_var,
530                                   const Stmt& body) {
531   using llvm::BasicBlock;
532   BasicBlock* pre_block = builder_->GetInsertBlock();
533   BasicBlock* for_begin = BasicBlock::Create(
534       *ctx_, "for_begin", function_);
535   BasicBlock* for_body = BasicBlock::Create(
536       *ctx_, "for_body", function_);
537   BasicBlock* for_end = BasicBlock::Create(
538       *ctx_, "for_end", function_);
539   builder_->CreateBr(for_begin);
540   builder_->SetInsertPoint(for_begin);
541   llvm::PHINode* loop_value = builder_->CreatePHI(begin->getType(), 2);
542   loop_value->addIncoming(begin, pre_block);
543   CHECK(!var_map_.count(loop_var.get()));
544   var_map_[loop_var.get()] = loop_value;
545   builder_->CreateCondBr(CreateLT(loop_var.type(), loop_value, end),
546                          for_body, for_end, md_very_likely_branch_);
547   builder_->SetInsertPoint(for_body);
548   this->VisitStmt(body);
549   var_map_.erase(loop_var.get());
550   llvm::Value* loop_next = CreateAdd(loop_var.type(), loop_value, stride);
551   loop_value->addIncoming(loop_next, builder_->GetInsertBlock());
552   builder_->CreateBr(for_begin);
553   builder_->SetInsertPoint(for_end);
554 }
555 
556 // cast operatpr
CreateCast(Type from,Type to,llvm::Value * value)557 llvm::Value* CodeGenLLVM::CreateCast(Type from, Type to, llvm::Value* value) {
558   llvm::Type * target = LLVMType(to);
559   if (value->getType() == target) return value;
560   if (to.is_handle()) {
561     return builder_->CreateBitCast(value, target);
562   } else if (to.is_uint() && to.bits() == 1) {
563     if (from.is_float()) {
564       llvm::Constant* zero = llvm::ConstantFP::get(LLVMType(from), 0.);
565       return builder_->CreateFCmpONE(value, zero);
566     } else {
567       llvm::Constant* zero = llvm::ConstantInt::get(LLVMType(from), 0);
568       return builder_->CreateICmpNE(value, zero);
569     }
570   } else if (!from.is_float() && !to.is_float()) {
571     return builder_->CreateIntCast(value, target, from.is_int());
572   } else if (from.is_float() && to.is_int()) {
573     return builder_->CreateFPToSI(value, target);
574   } else if (from.is_float() && to.is_uint()) {
575     if (to.bits() < 8) {
576       value = builder_->CreateFPToUI(value, LLVMType(to.with_bits(8)));
577       return builder_->CreateIntCast(value, target, false);
578     } else {
579       return builder_->CreateFPToUI(value, target);
580     }
581   } else if (from.is_int() && to.is_float()) {
582     return builder_->CreateSIToFP(value, target);
583   } else if (from.is_uint() && to.is_float()) {
584     return builder_->CreateUIToFP(value, target);
585   } else {
586     CHECK(from.is_float() && to.is_float());
587     return builder_->CreateFPCast(value, target);
588   }
589 }
590 
GetConstString(const std::string & str)591 llvm::Value* CodeGenLLVM::GetConstString(const std::string& str) {
592   auto it = str_map_.find(str);
593   if (it != str_map_.end()) return it->second;
594   llvm::Type* type = llvm::ArrayType::get(t_char_, str.length() + 1);
595   llvm::GlobalVariable *global = new llvm::GlobalVariable(
596       *module_, type, true, llvm::GlobalValue::PrivateLinkage, 0, ".str");
597 #if TVM_LLVM_VERSION >= 100
598   global->setAlignment(llvm::Align(1));
599 #else
600   global->setAlignment(1);
601 #endif
602   global->setInitializer(llvm::ConstantDataArray::getString(*ctx_, str));
603   llvm::Constant* zero = ConstInt32(0);
604   llvm::Constant* indices[] = {zero, zero};
605   llvm::Constant* ptr = llvm::ConstantExpr::getGetElementPtr(
606       type, global, indices);
607   str_map_[str] = ptr;
608   return ptr;
609 }
610 
CreateBufferPtr(Type t,llvm::Value * buffer,llvm::Value * index)611 llvm::Value* CodeGenLLVM::CreateBufferPtr(
612     Type t, llvm::Value* buffer, llvm::Value* index) {
613   CHECK_EQ(t.lanes(), 1);
614   llvm::PointerType* btype = llvm::dyn_cast<llvm::PointerType>(buffer->getType());
615   CHECK(btype != nullptr);
616   llvm::PointerType* ptype = LLVMType(t)->getPointerTo(btype->getAddressSpace());
617   if (btype != ptype) {
618     buffer = builder_->CreatePointerCast(buffer, ptype);
619   }
620 
621   return builder_->CreateInBoundsGEP(buffer, index);
622 }
623 
CreateBufferVecPtr(Type t,llvm::Value * buffer,llvm::Value * index)624 llvm::Value* CodeGenLLVM::CreateBufferVecPtr(
625     Type t, llvm::Value* buffer, llvm::Value* index) {
626   CHECK_GT(t.lanes(), 1);
627   llvm::PointerType* btype = llvm::dyn_cast<llvm::PointerType>(buffer->getType());
628   CHECK(btype != nullptr);
629   llvm::PointerType* ptype = LLVMType(t)->getPointerTo(btype->getAddressSpace());
630   if (btype != ptype) {
631     buffer = builder_->CreatePointerCast(buffer, ptype);
632   }
633   return builder_->CreateInBoundsGEP(buffer, index);
634 }
635 
GetVarValue(const Variable * v) const636 llvm::Value* CodeGenLLVM::GetVarValue(const Variable* v) const {
637   auto it = var_map_.find(v);
638   CHECK(it != var_map_.end()) << "cannot find variable " << v->name_hint;
639   return it->second;
640 }
641 
CreateCallExtern(const Call * op)642 llvm::Value* CodeGenLLVM::CreateCallExtern(const Call* op) {
643   std::vector<llvm::Value*> arg_value;
644   std::vector<llvm::Type*> arg_type;
645   for (size_t i = 0; i < op->args.size(); ++i) {
646     arg_value.push_back(MakeValue(op->args[i]));
647     arg_type.push_back(arg_value.back()->getType());
648   }
649   llvm::FunctionType* ftype = llvm::FunctionType::get(
650       LLVMType(op->type), arg_type, false);
651   llvm::Function* f = module_->getFunction(op->name);
652   if (f == nullptr) {
653     f = llvm::Function::Create(
654         ftype, llvm::Function::ExternalLinkage,
655         op->name, module_.get());
656   }
657   llvm::CallInst* call = builder_->CreateCall(f, arg_value);
658   return call;
659 }
660 
CreateIntrinsic(const Call * op)661 llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) {
662   if (op->is_intrinsic("llvm_intrin")) {
663     CHECK_GE(op->args.size(), 2U);
664     llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(
665         op->args[0].as<UIntImm>()->value);
666     const uint64_t *num_signature = as_const_uint(op->args[1]);
667     CHECK(num_signature) << "The second argument should be a uint represents number of arguments, "
668                          << "but " << op->args[1] << " got!\n";
669     std::vector<llvm::Value*> arg_value;
670     std::vector<llvm::Type*> sig_type;
671     for (size_t i = 2; i < op->args.size(); ++i) {
672       arg_value.push_back(MakeValue(op->args[i]));
673       if (i - 2 < *num_signature) {
674         sig_type.push_back(arg_value.back()->getType());
675       }
676     }
677     llvm::Type *return_type = LLVMType(op->type);
678     if (sig_type.size() > 0 && return_type != sig_type[0]) {
679       sig_type.insert(sig_type.begin(), return_type);
680     }
681     llvm::Function* f = llvm::Intrinsic::getDeclaration(
682         module_.get(), id, sig_type);
683     return builder_->CreateCall(f, arg_value);
684   } else if (op->is_intrinsic(Call::bitwise_and)) {
685     return builder_->CreateAnd(MakeValue(op->args[0]), MakeValue(op->args[1]));
686   } else if (op->is_intrinsic(Call::bitwise_or)) {
687     return builder_->CreateOr(MakeValue(op->args[0]), MakeValue(op->args[1]));
688   } else if (op->is_intrinsic(Call::bitwise_not)) {
689     return builder_->CreateNot(MakeValue(op->args[0]));
690   } else if (op->is_intrinsic(Call::bitwise_xor)) {
691     return builder_->CreateXor(MakeValue(op->args[0]), MakeValue(op->args[1]));
692   } else if (op->is_intrinsic(Call::shift_left)) {
693     return builder_->CreateShl(MakeValue(op->args[0]), MakeValue(op->args[1]));
694   } else if (op->is_intrinsic(Call::shift_right)) {
695     if (op->args[0].type().is_int()) {
696       return builder_->CreateAShr(MakeValue(op->args[0]), MakeValue(op->args[1]));
697     } else {
698       return builder_->CreateLShr(MakeValue(op->args[0]), MakeValue(op->args[1]));
699     }
700   } else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) {
701     return CreateStorageSync(op);
702   } else if (op->is_intrinsic(intrinsic::tvm_address_of)) {
703     const Load *l = op->args[0].as<Load>();
704     CHECK(op->args.size() == 1 && l);
705     const Ramp *r = l->index.as<Ramp>();
706     llvm::Value* ptr;
707     unsigned addrspace;
708     if (!r) {
709         ptr = CreateBufferPtr(
710           l->type, MakeValue(l->buffer_var), MakeValue(l->index));
711         addrspace = llvm::dyn_cast<llvm::PointerType>(
712           ptr->getType())->getAddressSpace();
713     } else {
714         Expr index = r->base / make_const(Int(32), r->lanes);
715         ptr = CreateBufferVecPtr(
716           l->type, MakeValue(l->buffer_var), MakeValue(index));
717         addrspace = llvm::dyn_cast<llvm::PointerType>(
718           ptr->getType())->getAddressSpace();
719     }
720     return builder_->CreatePointerCast(ptr, t_char_->getPointerTo(addrspace));
721   } else if (op->is_intrinsic(Call::reinterpret) && is_zero(op->args[0])) {
722     return llvm::Constant::getNullValue(t_void_p_);
723   } else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) {
724     return builder_->CreateIsNull(MakeValue(op->args[0]));
725   } else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) {
726     CHECK_EQ(op->args[0].type().lanes(), 1)
727         << "if_then_else can only take scalar condition";
728     using llvm::BasicBlock;
729     BasicBlock* then_block = BasicBlock::Create(
730         *ctx_, "if_then", function_);
731     BasicBlock* else_block = BasicBlock::Create(
732         *ctx_, "if_else", function_);
733     BasicBlock* end_block = BasicBlock::Create(
734         *ctx_, "if_end", function_);
735     builder_->CreateCondBr(MakeValue(op->args[0]), then_block, else_block);
736     builder_->SetInsertPoint(then_block);
737     llvm::Value* then_value = MakeValue(op->args[1]);
738     BasicBlock* then_value_block = builder_->GetInsertBlock();
739     builder_->CreateBr(end_block);
740     builder_->SetInsertPoint(else_block);
741     llvm::Value* else_value = MakeValue(op->args[2]);
742     BasicBlock* else_value_block = builder_->GetInsertBlock();
743     builder_->CreateBr(end_block);
744     builder_->SetInsertPoint(end_block);
745     llvm::PHINode* value = builder_->CreatePHI(then_value->getType(), 2);
746     value->addIncoming(then_value, then_value_block);
747     value->addIncoming(else_value, else_value_block);
748     return value;
749   } else if (op->is_intrinsic(Call::reinterpret)) {
750     llvm::Type * target = LLVMType(op->type);
751     return builder_->CreateBitCast(MakeValue(op->args[0]), target);
752   } else if (op->is_intrinsic(Call::isnan)) {
753     // TODO(hgt312): set fast math flag
754     llvm::Value* a = MakeValue(op->args[0]);
755     return builder_->CreateFCmpUNO(a, a);
756   } else if (op->is_intrinsic("vectorlow")) {
757     llvm::Value *v = MakeValue(op->args[0]);
758     int l = v->getType()->getVectorNumElements();
759     return CreateVecSlice(v, 0, l/2);
760   } else if (op->is_intrinsic("vectorhigh")) {
761     llvm::Value *v = MakeValue(op->args[0]);
762     int l = v->getType()->getVectorNumElements();
763     return CreateVecSlice(v, l/2, l/2);
764   } else if (op->is_intrinsic("vectorcombine")) {
765     llvm::Value *v0 = MakeValue(op->args[0]);
766     llvm::Value *v1 = MakeValue(op->args[1]);
767     int num_elems = static_cast<int>(v0->getType()->getVectorNumElements()) * 2;
768     std::vector<unsigned> indices;
769     for (int i = 0; i < num_elems; ++i) {
770       indices.push_back(i);
771     }
772     return builder_->CreateShuffleVector(v0, v1, indices);
773   } else {
774     LOG(FATAL) << "unknown intrinsic " << op->name;
775     return nullptr;
776   }
777 }
778 
Scalarize(const Expr & e,std::function<void (int i,llvm::Value * v)> f)779 void CodeGenLLVM::Scalarize(const Expr& e,
780                             std::function<void(int i, llvm::Value* v)> f) {
781   if (const Ramp* ramp = e.as<Ramp>()) {
782     for (int i = 0; i < ramp->type.lanes(); ++i) {
783       Expr offset = ramp->base + (ramp->stride * i);
784       f(i, MakeValue(offset));
785     }
786   } else {
787     llvm::Value* value = MakeValue(e);
788     for (int i = 0; i < e.type().lanes(); ++i) {
789       f(i, builder_->CreateExtractElement(value, i));
790     }
791   }
792 }
793 
794 
795 // Visitors
VisitExpr_(const Variable * op)796 llvm::Value* CodeGenLLVM::VisitExpr_(const Variable* op) {
797   return GetVarValue(op);
798 }
799 
VisitExpr_(const Cast * op)800 llvm::Value* CodeGenLLVM::VisitExpr_(const Cast* op) {
801   return CreateCast(op->value.type(), op->type, MakeValue(op->value));
802 }
VisitExpr_(const IntImm * op)803 llvm::Value* CodeGenLLVM::VisitExpr_(const IntImm* op) {
804   return llvm::ConstantInt::getSigned(LLVMType(op->type), op->value);
805 }
806 
VisitExpr_(const UIntImm * op)807 llvm::Value* CodeGenLLVM::VisitExpr_(const UIntImm* op) {
808   return llvm::ConstantInt::get(LLVMType(op->type), op->value);
809 }
810 
VisitExpr_(const FloatImm * op)811 llvm::Value* CodeGenLLVM::VisitExpr_(const FloatImm* op) {
812   return llvm::ConstantFP::get(LLVMType(op->type), op->value);
813 }
814 
VisitExpr_(const StringImm * op)815 llvm::Value* CodeGenLLVM::VisitExpr_(const StringImm* op) {
816   return GetConstString(op->value);
817 }
818 
819 #define DEFINE_CODEGEN_BINARY_OP(Op)                                    \
820   llvm::Value* CodeGenLLVM::Create ## Op(                               \
821       Type t, llvm::Value* a, llvm::Value *b) {                         \
822     if (t.is_int()) {                                                   \
823       if (t.bits() >= 32) {                                             \
824         return builder_->CreateNSW ## Op (a, b);                        \
825       } else {                                                          \
826         return builder_->Create ## Op (a, b);                           \
827       }                                                                 \
828     } else if (t.is_uint()) {                                           \
829       if (t.bits() >= 32) {                                             \
830         return builder_->CreateNUW ## Op (a, b);                        \
831       } else {                                                          \
832         return builder_->Create ## Op (a, b);                           \
833       }                                                                 \
834     } else {                                                            \
835       CHECK(t.is_float());                                              \
836       return builder_->CreateF ## Op (a, b);                            \
837     }                                                                   \
838   }                                                                     \
839   llvm::Value* CodeGenLLVM::VisitExpr_(const Op* op) {                  \
840     return Create ## Op(op->type, MakeValue(op->a), MakeValue(op->b));  \
841   }
842 
843 DEFINE_CODEGEN_BINARY_OP(Add);
844 DEFINE_CODEGEN_BINARY_OP(Sub);
845 DEFINE_CODEGEN_BINARY_OP(Mul);
846 
847 #define DEFINE_CODEGEN_CMP_OP(Op)                                       \
848   llvm::Value* CodeGenLLVM::Create ## Op(                               \
849       Type t, llvm::Value* a, llvm::Value* b) {                         \
850     if (t.is_int()) {                                                   \
851       return builder_->CreateICmpS ## Op (a, b);                        \
852     } else if (t.is_uint()) {                                           \
853       return builder_->CreateICmpU ## Op (a, b);                        \
854     } else {                                                            \
855       CHECK(t.is_float());                                              \
856       return builder_->CreateFCmpO ## Op (a, b);                        \
857     }                                                                   \
858 }                                                                       \
859   llvm::Value* CodeGenLLVM::VisitExpr_(const Op* op) {                  \
860     return Create ## Op(op->a.type(), MakeValue(op->a), MakeValue(op->b)); \
861   }
862 
863 DEFINE_CODEGEN_CMP_OP(LT);
864 DEFINE_CODEGEN_CMP_OP(LE);
865 DEFINE_CODEGEN_CMP_OP(GT);
866 DEFINE_CODEGEN_CMP_OP(GE);
867 
VisitExpr_(const Div * op)868 llvm::Value* CodeGenLLVM::VisitExpr_(const Div* op) {
869   llvm::Value* a = MakeValue(op->a);
870   llvm::Value* b = MakeValue(op->b);
871   if (op->type.is_int()) {
872     return builder_->CreateSDiv(a, b);
873   } else if (op->type.is_uint()) {
874     return builder_->CreateUDiv(a, b);
875   } else {
876     CHECK(op->type.is_float());
877     return builder_->CreateFDiv(a, b);
878   }
879 }
880 
VisitExpr_(const Mod * op)881 llvm::Value* CodeGenLLVM::VisitExpr_(const Mod* op) {
882   llvm::Value* a = MakeValue(op->a);
883   llvm::Value* b = MakeValue(op->b);
884   if (op->type.is_int()) {
885     return builder_->CreateSRem(a, b);
886   } else if (op->type.is_uint()) {
887     return builder_->CreateURem(a, b);
888   } else {
889     CHECK(op->type.is_float());
890     return builder_->CreateFRem(a, b);
891   }
892 }
893 
VisitExpr_(const Min * op)894 llvm::Value* CodeGenLLVM::VisitExpr_(const Min* op) {
895   llvm::Value* a = MakeValue(op->a);
896   llvm::Value* b = MakeValue(op->b);
897   return builder_->CreateSelect(CreateLT(op->a.type(), a, b), a, b);
898 }
899 
VisitExpr_(const Max * op)900 llvm::Value* CodeGenLLVM::VisitExpr_(const Max* op) {
901   llvm::Value* a = MakeValue(op->a);
902   llvm::Value* b = MakeValue(op->b);
903   return builder_->CreateSelect(CreateGT(op->a.type(), a, b), a, b);
904 }
905 
VisitExpr_(const EQ * op)906 llvm::Value* CodeGenLLVM::VisitExpr_(const EQ* op) {
907   llvm::Value* a = MakeValue(op->a);
908   llvm::Value* b = MakeValue(op->b);
909   if (op->a.type().is_int() || op->a.type().is_uint()) {
910     return builder_->CreateICmpEQ(a, b);
911   } else {
912     return builder_->CreateFCmpOEQ(a, b);
913   }
914 }
915 
VisitExpr_(const NE * op)916 llvm::Value* CodeGenLLVM::VisitExpr_(const NE* op) {
917   llvm::Value* a = MakeValue(op->a);
918   llvm::Value* b = MakeValue(op->b);
919   if (op->a.type().is_int() || op->a.type().is_uint()) {
920     return builder_->CreateICmpNE(a, b);
921   } else {
922     return builder_->CreateFCmpONE(a, b);
923   }
924 }
925 
VisitExpr_(const And * op)926 llvm::Value* CodeGenLLVM::VisitExpr_(const And* op) {
927   return builder_->CreateAnd(MakeValue(op->a), MakeValue(op->b));
928 }
929 
VisitExpr_(const Or * op)930 llvm::Value* CodeGenLLVM::VisitExpr_(const Or* op) {
931   return builder_->CreateOr(MakeValue(op->a), MakeValue(op->b));
932 }
933 
VisitExpr_(const Not * op)934 llvm::Value* CodeGenLLVM::VisitExpr_(const Not* op) {
935   return builder_->CreateNot(MakeValue(op->a));
936 }
937 
VisitExpr_(const Select * op)938 llvm::Value* CodeGenLLVM::VisitExpr_(const Select* op) {
939   return builder_->CreateSelect(
940       MakeValue(op->condition),
941       MakeValue(op->true_value),
942       MakeValue(op->false_value));
943 }
944 
VisitExpr_(const Let * op)945 llvm::Value* CodeGenLLVM::VisitExpr_(const Let* op) {
946   CHECK(!var_map_.count(op->var.get()));
947   var_map_[op->var.get()] = MakeValue(op->value);
948   analyzer_->Bind(op->var, op->value);
949   return MakeValue(op->body);
950 }
951 
VisitExpr_(const Load * op)952 llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) {
953   Type t = op->type;
954   bool is_volatile = volatile_buf_.count(op->buffer_var.get());
955   llvm::Value* buffer = MakeValue(op->buffer_var);
956   llvm::Value* index = MakeValue(op->index);
957 
958   if (t.lanes() == 1) {
959     int alignment, native_bits;
960     GetAlignment(t, op->buffer_var.get(), op->index, &alignment, &native_bits);
961     llvm::Value* ptr = CreateBufferPtr(t, buffer, index);
962     llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, alignment, is_volatile);
963     AddAliasInfo(load, op->buffer_var.get(), op->index, t);
964     return load;
965   } else {
966     // vector load
967     unsigned addrspace = llvm::dyn_cast<llvm::PointerType>(
968       buffer->getType())->getAddressSpace();
969     if (const Ramp* ramp = op->index.as<Ramp>()) {
970       if (is_one(ramp->stride)) {
971         int alignment, native_bits;
972         GetAlignment(t, op->buffer_var.get(), ramp->base, &alignment, &native_bits);
973         CHECK_EQ(ramp->lanes, t.lanes());
974         llvm::Value* ptr = CreateBufferPtr(
975             t.element_of(), buffer, MakeValue(ramp->base));
976         ptr = builder_->CreatePointerCast(ptr, LLVMType(t)->getPointerTo(addrspace));
977         llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, alignment, is_volatile);
978         AddAliasInfo(load, op->buffer_var.get(), op->index, t);
979         return load;
980       }
981     }
982   }
983   // scalarized load.
984   int basic_align = t.bits() / 8;
985   llvm::Value* ret = llvm::UndefValue::get(LLVMType(t));
986   auto f = [&](int i, llvm::Value* index) {
987     llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, index);
988     llvm::LoadInst* load = builder_->CreateAlignedLoad(
989         ptr, basic_align, is_volatile);
990     ret = builder_->CreateInsertElement(ret, load, ConstInt32(i));
991     AddAliasInfo(load, op->buffer_var.get(), Expr(), t);
992   };
993   this->Scalarize(op->index, f);
994   return ret;
995 }
996 
VisitExpr_(const Call * op)997 llvm::Value* CodeGenLLVM::VisitExpr_(const Call* op) {
998   if (op->call_type == Call::Intrinsic ||
999       op->call_type == Call::PureIntrinsic) {
1000     return CreateIntrinsic(op);
1001   } else if (op->call_type == Call::Extern ||
1002              op->call_type == Call::PureExtern) {
1003     return CreateCallExtern(op);
1004   } else {
1005     LOG(FATAL) << "Unknown call type " <<
1006       "name= " << op->name <<
1007       " call_type= " << op->call_type;
1008     return nullptr;
1009   }
1010 }
1011 
VisitExpr_(const Ramp * op)1012 llvm::Value* CodeGenLLVM::VisitExpr_(const Ramp* op) {
1013   llvm::Value* vec = llvm::UndefValue::get(LLVMType(op->type));
1014   for (int i = 0; i < op->lanes; ++i) {
1015     vec = builder_->CreateInsertElement(
1016         vec, MakeValue(op->base + op->stride * make_const(op->stride.type(), i)),
1017         ConstInt32(i));
1018   }
1019   return vec;
1020 }
1021 
VisitExpr_(const Shuffle * op)1022 llvm::Value* CodeGenLLVM::VisitExpr_(const Shuffle* op) {
1023   std::vector<llvm::Value *> vecs(op->vectors.size());
1024   int total_lanes = 0;
1025   for (int i = 0, e = op->vectors.size(); i < e; ++i) {
1026     vecs[i] = VisitExpr(op->vectors[i]);
1027     total_lanes += op->vectors[i].type().lanes();
1028   }
1029   llvm::Value* v0 = CreateVecConcat(vecs);
1030   std::vector<uint32_t> idx(op->indices.size());
1031   for (int i = 0, e = op->indices.size(); i < e; ++i) {
1032     const int64_t *val = as_const_int(op->indices[i]);
1033     CHECK(val && *val >= 0 && *val  < total_lanes) << "Shuffled indeces are suppose to be int, "
1034       << "but get " << op->indices[i] << "\n";
1035     idx[i] = *val;
1036   }
1037   llvm::Value* mask = llvm::ConstantDataVector::get(builder_->getContext(), idx);
1038   auto res = builder_->CreateShuffleVector(v0, llvm::UndefValue::get(v0->getType()), mask);
1039   return res;
1040 }
1041 
VisitExpr_(const Broadcast * op)1042 llvm::Value* CodeGenLLVM::VisitExpr_(const Broadcast* op) {
1043   return CreateBroadcast(MakeValue(op->value), op->lanes);
1044 }
1045 
VisitStmt_(const Store * op)1046 void CodeGenLLVM::VisitStmt_(const Store* op) {
1047   CHECK(is_one(op->predicate));
1048   Type t = op->value.type();
1049   bool is_volatile = volatile_buf_.count(op->buffer_var.get());
1050   llvm::Value* buffer = MakeValue(op->buffer_var);
1051   llvm::Value* index = MakeValue(op->index);
1052   llvm::Value* value = MakeValue(op->value);
1053 
1054   if (t.lanes() == 1) {
1055     int alignment, native_bits;
1056     GetAlignment(t, op->buffer_var.get(), op->index, &alignment, &native_bits);
1057     llvm::Value* ptr = CreateBufferPtr(t, buffer, index);
1058     llvm::StoreInst* store = builder_->CreateAlignedStore(value, ptr, alignment, is_volatile);
1059     AddAliasInfo(store, op->buffer_var.get(), op->index, op->value.type());
1060     return;
1061   } else {
1062     // vector store
1063     unsigned addrspace = llvm::dyn_cast<llvm::PointerType>(
1064         buffer->getType())->getAddressSpace();
1065     if (const Ramp* ramp = op->index.as<Ramp>()) {
1066       if (is_one(ramp->stride)) {
1067         int alignment, native_bits;
1068         GetAlignment(t, op->buffer_var.get(), ramp->base, &alignment, &native_bits);
1069         CHECK_EQ(ramp->lanes, t.lanes());
1070         llvm::Value* ptr = CreateBufferPtr(
1071             t.element_of(), buffer, MakeValue(ramp->base));
1072         ptr = builder_->CreatePointerCast(ptr, LLVMType(t)->getPointerTo(addrspace));
1073         llvm::StoreInst* store = builder_->CreateAlignedStore(value, ptr, alignment, is_volatile);
1074         AddAliasInfo(store, op->buffer_var.get(), op->index, op->value.type());
1075         return;
1076       }
1077     }
1078   }
1079   CHECK_GE(t.bits(), 8);
1080   // scalarized store.
1081   int basic_align = t.bits() / 8;
1082   auto f = [&](int i, llvm::Value* index) {
1083     llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, index);
1084     llvm::StoreInst* store = builder_->CreateAlignedStore(
1085         builder_->CreateExtractElement(value, i),
1086         ptr, basic_align, is_volatile);
1087     AddAliasInfo(store, op->buffer_var.get(), Expr(), op->value.type());
1088   };
1089   this->Scalarize(op->index, f);
1090 }
1091 
VisitStmt_(const For * op)1092 void CodeGenLLVM::VisitStmt_(const For* op) {
1093   CHECK(is_zero(op->min));
1094   analyzer_->Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent));
1095   if (op->for_type == ForType::Unrolled) {
1096     LOG(WARNING) << "Unroll hint get ignore at CodeGenLLVM backend, "
1097                  << " consider set unroll_explicit=True";
1098   } else {
1099     CHECK(op->for_type == ForType::Serial);
1100   }
1101   CreateSerialFor(MakeValue(op->min), MakeValue(op->extent),
1102                   ConstInt32(1), op->loop_var, op->body);
1103 }
1104 
1105 
VisitStmt_(const IfThenElse * op)1106 void CodeGenLLVM::VisitStmt_(const IfThenElse* op) {
1107   using llvm::BasicBlock;
1108   llvm::Value* cond = MakeValue(op->condition);
1109   BasicBlock* then_block = BasicBlock::Create(
1110       *ctx_, "if_then", function_);
1111   BasicBlock* end_block = BasicBlock::Create(
1112       *ctx_, "if_end", function_);
1113   if (op->else_case.defined()) {
1114     BasicBlock* else_block = BasicBlock::Create(
1115         *ctx_, "if_else", function_);
1116     builder_->CreateCondBr(cond, then_block, else_block);
1117     builder_->SetInsertPoint(then_block);
1118     this->VisitStmt(op->then_case);
1119     builder_->CreateBr(end_block);
1120     builder_->SetInsertPoint(else_block);
1121     this->VisitStmt(op->else_case);
1122     builder_->CreateBr(end_block);
1123   } else {
1124     builder_->CreateCondBr(cond, then_block, end_block, md_very_likely_branch_);
1125     builder_->SetInsertPoint(then_block);
1126     this->VisitStmt(op->then_case);
1127     builder_->CreateBr(end_block);
1128   }
1129   builder_->SetInsertPoint(end_block);
1130 }
1131 
1132 
VisitStmt_(const Allocate * op)1133 void CodeGenLLVM::VisitStmt_(const Allocate* op) {
1134   CHECK(!is_zero(op->condition));
1135   llvm::Value* buf = nullptr;
1136   if (op->new_expr.defined()) {
1137     CHECK_EQ(op->free_function, "nop");
1138     buf = MakeValue(op->new_expr);
1139   } else {
1140     int32_t constant_size = op->constant_allocation_size();
1141     CHECK_GT(constant_size, 0)
1142         << "Can only handle constant size stack allocation";
1143     StorageInfo& info = alloc_storage_info_[op->buffer_var.get()];
1144     if (constant_size % 4 == 0 && info.alignment == 0) {
1145       info.alignment = GetTempAllocaAlignment(op->type, constant_size);
1146     }
1147     // maximum necessary alignment in the NV devices
1148     if (info.alignment > 16) {
1149       info.alignment = 16;
1150     }
1151     llvm::AllocaInst* alloca = WithFunctionEntry([&]() {
1152         return builder_->CreateAlloca(
1153             LLVMType(op->type), ConstInt32(constant_size));
1154       });
1155     if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) {
1156 #if TVM_LLVM_VERSION >= 100
1157       alloca->setAlignment(llvm::Align(info.alignment));
1158 #else
1159       alloca->setAlignment(info.alignment);
1160 #endif
1161     }
1162     info.alignment = alloca->getAlignment();
1163     buf = alloca;
1164   }
1165   buf = builder_->CreatePointerCast(
1166       buf, LLVMType(op->type)->getPointerTo(
1167           buf->getType()->getPointerAddressSpace()));
1168   CHECK(!var_map_.count(op->buffer_var.get()));
1169   var_map_[op->buffer_var.get()] = buf;
1170   this->VisitStmt(op->body);
1171 }
1172 
VisitStmt_(const AttrStmt * op)1173 void CodeGenLLVM::VisitStmt_(const AttrStmt* op) {
1174   if (op->attr_key == attr::thread_extent) {
1175     IterVar iv = Downcast<IterVar>(op->node);
1176     if (iv->thread_tag.length() != 0) {
1177       if (!var_map_.count(iv->var.get())) {
1178         var_map_[iv->var.get()] = GetThreadIndex(iv);
1179         analyzer_->Bind(iv->var, Range::make_by_min_extent(0, op->value));
1180       }
1181     }
1182   } else if (op->attr_key == ir::attr::storage_scope) {
1183     const Variable* v = op->node.as<Variable>();
1184     CHECK(v);
1185     alloc_storage_info_[v].scope =
1186         runtime::StorageScope::make(op->value.as<StringImm>()->value);
1187   } else if (op->attr_key == ir::attr::storage_alignment) {
1188     const Variable* v = op->node.as<Variable>();
1189     CHECK(v);
1190     alloc_storage_info_[v].alignment =
1191         static_cast<int>(op->value.as<IntImm>()->value);
1192   } else if (op->attr_key == ir::attr::volatile_scope) {
1193     const Variable* v = op->node.as<Variable>();
1194     CHECK(v);
1195     volatile_buf_.insert(v);
1196   }
1197   this->VisitStmt(op->body);
1198 }
1199 
VisitStmt_(const AssertStmt * op)1200 void CodeGenLLVM::VisitStmt_(const AssertStmt* op) {
1201   With<arith::ConstraintContext> cctx(analyzer_.get(), op->condition);
1202   this->VisitStmt(op->body);
1203 }
1204 
VisitStmt_(const LetStmt * op)1205 void CodeGenLLVM::VisitStmt_(const LetStmt* op) {
1206   CHECK(!var_map_.count(op->var.get()));
1207   if (op->var.type().is_handle()) {
1208     if (!is_restricted_) {
1209       alias_var_set_.insert(op->var.get());
1210     }
1211   }
1212   var_map_[op->var.get()] = MakeValue(op->value);
1213   analyzer_->Bind(op->var, op->value);
1214   this->VisitStmt(op->body);
1215 }
1216 
VisitStmt_(const Block * op)1217 void CodeGenLLVM::VisitStmt_(const Block* op) {
1218   this->VisitStmt(op->first);
1219   if (op->rest.defined()) {
1220     this->VisitStmt(op->rest);
1221   }
1222 }
1223 
VisitStmt_(const Evaluate * op)1224 void CodeGenLLVM::VisitStmt_(const Evaluate* op) {
1225   MakeValue(op->value);
1226 }
1227 
VisitStmt_(const ProducerConsumer * op)1228 void CodeGenLLVM::VisitStmt_(const ProducerConsumer* op) {
1229   this->VisitStmt(op->body);
1230 }
1231 }  // namespace codegen
1232 }  // namespace tvm
1233 #endif  // TVM_LLVM_VERSION
1234