1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file src/relay/backend/vm/compiler.h
22  * \brief A compiler from relay::Module to the VM byte code.
23  */
24 
25 #ifndef TVM_RELAY_BACKEND_VM_COMPILER_H_
26 #define TVM_RELAY_BACKEND_VM_COMPILER_H_
27 
28 #include <tvm/relay/error.h>
29 #include <tvm/relay/expr_functor.h>
30 #include <tvm/relay/interpreter.h>
31 #include <tvm/logging.h>
32 #include <tvm/relay/transform.h>
33 #include <tvm/runtime/vm.h>
34 #include <iostream>
35 #include <memory>
36 #include <string>
37 #include <unordered_map>
38 #include <unordered_set>
39 #include <utility>
40 #include <vector>
41 #include "../../../runtime/vm/profiler/vm.h"
42 #include "../../../runtime/vm/naive_allocator.h"
43 #include "../../backend/compile_engine.h"
44 #include "../../pass/pass_util.h"
45 
46 namespace tvm {
47 namespace relay {
48 namespace vm {
49 
50 using namespace tvm::runtime;
51 using namespace tvm::runtime::vm;
52 using namespace relay::transform;
53 
54 template <typename T, typename U>
55 using NodeMap = std::unordered_map<T, U, NodeHash, NodeEqual>;
56 using TagMap = NodeMap<tvm::relay::Constructor, Index>;
57 using TagNameMap = std::unordered_map<size_t, tvm::relay::Constructor>;
58 using GlobalMap = NodeMap<GlobalVar, Index>;
59 using ConstMap = NodeMap<Constant, Index>;
60 using ConstTensorShapeMap = NodeMap<TensorType, std::pair<Index, NDArray>>;
61 using TargetsMap = Map<tvm::Integer, tvm::Target>;
62 
63 struct VMCompilerContext {
64   // The module context for the compilation
65   Module module;
66   // Error reporter
67   ErrorReporter err_reporter;
68   // Map from a unique integer to ADT constructor tag
69   TagNameMap tag_index_map;
70   // Map from ADT constructor tag to a unique integer
71   TagMap tag_map;
72   // Map from global var to a unique integer
73   GlobalMap global_map;
74   // List of constants
75   std::vector<NDArray> constants;
76   // List of cached functions
77   std::vector<CachedFunc> cached_funcs;
78   // The functions that have been lowered.
79   std::unordered_map<LoweredFunc, size_t, NodeHash, NodeEqual> seen_funcs;
80 };
81 
82 
83 class VMCompiler : public runtime::ModuleNode {
84  public:
~VMCompiler()85   virtual ~VMCompiler() {}
86 
87   virtual PackedFunc GetFunction(const std::string& name,
88                                  const ObjectPtr<Object>& sptr_to_self);
89 
type_key()90   const char* type_key() const {
91     return "VMCompiler";
92   }
93 
InitVM()94   void InitVM() {
95     exec_ = make_object<Executable>();
96   }
97 
98   /*!
99    * \brief Set the parameters
100    *
101    * \param name name of parameter
102    * \param data_in input DLTensor
103    */
104   void SetParam(const std::string& name, runtime::NDArray data_in);
105 
106   /*!
107    * \brief Compile functions in a Module
108    *
109    * \param mod Relay Module
110    * \param targets For heterogeneous compilation, it is a dictionary indicating context
111                     to target mapping. For homogeneous compilation, it is a build target.
112    * \param target_host Host compilation target, if target is device.
113    */
114   void Compile(Module mod,
115                const TargetsMap& targets,
116                const tvm::Target& target_host);
117 
118  protected:
119   /*!
120    * \brief Bind params to function by using name
121    * \param func Relay function
122    * \param params params dict
123    * \return relay::Function
124    */
125   relay::Function BindParamsByName(
126       relay::Function func,
127       const std::unordered_map<std::string, runtime::NDArray>& params);
128 
129   Module OptimizeModule(const Module& mod, const TargetsMap& targets);
130 
131   void PopulateGlobalMap();
132 
133   void LibraryCodegen();
134 
135  protected:
136   /*! \brief Target devices. */
137   TargetsMap targets_;
138   /*! \brief Target host device. */
139   tvm::Target target_host_;
140   /*! \brief Global shared meta data */
141   VMCompilerContext context_;
142   /*! \brief Compiled executable. */
143   ObjectPtr<Executable> exec_;
144   /*! \brief parameters */
145   std::unordered_map<std::string, runtime::NDArray> params_;
146 };
147 
148 }  // namespace vm
149 }  // namespace relay
150 }  // namespace tvm
151 
152 #endif  // TVM_RELAY_BACKEND_VM_COMPILER_H_
153