1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20 /*! 21 * \file codegen_llvm.h 22 * \brief Common base class for generating into LLVM IR 23 */ 24 #ifndef TVM_CODEGEN_LLVM_CODEGEN_LLVM_H_ 25 #define TVM_CODEGEN_LLVM_CODEGEN_LLVM_H_ 26 #ifdef TVM_LLVM_VERSION 27 28 #include <tvm/ir.h> 29 #include <tvm/ir_functor_ext.h> 30 #include <tvm/codegen.h> 31 #include <tvm/arithmetic.h> 32 #include <memory> 33 #include <utility> 34 #include <vector> 35 #include <string> 36 #include <unordered_map> 37 #include <unordered_set> 38 #include "llvm_common.h" 39 #include "../../runtime/thread_storage_scope.h" 40 41 namespace tvm { 42 namespace codegen { 43 44 using namespace ir; 45 46 /*! 47 * \brief A base class to generate a LLVM. 48 */ 49 class CodeGenLLVM : 50 public ExprFunctor<llvm::Value* (const Expr&)>, 51 public StmtFunctor<void(const Stmt&)> { 52 public: 53 /*! 54 * \brief Create new code generator based on target machine. 55 * \param tm The target machine 56 * \return The created llvm generator. 57 */ 58 static std::unique_ptr<CodeGenLLVM> Create(llvm::TargetMachine* tm); 59 /*! 60 * \brief Initialize the code generator with given context 61 * \param module_name The name of the module. 62 * \param tm Target machine model 63 * \param ctx The context. 64 * \param system_lib Whether to insert system library registration. 65 * \param dynamic_lookup Whether dynamically lookup runtime function 66 * or use the runtime function table passed by caller. 67 */ 68 virtual void Init(const std::string& module_name, 69 llvm::TargetMachine* tm, 70 llvm::LLVMContext* ctx, 71 bool system_lib, 72 bool dynamic_lookup); 73 /*! 74 * \brief Compile and add function f to the current module. 75 * \param f The function to be added. 76 */ 77 virtual void AddFunction(const LoweredFunc& f); 78 /*! 79 * \brief Add main function as the entry name 80 * \param entry_func_name The name of entry function to be added. 81 */ 82 virtual void AddMainFunction(const std::string& entry_func_name); 83 /*! 84 * \brief Finish current pass of codegen, get the module. 85 * \return the created module. 86 */ 87 virtual std::unique_ptr<llvm::Module> Finish(); 88 /*! 89 * \brief Add mod to be linked with the generated module 90 * \param mod The module to be linked. 91 */ 92 void AddLinkModule(std::unique_ptr<llvm::Module>&& mod); 93 /*! 94 * \brief Create Value for expression e 95 * \param e The expression to be created value for. 96 * \return created value. 97 */ MakeValue(const Expr & e)98 llvm::Value* MakeValue(const Expr& e) { 99 return VisitExpr(e); 100 } 101 // Short hande code to get a constant int 32 ConstInt32(int64_t value)102 llvm::Constant* ConstInt32(int64_t value) const { 103 return llvm::ConstantInt::getSigned(t_int32_, value); 104 } 105 // override codegen 106 llvm::Value* VisitExpr_(const Variable* op) override; 107 llvm::Value* VisitExpr_(const Cast* op) override; 108 llvm::Value* VisitExpr_(const IntImm* op) override; 109 llvm::Value* VisitExpr_(const UIntImm* op) override; 110 llvm::Value* VisitExpr_(const FloatImm* op) override; 111 llvm::Value* VisitExpr_(const StringImm* op) override; 112 llvm::Value* VisitExpr_(const Add* op) override; 113 llvm::Value* VisitExpr_(const Sub* op) override; 114 llvm::Value* VisitExpr_(const Mul* op) override; 115 llvm::Value* VisitExpr_(const Div* op) override; 116 llvm::Value* VisitExpr_(const Mod* op) override; 117 llvm::Value* VisitExpr_(const Min* op) override; 118 llvm::Value* VisitExpr_(const Max* op) override; 119 llvm::Value* VisitExpr_(const LT* op) override; 120 llvm::Value* VisitExpr_(const LE* op) override; 121 llvm::Value* VisitExpr_(const GT* op) override; 122 llvm::Value* VisitExpr_(const GE* op) override; 123 llvm::Value* VisitExpr_(const EQ* op) override; 124 llvm::Value* VisitExpr_(const NE* op) override; 125 llvm::Value* VisitExpr_(const And* op) override; 126 llvm::Value* VisitExpr_(const Or* op) override; 127 llvm::Value* VisitExpr_(const Not* op) override; 128 llvm::Value* VisitExpr_(const Select* op) override; 129 llvm::Value* VisitExpr_(const Let* op) override; 130 llvm::Value* VisitExpr_(const Load* op) override; 131 llvm::Value* VisitExpr_(const Call* op) override; 132 llvm::Value* VisitExpr_(const Ramp* op) override; 133 llvm::Value* VisitExpr_(const Shuffle* op) override; 134 llvm::Value* VisitExpr_(const Broadcast* op) override; 135 // stmt 136 void VisitStmt_(const Store* op) override; 137 void VisitStmt_(const For* op) override; 138 void VisitStmt_(const IfThenElse* op) override; 139 void VisitStmt_(const Allocate* op) override; 140 void VisitStmt_(const AttrStmt* op) override; 141 void VisitStmt_(const AssertStmt* op) override; 142 void VisitStmt_(const LetStmt* op) override; 143 void VisitStmt_(const Block* op) override; 144 void VisitStmt_(const Evaluate* op) override; 145 void VisitStmt_(const ProducerConsumer* op) override; 146 147 protected: 148 /*! \brief The storage information */ 149 struct StorageInfo { 150 /*! \brief The storage scope */ 151 runtime::StorageScope scope; 152 /*! \brief The alignment of allocation */ 153 int alignment{0}; 154 }; 155 /*! 156 * \brief Execute falloca at the beginning of the 157 * currrent function and obtain its return value. 158 * 159 * This is a helper function to make sure that 160 * alloca always happen in the beginning of the function. 161 * 162 * \param falloca The allocation function to be executed. 163 * \tparam F The function to be executed. 164 * \return The result. 165 */ 166 template<typename F> WithFunctionEntry(F falloca)167 inline llvm::AllocaInst* WithFunctionEntry(F falloca) { 168 llvm::BasicBlock* current = builder_->GetInsertBlock(); 169 llvm::BasicBlock* entry = &(function_->getEntryBlock()); 170 builder_->SetInsertPoint(entry, entry->begin()); 171 llvm::AllocaInst* res = falloca(); 172 builder_->SetInsertPoint(current); 173 return res; 174 } 175 // create intrinstic given call 176 virtual llvm::Value* CreateIntrinsic(const Call* op); 177 // create extern function call 178 virtual llvm::Value* CreateCallExtern(const Call* op); 179 // Get the corresponding thread index 180 virtual llvm::Value* GetThreadIndex(const IterVar& iv); 181 // Get the corresponding thread index 182 virtual llvm::Value* CreateStorageSync(const Call* op); 183 // apply optimization on the module. 184 virtual void InitPassManagerBuilder(llvm::PassManagerBuilder* builder); 185 // Scalarize by iterating elements of e. 186 // f is a callback that takes index and v. 187 virtual void Scalarize(const Expr& e, 188 std::function<void(int i, llvm::Value* v)> f); 189 // Initialize target 190 virtual void InitTarget(llvm::TargetMachine* tm); 191 // Add module startup function if needed. AddStartupFunction()192 virtual void AddStartupFunction() {} 193 // apply optimization on the module. 194 virtual void Optimize(); 195 // Get the maximim storage align bits of buffer pointer given storage scope. 196 virtual int NativeVectorBits(const runtime::StorageScope& storage_scope) const; 197 // Get correct address space depending on the backend 198 virtual unsigned GetGlobalAddressSpace(); 199 200 void AddFunctionInternal(const LoweredFunc& f, bool ret_void); 201 // Create extern call 202 llvm::CallInst* CreateCallExtern(llvm::Type* ret, 203 const std::string& name, 204 const std::vector<llvm::Value*>& value); 205 /*! 206 * \param t The original type. 207 * \return LLVM type of t 208 */ 209 llvm::Type* LLVMType(const Type& t) const; 210 // initialize the function state. 211 void InitFuncState(); 212 // Get alignment given index. 213 void GetAlignment( 214 Type t, const Variable* buf_var, const Expr& index, 215 int* p_alignment, int* p_native_bits); 216 // Get constant string 217 llvm::Value* GetConstString(const std::string& str); 218 // do a scalarize call with f 219 llvm::Value* CreateScalarizedCall( 220 const Call* op, llvm::Function* f, const std::vector<llvm::Value*>& args); 221 // handle module import 222 void HandleImport(const std::string& code); 223 // cast operatpr 224 llvm::Value* CreateCast(Type from, Type to, llvm::Value* value); 225 // comparison op 226 llvm::Value* GetVarValue(const Variable* v) const; 227 llvm::Value* CreateLT(Type t, llvm::Value* a, llvm::Value* b); 228 llvm::Value* CreateLE(Type t, llvm::Value* a, llvm::Value* b); 229 llvm::Value* CreateGT(Type t, llvm::Value* a, llvm::Value* b); 230 llvm::Value* CreateGE(Type t, llvm::Value* a, llvm::Value* b); 231 llvm::Value* CreateAdd(Type t, llvm::Value* a, llvm::Value* b); 232 llvm::Value* CreateSub(Type t, llvm::Value* a, llvm::Value* b); 233 llvm::Value* CreateMul(Type t, llvm::Value* a, llvm::Value* b); 234 llvm::Value* CreateBroadcast(llvm::Value* value, int lanes); 235 llvm::Value* CreateBufferPtr(Type t, llvm::Value* buffer, llvm::Value* index); 236 llvm::Value* CreateBufferVecPtr(Type t, llvm::Value* buffer, llvm::Value* index); 237 // Vector concatenation. 238 llvm::Value* CreateVecSlice(llvm::Value* vec, int begin, int extent); 239 llvm::Value* CreateVecFlip(llvm::Value* vec); 240 llvm::Value* CreateVecConcat(std::vector<llvm::Value*> vecs); 241 llvm::Value* CreateVecPad(llvm::Value* vec, int target_lanes); 242 // Create serial for 243 void CreateSerialFor(llvm::Value* begin, 244 llvm::Value* end, 245 llvm::Value* stride, 246 const VarExpr& loop_var, const Stmt& body); 247 // add alias information. 248 void AddAliasInfo(llvm::Instruction* load, const Variable* buffer, Expr index, Type type); 249 // The IRBuilder. 250 using IRBuilder = llvm::IRBuilder<llvm::ConstantFolder, llvm::IRBuilderDefaultInserter>; 251 // The current function 252 llvm::Function* function_; 253 // Internal builder 254 std::unique_ptr<IRBuilder> builder_; 255 // The module to be returned; 256 std::unique_ptr<llvm::Module> module_; 257 std::unique_ptr<llvm::DataLayout> data_layout_; 258 // Internal metabuilder 259 std::unique_ptr<llvm::MDBuilder> md_builder_; 260 // llvm target machine 261 llvm::TargetMachine* target_machine_{nullptr}; 262 // llvm context 263 llvm::LLVMContext* ctx_{nullptr}; 264 // helpful data types 265 llvm::Type* t_void_{nullptr}; 266 llvm::PointerType* t_void_p_{nullptr}; 267 llvm::Type* t_int_{nullptr}; 268 llvm::Type* t_char_{nullptr}; 269 llvm::Type* t_int8_{nullptr}; 270 llvm::Type* t_int16_{nullptr}; 271 llvm::Type* t_int32_{nullptr}; 272 llvm::Type* t_int64_{nullptr}; 273 llvm::Type* t_float64_{nullptr}; 274 // meta data 275 llvm::MDNode* md_very_likely_branch_{nullptr}; 276 llvm::MDNode* md_tbaa_root_{nullptr}; 277 llvm::MDNode* md_tbaa_alias_set_{nullptr}; 278 // modules to be linked. 279 std::vector<std::unique_ptr<llvm::Module> > link_modules_; 280 /*! \brief native vector bits of current targetx*/ 281 int native_vector_bits_{0}; 282 /*! \brief the storage scope of allocation */ 283 std::unordered_map<const Variable*, StorageInfo> alloc_storage_info_; 284 // The definition of local variable. 285 std::unordered_map<const Variable*, llvm::Value*> var_map_; 286 // global strings 287 std::unordered_map<std::string, llvm::Constant*> str_map_; 288 // Whether current function is restricted 289 bool is_restricted_{true}; 290 // The analyzer information 291 std::unique_ptr<arith::Analyzer> analyzer_; 292 // set of var that are not restricted(can alias) 293 std::unordered_set<const Variable*> alias_var_set_; 294 // set of volatile buffer. 295 std::unordered_set<const Variable*> volatile_buf_; 296 /*! \brief Helper struct for debug infos. */ 297 struct DebugInfo { 298 std::unique_ptr<llvm::DIBuilder> di_builder_; 299 llvm::DICompileUnit* compilation_unit_{nullptr}; 300 llvm::DIFile* file_{nullptr}; 301 }; 302 /*! 303 * \brief Create a new DebugInfo struct from the given Module that 304 * initializes file and compilation_unit_ to TVM defaults. 305 */ 306 static std::unique_ptr<DebugInfo> CreateDebugInfo(llvm::Module* module); 307 }; 308 } // namespace codegen 309 } // namespace tvm 310 #endif // LLVM_VERSION 311 #endif // TVM_CODEGEN_LLVM_CODEGEN_LLVM_H_ 312