1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file codegen_llvm.h
22  * \brief Common base class for generating into LLVM IR
23  */
24 #ifndef TVM_CODEGEN_LLVM_CODEGEN_LLVM_H_
25 #define TVM_CODEGEN_LLVM_CODEGEN_LLVM_H_
26 #ifdef TVM_LLVM_VERSION
27 
28 #include <tvm/ir.h>
29 #include <tvm/ir_functor_ext.h>
30 #include <tvm/codegen.h>
31 #include <tvm/arithmetic.h>
32 #include <memory>
33 #include <utility>
34 #include <vector>
35 #include <string>
36 #include <unordered_map>
37 #include <unordered_set>
38 #include "llvm_common.h"
39 #include "../../runtime/thread_storage_scope.h"
40 
41 namespace tvm {
42 namespace codegen {
43 
44 using namespace ir;
45 
46 /*!
47  * \brief A base class to generate a LLVM.
48  */
49 class CodeGenLLVM :
50       public ExprFunctor<llvm::Value* (const Expr&)>,
51       public StmtFunctor<void(const Stmt&)> {
52  public:
53   /*!
54    * \brief Create new code generator based on target machine.
55    * \param tm The target machine
56    * \return The created llvm generator.
57    */
58   static std::unique_ptr<CodeGenLLVM> Create(llvm::TargetMachine* tm);
59   /*!
60    * \brief Initialize the code generator with given context
61    * \param module_name The name of the module.
62    * \param tm Target machine model
63    * \param ctx The context.
64    * \param system_lib Whether to insert system library registration.
65    * \param dynamic_lookup Whether dynamically lookup runtime function
66    *                       or use the runtime function table passed by caller.
67    */
68   virtual void Init(const std::string& module_name,
69                     llvm::TargetMachine* tm,
70                     llvm::LLVMContext* ctx,
71                     bool system_lib,
72                     bool dynamic_lookup);
73   /*!
74    * \brief Compile and add function f to the current module.
75    * \param f The function to be added.
76    */
77   virtual void AddFunction(const LoweredFunc& f);
78   /*!
79    * \brief Add main function as the entry name
80    * \param entry_func_name The name of entry function to be added.
81    */
82   virtual void AddMainFunction(const std::string& entry_func_name);
83   /*!
84    * \brief Finish current pass of codegen, get the module.
85    * \return the created module.
86    */
87   virtual std::unique_ptr<llvm::Module> Finish();
88   /*!
89    * \brief Add mod to be linked with the generated module
90    * \param mod The module to be linked.
91    */
92   void AddLinkModule(std::unique_ptr<llvm::Module>&& mod);
93   /*!
94    * \brief Create Value for expression e
95    * \param e The expression to be created value for.
96    * \return created value.
97    */
MakeValue(const Expr & e)98   llvm::Value* MakeValue(const Expr& e) {
99     return VisitExpr(e);
100   }
101   // Short hande code to get a constant int 32
ConstInt32(int64_t value)102   llvm::Constant* ConstInt32(int64_t value) const {
103     return llvm::ConstantInt::getSigned(t_int32_, value);
104   }
105   // override codegen
106   llvm::Value* VisitExpr_(const Variable* op) override;
107   llvm::Value* VisitExpr_(const Cast* op) override;
108   llvm::Value* VisitExpr_(const IntImm* op) override;
109   llvm::Value* VisitExpr_(const UIntImm* op) override;
110   llvm::Value* VisitExpr_(const FloatImm* op) override;
111   llvm::Value* VisitExpr_(const StringImm* op) override;
112   llvm::Value* VisitExpr_(const Add* op) override;
113   llvm::Value* VisitExpr_(const Sub* op) override;
114   llvm::Value* VisitExpr_(const Mul* op) override;
115   llvm::Value* VisitExpr_(const Div* op) override;
116   llvm::Value* VisitExpr_(const Mod* op) override;
117   llvm::Value* VisitExpr_(const Min* op) override;
118   llvm::Value* VisitExpr_(const Max* op) override;
119   llvm::Value* VisitExpr_(const LT* op) override;
120   llvm::Value* VisitExpr_(const LE* op) override;
121   llvm::Value* VisitExpr_(const GT* op) override;
122   llvm::Value* VisitExpr_(const GE* op) override;
123   llvm::Value* VisitExpr_(const EQ* op) override;
124   llvm::Value* VisitExpr_(const NE* op) override;
125   llvm::Value* VisitExpr_(const And* op) override;
126   llvm::Value* VisitExpr_(const Or* op) override;
127   llvm::Value* VisitExpr_(const Not* op) override;
128   llvm::Value* VisitExpr_(const Select* op) override;
129   llvm::Value* VisitExpr_(const Let* op) override;
130   llvm::Value* VisitExpr_(const Load* op) override;
131   llvm::Value* VisitExpr_(const Call* op) override;
132   llvm::Value* VisitExpr_(const Ramp* op) override;
133   llvm::Value* VisitExpr_(const Shuffle* op) override;
134   llvm::Value* VisitExpr_(const Broadcast* op) override;
135   // stmt
136   void VisitStmt_(const Store* op) override;
137   void VisitStmt_(const For* op) override;
138   void VisitStmt_(const IfThenElse* op) override;
139   void VisitStmt_(const Allocate* op) override;
140   void VisitStmt_(const AttrStmt* op) override;
141   void VisitStmt_(const AssertStmt* op) override;
142   void VisitStmt_(const LetStmt* op) override;
143   void VisitStmt_(const Block* op) override;
144   void VisitStmt_(const Evaluate* op) override;
145   void VisitStmt_(const ProducerConsumer* op) override;
146 
147  protected:
148   /*! \brief The storage information */
149   struct StorageInfo {
150     /*! \brief The storage scope */
151     runtime::StorageScope scope;
152     /*! \brief The alignment of allocation */
153     int alignment{0};
154   };
155   /*!
156    * \brief Execute falloca at the beginning of the
157    *  currrent function and obtain its return value.
158    *
159    *  This is a helper function to make sure that
160    *  alloca always happen in the beginning of the function.
161    *
162    * \param falloca The allocation function to be executed.
163    * \tparam F The function to be executed.
164    * \return The result.
165    */
166   template<typename F>
WithFunctionEntry(F falloca)167   inline llvm::AllocaInst* WithFunctionEntry(F falloca) {
168     llvm::BasicBlock* current = builder_->GetInsertBlock();
169     llvm::BasicBlock* entry = &(function_->getEntryBlock());
170     builder_->SetInsertPoint(entry, entry->begin());
171     llvm::AllocaInst* res = falloca();
172     builder_->SetInsertPoint(current);
173     return res;
174   }
175   // create intrinstic given call
176   virtual llvm::Value* CreateIntrinsic(const Call* op);
177   // create extern function call
178   virtual llvm::Value* CreateCallExtern(const Call* op);
179   // Get the corresponding thread index
180   virtual llvm::Value* GetThreadIndex(const IterVar& iv);
181   // Get the corresponding thread index
182   virtual llvm::Value* CreateStorageSync(const Call* op);
183   // apply optimization on the module.
184   virtual void InitPassManagerBuilder(llvm::PassManagerBuilder* builder);
185   // Scalarize by iterating elements of e.
186   // f is a callback that takes index and v.
187   virtual void Scalarize(const Expr& e,
188                          std::function<void(int i, llvm::Value* v)> f);
189   // Initialize target
190   virtual void InitTarget(llvm::TargetMachine* tm);
191   // Add module startup function if needed.
AddStartupFunction()192   virtual void AddStartupFunction() {}
193   // apply optimization on the module.
194   virtual void Optimize();
195   // Get the maximim storage align bits of buffer pointer given storage scope.
196   virtual int NativeVectorBits(const runtime::StorageScope& storage_scope) const;
197   // Get correct address space depending on the backend
198   virtual unsigned GetGlobalAddressSpace();
199 
200   void AddFunctionInternal(const LoweredFunc& f, bool ret_void);
201   // Create extern call
202   llvm::CallInst* CreateCallExtern(llvm::Type* ret,
203                                    const std::string& name,
204                                    const std::vector<llvm::Value*>& value);
205   /*!
206    * \param t The original type.
207    * \return LLVM type of t
208    */
209   llvm::Type* LLVMType(const Type& t) const;
210   // initialize the function state.
211   void InitFuncState();
212   // Get alignment given index.
213   void GetAlignment(
214       Type t, const Variable* buf_var, const Expr& index,
215       int* p_alignment, int* p_native_bits);
216   // Get constant string
217   llvm::Value* GetConstString(const std::string& str);
218   // do a scalarize call with f
219   llvm::Value* CreateScalarizedCall(
220       const Call* op, llvm::Function* f, const std::vector<llvm::Value*>& args);
221   // handle module import
222   void HandleImport(const std::string& code);
223   // cast operatpr
224   llvm::Value* CreateCast(Type from, Type to, llvm::Value* value);
225   // comparison op
226   llvm::Value* GetVarValue(const Variable* v) const;
227   llvm::Value* CreateLT(Type t, llvm::Value* a, llvm::Value* b);
228   llvm::Value* CreateLE(Type t, llvm::Value* a, llvm::Value* b);
229   llvm::Value* CreateGT(Type t, llvm::Value* a, llvm::Value* b);
230   llvm::Value* CreateGE(Type t, llvm::Value* a, llvm::Value* b);
231   llvm::Value* CreateAdd(Type t, llvm::Value* a, llvm::Value* b);
232   llvm::Value* CreateSub(Type t, llvm::Value* a, llvm::Value* b);
233   llvm::Value* CreateMul(Type t, llvm::Value* a, llvm::Value* b);
234   llvm::Value* CreateBroadcast(llvm::Value* value, int lanes);
235   llvm::Value* CreateBufferPtr(Type t, llvm::Value* buffer, llvm::Value* index);
236   llvm::Value* CreateBufferVecPtr(Type t, llvm::Value* buffer, llvm::Value* index);
237   // Vector concatenation.
238   llvm::Value* CreateVecSlice(llvm::Value* vec, int begin, int extent);
239   llvm::Value* CreateVecFlip(llvm::Value* vec);
240   llvm::Value* CreateVecConcat(std::vector<llvm::Value*> vecs);
241   llvm::Value* CreateVecPad(llvm::Value* vec, int target_lanes);
242   // Create serial for
243   void CreateSerialFor(llvm::Value* begin,
244                        llvm::Value* end,
245                        llvm::Value* stride,
246                        const VarExpr& loop_var, const Stmt& body);
247   // add alias information.
248   void AddAliasInfo(llvm::Instruction* load, const Variable* buffer, Expr index, Type type);
249   // The IRBuilder.
250   using IRBuilder = llvm::IRBuilder<llvm::ConstantFolder, llvm::IRBuilderDefaultInserter>;
251   // The current function
252   llvm::Function* function_;
253   // Internal builder
254   std::unique_ptr<IRBuilder> builder_;
255   // The module to be returned;
256   std::unique_ptr<llvm::Module> module_;
257   std::unique_ptr<llvm::DataLayout> data_layout_;
258   // Internal metabuilder
259   std::unique_ptr<llvm::MDBuilder> md_builder_;
260   // llvm target machine
261   llvm::TargetMachine* target_machine_{nullptr};
262   // llvm context
263   llvm::LLVMContext* ctx_{nullptr};
264   // helpful data types
265   llvm::Type* t_void_{nullptr};
266   llvm::PointerType* t_void_p_{nullptr};
267   llvm::Type* t_int_{nullptr};
268   llvm::Type* t_char_{nullptr};
269   llvm::Type* t_int8_{nullptr};
270   llvm::Type* t_int16_{nullptr};
271   llvm::Type* t_int32_{nullptr};
272   llvm::Type* t_int64_{nullptr};
273   llvm::Type* t_float64_{nullptr};
274   // meta data
275   llvm::MDNode* md_very_likely_branch_{nullptr};
276   llvm::MDNode* md_tbaa_root_{nullptr};
277   llvm::MDNode* md_tbaa_alias_set_{nullptr};
278   // modules to be linked.
279   std::vector<std::unique_ptr<llvm::Module> > link_modules_;
280   /*! \brief native vector bits of current targetx*/
281   int native_vector_bits_{0};
282   /*! \brief the storage scope of allocation */
283   std::unordered_map<const Variable*, StorageInfo> alloc_storage_info_;
284   // The definition of local variable.
285   std::unordered_map<const Variable*, llvm::Value*> var_map_;
286   // global strings
287   std::unordered_map<std::string, llvm::Constant*> str_map_;
288   // Whether current function is restricted
289   bool is_restricted_{true};
290   // The analyzer information
291   std::unique_ptr<arith::Analyzer> analyzer_;
292   // set of var that are not restricted(can alias)
293   std::unordered_set<const Variable*> alias_var_set_;
294   // set of volatile buffer.
295   std::unordered_set<const Variable*> volatile_buf_;
296   /*! \brief Helper struct for debug infos. */
297   struct DebugInfo {
298     std::unique_ptr<llvm::DIBuilder> di_builder_;
299     llvm::DICompileUnit* compilation_unit_{nullptr};
300     llvm::DIFile* file_{nullptr};
301   };
302   /*!
303    * \brief Create a new DebugInfo struct from the given Module that
304    *  initializes file and compilation_unit_ to TVM defaults.
305    */
306   static std::unique_ptr<DebugInfo> CreateDebugInfo(llvm::Module* module);
307 };
308 }  // namespace codegen
309 }  // namespace tvm
310 #endif  // LLVM_VERSION
311 #endif  // TVM_CODEGEN_LLVM_CODEGEN_LLVM_H_
312