1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file tvm/build_module.h
22  * \brief Functions for compiling ops.
23  */
24 #ifndef TVM_BUILD_MODULE_H_
25 #define TVM_BUILD_MODULE_H_
26 
27 #include <string>
28 #include <vector>
29 #include <utility>
30 #include <unordered_map>
31 #include <unordered_set>
32 #include "runtime/packed_func.h"
33 #include "schedule_pass.h"
34 #include "lowered_func.h"
35 
36 namespace tvm {
37 
38 /*!
39 * \brief Container for target device information.
40 *   Use target::llvm, target::cuda etc functions instead of constructing directly.
41 */
42 class TargetNode : public Node {
43  public:
44   /*! \brief The name of the target device */
45   std::string target_name;
46   /*! \brief The name of the target device */
47   std::string device_name;
48   /*! \brief The type of the target device */
49   int device_type;
50   /*! \brief The maximum threads that a schedule should use for this device */
51   int max_num_threads = 1;
52   /*! \brief The warp size that should be used by the LowerThreadAllreduce pass */
53   int thread_warp_size = 1;
54   /*! \brief Keys for this target */
55   Array<Expr> keys_array;
56   /*! \brief Options for this target */
57   Array<Expr> options_array;
58   /*! \brief Collection of imported libs */
59   Array<Expr> libs_array;
60 
61   /*! \return the full device string to pass to codegen::Build */
62   TVM_DLL const std::string& str() const;
63 
VisitAttrs(AttrVisitor * v)64   void VisitAttrs(AttrVisitor* v) {
65     v->Visit("target_name", &target_name);
66     v->Visit("device_name", &device_name);
67     v->Visit("device_type", &device_type);
68     v->Visit("max_num_threads", &max_num_threads);
69     v->Visit("thread_warp_size", &thread_warp_size);
70     v->Visit("keys_array", &keys_array);
71     v->Visit("options_array", &options_array);
72     v->Visit("libs_array", &libs_array);
73   }
74 
75   /*! \brief Get the keys for this target as a vector of string */
76   TVM_DLL std::vector<std::string> keys() const;
77 
78   /*! \brief Get the options for this target as a vector of string */
79   TVM_DLL std::vector<std::string> options() const;
80 
81   /*! \brief Get the keys for this target as an unordered_set of string */
82   TVM_DLL std::unordered_set<std::string> libs() const;
83 
84   static constexpr const char* _type_key = "Target";
85   TVM_DECLARE_NODE_TYPE_INFO(TargetNode, Node);
86 
87  private:
88   /*! \brief Internal string repr. */
89   mutable std::string str_repr_;
90 };
91 
92 /*! \brief reference cpass to the target. */
93 class Target : public NodeRef {
94  public:
Target()95   Target() {}
Target(ObjectPtr<Object> n)96   explicit Target(ObjectPtr<Object> n) : NodeRef(n) {}
97   /*!
98   * \brief Create a Target given a string
99   * \param target_str the string to parse
100   */
101   TVM_DLL static Target Create(const std::string& target_str);
102   /*!
103    * \brief Get the current target context from thread local storage.
104    * \param allow_not_defined If the context stack is empty and this is set to true, an
105    *   undefined Target will be returned. Otherwise, an empty context stack will cause a
106    *   runtime error.
107    * \return The target that is the current context. The target may not be defined if
108    * allow_not_defined is true.
109    */
110   TVM_DLL static tvm::Target Current(bool allow_not_defined = true);
111 
112   const TargetNode* operator->() const {
113       return static_cast<const TargetNode*>(get());
114   }
115 
116   using ContainerType = TargetNode;
117   class Internal;
118  private:
119   // enable with syntax.
120   friend class Internal;
121   friend class With<Target>;
122   /*!
123    * \brief Push a new target context onto the thread local stack.
124    *  The Target on top of the stack is used to determine which
125    *  specialization to use when invoking a GenericFunc.
126    */
127   TVM_DLL void EnterWithScope();
128   /*!
129    * \brief Pop a target off the thread local context stack,
130    *  restoring the previous target as the current context.
131    */
132   TVM_DLL void ExitWithScope();
133 };
134 
135 /*! \brief This namespace provides functions to construct Target instances */
136 namespace target {
137 /*! \return A target for LLVM */
138 TVM_DLL Target llvm(const std::vector<std::string>& options =
139                    std::vector<std::string>());
140 
141 /*! \return A target for CUDA */
142 TVM_DLL Target cuda(const std::vector<std::string>& options =
143                    std::vector<std::string>());
144 
145 /*! \return A target for ROCm */
146 TVM_DLL Target rocm(const std::vector<std::string>& options =
147                    std::vector<std::string>());
148 
149 /*! \return A target for OpenCL */
150 TVM_DLL Target opencl(const std::vector<std::string>& options =
151                      std::vector<std::string>());
152 
153 /*! \return A target for Metal */
154 TVM_DLL Target metal(const std::vector<std::string>& options =
155                     std::vector<std::string>());
156 
157 /*! \return A target for rasp */
158 TVM_DLL Target rasp(const std::vector<std::string>& options =
159                    std::vector<std::string>());
160 
161 /*! \return A target for Mali */
162 TVM_DLL Target mali(const std::vector<std::string>& options =
163                    std::vector<std::string>());
164 
165 /*! \return A target for Intel Graphics */
166 TVM_DLL Target intel_graphics(const std::vector<std::string>& options =
167                              std::vector<std::string>());
168 
169 /*! \return A target for stackvm */
170 TVM_DLL Target stackvm(const std::vector<std::string>& options =
171                       std::vector<std::string>());
172 
173 }  // namespace target
174 
175 /*!
176  * \brief Container for build configuration options
177  */
178 class BuildConfigNode : public Node {
179  public:
180   /*!
181    * \brief The data alignment to use when constructing buffers. If this is set to
182    * -1, then TVM's internal default will be used
183    */
184   int data_alignment = -1;
185   /*!
186    * \brief The offset factor to use when constructing buffers. If this is set to
187    * 0, then the offset field is not used.
188    */
189   int offset_factor = 0;
190 
191   /*!
192    * \brief Splitting factor for loop splitting. If this is set to zero, no splitting will be
193    * done. Otherwise, a split will be done with this factor and the inner loop will be unrolled.
194    */
195   int double_buffer_split_loop = 1;
196   /*! \brief Threshold of number of steps in the loop to be automatically unrolled */
197   int auto_unroll_max_step = 0;
198   /*! \brief The maximum nested level of loops that can be automatically unrolled */
199   int auto_unroll_max_depth = 8;
200   /*! \brief The maximum extent of loop that will be unrolled */
201   int auto_unroll_max_extent = 0;
202   /*!
203    * \brief Whether to explicitly unroll the loop. If set to false, the unroll hint will
204    * be passed to the CodeGen phase. Set to true if CodeGen supports unroll pragma.
205    */
206   bool unroll_explicit = true;
207 
208   /*! \brief Set to true if buffer arguments do not overlap. This enables more optimization. */
209   bool restricted_func = true;
210 
211   /*! \brief Whether to detect global barrier */
212   bool detect_global_barrier = false;
213 
214   /*! \brief Whether to partition const loop */
215   bool partition_const_loop = false;
216 
217   /*! \brief Whether to dump the IR of each pass (only when building from python) */
218   std::vector< std::pair<int, runtime::PackedFunc> > add_lower_pass;
219 
220   /*! \brief Whether to dump the IR of each pass (only when building from python) */
221   bool dump_pass_ir = false;
222 
223   /*! \brief Whether to instrument loads and stores with check for out of the bounds. */
224   bool instrument_bound_checkers = false;
225 
226   /*! \brief Whether to disable select rewriting. */
227   bool disable_select_rewriting = false;
228 
229   /*! \brief Whether to disable loop vectorization. */
230   bool disable_vectorize = false;
231 
232   /*! \brief Whether to disable assert stmt generation. */
233   bool disable_assert = false;
234 
VisitAttrs(AttrVisitor * v)235   void VisitAttrs(AttrVisitor* v) {
236     v->Visit("data_alignment", &data_alignment);
237     v->Visit("offset_factor", &offset_factor);
238     v->Visit("double_buffer_split_loop", &double_buffer_split_loop);
239     v->Visit("auto_unroll_max_step", &auto_unroll_max_step);
240     v->Visit("auto_unroll_max_depth", &auto_unroll_max_depth);
241     v->Visit("auto_unroll_max_extent", &auto_unroll_max_extent);
242     v->Visit("unroll_explicit", &unroll_explicit);
243     v->Visit("restricted_func", &restricted_func);
244     v->Visit("detect_global_barrier", &detect_global_barrier);
245     v->Visit("partition_const_loop", &partition_const_loop);
246     v->Visit("dump_pass_ir", &dump_pass_ir);
247     v->Visit("instrument_bound_checkers", &instrument_bound_checkers);
248     v->Visit("disable_select_rewriting", &disable_select_rewriting);
249     v->Visit("disable_vectorize", &disable_vectorize);
250     v->Visit("disable_assert", &disable_assert);
251   }
252 
253   static constexpr const char* _type_key = "BuildConfig";
254   TVM_DECLARE_NODE_TYPE_INFO(BuildConfigNode, Node);
255 };
256 
257 /*!
258  * \brief Build configuration for compilations.
259  */
260 class BuildConfig : public ::tvm::NodeRef {
261  public:
BuildConfig()262   BuildConfig() {}
BuildConfig(ObjectPtr<Object> n)263   explicit BuildConfig(ObjectPtr<Object> n) : NodeRef(n) {}
264   const BuildConfigNode* operator->() const {
265     return static_cast<const BuildConfigNode*>(get());
266   }
267   BuildConfigNode* operator->() {
268     return static_cast<BuildConfigNode*>(get_mutable());
269   }
270   /*!
271    * \brief Construct a BuildConfig containing a empty build config node.
272    * \return The new BuildConfig
273    */
274   TVM_DLL static BuildConfig Create();
275   /*!
276    * \brief Get the current BuildConfig context from thread local storage, or a default
277    * configuration if a BuildConfig scope has not been entered.
278    * \return The configuration that is the current context.
279    */
280   TVM_DLL static BuildConfig Current();
281 
282   using ContainerType = BuildConfigNode;
283   class Internal;
284 
285  private:
286   // Enable with syntax.
287   friend class With<BuildConfig>;
288   /*!
289    * \brief Push a new BuildConfig context onto the thread local stack.
290    */
291   TVM_DLL void EnterWithScope();
292 
293   /*!
294    * \brief Pop a build config off the thread local context stack,
295    * restoring the previous configuration as the current context.
296    */
297   TVM_DLL void ExitWithScope();
298 };
299 
300 /*!
301 * \brief Build a LoweredFunc given a schedule, args and binds
302 * \param sch The schedule to lower.
303 * \param args The arguments to the function.
304 * \param name The name of the lowered function.
305 * \param binds Buffer assignments.
306 * \param config The build configuration.
307 * \return The lowered function.
308 */
309 TVM_DLL Array<LoweredFunc> lower(Schedule sch,
310                                  const Array<Tensor>& args,
311                                  const std::string& name,
312                                  const std::unordered_map<Tensor, Buffer>& binds,
313                                  const BuildConfig& config);
314 /*!
315 * \brief Split host/device function and running necessary pass before build
316 * \param funcs The functions to be built.
317 * \param target The target device to build for.
318 * \param target_host The target for building host code. To use the default, pass Target()
319 * \param config The build configuration.
320 * \return The Array<Array<LoweredFunc>> with 2 elements. First is host function Array,
321           second is device function array
322 */
323 TVM_DLL Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs,
324                                                         const Target& target,
325                                                         const Target& target_host,
326                                                         const BuildConfig& config);
327 
328 /*!
329 * \brief Build a device and host module for a specific target from an array of lowered functions.
330 * \param funcs The functions to be built.
331 * \param target The target device to build for.
332 * \param target_host The target for building host code. To use the default, pass Target()
333 * \param config The build configuration.
334 * \return The built module.
335 */
336 TVM_DLL runtime::Module build(const Array<LoweredFunc>& funcs,
337                               const Target& target,
338                               const Target& target_host,
339                               const BuildConfig& config);
340 
341 /*!
342  * \brief Build a device and host module for a specific target from a map
343  * contains target to a list of lowered functions pairs. This function is used
344  * for heterogeneous build.
345  * \param input The map contains target to a list of lowered functions pairs.
346  * \param target_host The target for building host code. To use the default,
347  *        pass Target().
348  * \param config The build configuration.
349  * \return The built module that contains code for different processors.
350  */
351 TVM_DLL runtime::Module build(const Map<Target, Array<LoweredFunc>>& input,
352                               const Target& target_host,
353                               const BuildConfig& config);
354 
355 /*!
356  * \brief Build a device and host module for a specific target from a map
357  * contains target to a list of lowered functions pairs. This function is used
358  * for heterogeneous build.
359  * \param input The map contains target string to a list of lowered functions
360  *        pairs.
361  * \param target_host The target for building host code. To use the default,
362  *        pass Target().
363  * \param config The build configuration.
364  * \return The built module that contains code for different processors.
365  */
366 TVM_DLL runtime::Module build(const Map<std::string, Array<LoweredFunc>>& input,
367                               const Target& target_host,
368                               const BuildConfig& config);
369 
370 class GenericFuncNode;
371 
372 /*!
373  * \brief Generic function that can be specialized on a per-target basis.
374  */
375 class GenericFunc : public NodeRef {
376  public:
GenericFunc()377   GenericFunc() {}
GenericFunc(ObjectPtr<Object> n)378   explicit GenericFunc(ObjectPtr<Object> n) : NodeRef(n) {}
379 
380   /*!
381    * \brief Set the default function implementaiton.
382    * \param value The default function
383    * \param allow_override If true, this call may override a previously registered function. If
384    * false, an error will be logged if the call would override a previously registered function.
385    * \return reference to self.
386    */
387   TVM_DLL GenericFunc& set_default(const runtime::PackedFunc value,
388                                    bool allow_override = false);
389   /*!
390    * \brief Register a specialized function
391    * \param tags The tags for this specialization
392    * \param value The specialized function
393    * \param allow_override If true, this call may override previously registered tags. If false,
394    * an error will be logged if the call would override previously registered tags.
395    * \return reference to self.
396    */
397   TVM_DLL GenericFunc& register_func(const std::vector<std::string>& tags,
398                                      const runtime::PackedFunc value,
399                                      bool allow_override = false);
400   /*!
401    * \brief Call generic function by directly passing in unpacked format.
402    * \param args Arguments to be passed.
403    * \tparam Args arguments to be passed.
404    *
405    * \code
406    *   // Example code on how to call generic function
407    *   void CallGeneirc(GenericFunc f) {
408    *     // call like normal functions by pass in arguments
409    *     // return value is automatically converted back
410    *     int rvalue = f(1, 2.0);
411    *   }
412    * \endcode
413    */
414   template<typename... Args>
415   inline runtime::TVMRetValue operator()(Args&& ...args) const;
416   /*!
417    * \brief Invoke the relevant function for the current target context, set by set_target_context.
418    * Arguments are passed in packed format.
419    * \param args The arguments to pass to the function.
420    * \param ret The return value
421    */
422   TVM_DLL void CallPacked(runtime::TVMArgs args,
423                           runtime::TVMRetValue* ret) const;
424 
425   /*!
426    * \brief Find or register the GenericFunc instance corresponding to the give name
427    * \param name The name of the registered GenericFunc
428    * \return The GenericFunc instance
429    */
430   TVM_DLL static GenericFunc Get(const std::string& name);
431 
432   /*!
433    * \brief Add a GenericFunc instance to the registry
434    * \param func The GenericFunc instance
435    * \param name The name of the registered GenericFunc
436    */
437   TVM_DLL static void RegisterGenericFunc(GenericFunc func, const std::string& name);
438 
439   /*!
440    * \brief access the internal node container
441    * \return the pointer to the internal node container
442    */
443   inline GenericFuncNode* operator->();
444 
445   // declare container type
446   using ContainerType = GenericFuncNode;
447 
448   // Internal class.
449   struct Manager;
450 
451  private:
452   friend struct Manager;
453 };
454 
455 template<typename... Args>
operator()456 inline runtime::TVMRetValue GenericFunc::operator()(Args&& ...args) const {
457   const int kNumArgs = sizeof...(Args);
458   const int kArraySize = kNumArgs > 0 ? kNumArgs : 1;
459   TVMValue values[kArraySize];
460   int type_codes[kArraySize];
461   runtime::detail::for_each(runtime::TVMArgsSetter(values, type_codes),
462     std::forward<Args>(args)...);
463   runtime::TVMRetValue rv;
464   CallPacked(runtime::TVMArgs(values, type_codes, kNumArgs), &rv);
465   return rv;
466 }
467 
468 /*!
469  * \brief Represents a generic function that can be specialized on a per-target basis.
470  */
471 class GenericFuncNode : public Node {
472  public:
473   /*! \brief name of the function */
474   std::string name_;
475   /* \brief the generic builder */
476   runtime::PackedFunc generic_func_;
477   /* \brief map from keys to registered functions */
478   std::unordered_map<std::string, runtime::PackedFunc> dispatch_dict_;
479 
VisitAttrs(AttrVisitor * v)480   void VisitAttrs(AttrVisitor* v) {}
481 
482   static constexpr const char* _type_key = "GenericFunc";
483   TVM_DECLARE_NODE_TYPE_INFO(GenericFuncNode, Node);
484 };
485 
486 inline GenericFuncNode* GenericFunc::operator->() {
487   return static_cast<GenericFuncNode*>(get_mutable());
488 }
489 
490 #define TVM_GENERIC_FUNC_REG_VAR_DEF                            \
491   static TVM_ATTRIBUTE_UNUSED ::tvm::GenericFunc& __mk_ ## TVM
492 
493 /*!
494  * \def TVM_REGISTER_GENERIC_FUNC
495  * \brief Register a new generic function, or set a device-specific variant
496  * of the corresponding function.
497  *
498  * \param name The name of the function
499  */
500 #define TVM_REGISTER_GENERIC_FUNC(name)                           \
501   TVM_STR_CONCAT(TVM_GENERIC_FUNC_REG_VAR_DEF, __COUNTER__) =     \
502       ::tvm::GenericFunc::Get(#name)
503 
504 
505 }  // namespace tvm
506 
507 #endif  // TVM_BUILD_MODULE_H_
508