1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file tvm/build_module.h
22 * \brief Functions for compiling ops.
23 */
24 #ifndef TVM_BUILD_MODULE_H_
25 #define TVM_BUILD_MODULE_H_
26
27 #include <string>
28 #include <vector>
29 #include <utility>
30 #include <unordered_map>
31 #include <unordered_set>
32 #include "runtime/packed_func.h"
33 #include "schedule_pass.h"
34 #include "lowered_func.h"
35
36 namespace tvm {
37
38 /*!
39 * \brief Container for target device information.
40 * Use target::llvm, target::cuda etc functions instead of constructing directly.
41 */
42 class TargetNode : public Node {
43 public:
44 /*! \brief The name of the target device */
45 std::string target_name;
46 /*! \brief The name of the target device */
47 std::string device_name;
48 /*! \brief The type of the target device */
49 int device_type;
50 /*! \brief The maximum threads that a schedule should use for this device */
51 int max_num_threads = 1;
52 /*! \brief The warp size that should be used by the LowerThreadAllreduce pass */
53 int thread_warp_size = 1;
54 /*! \brief Keys for this target */
55 Array<Expr> keys_array;
56 /*! \brief Options for this target */
57 Array<Expr> options_array;
58 /*! \brief Collection of imported libs */
59 Array<Expr> libs_array;
60
61 /*! \return the full device string to pass to codegen::Build */
62 TVM_DLL const std::string& str() const;
63
VisitAttrs(AttrVisitor * v)64 void VisitAttrs(AttrVisitor* v) {
65 v->Visit("target_name", &target_name);
66 v->Visit("device_name", &device_name);
67 v->Visit("device_type", &device_type);
68 v->Visit("max_num_threads", &max_num_threads);
69 v->Visit("thread_warp_size", &thread_warp_size);
70 v->Visit("keys_array", &keys_array);
71 v->Visit("options_array", &options_array);
72 v->Visit("libs_array", &libs_array);
73 }
74
75 /*! \brief Get the keys for this target as a vector of string */
76 TVM_DLL std::vector<std::string> keys() const;
77
78 /*! \brief Get the options for this target as a vector of string */
79 TVM_DLL std::vector<std::string> options() const;
80
81 /*! \brief Get the keys for this target as an unordered_set of string */
82 TVM_DLL std::unordered_set<std::string> libs() const;
83
84 static constexpr const char* _type_key = "Target";
85 TVM_DECLARE_NODE_TYPE_INFO(TargetNode, Node);
86
87 private:
88 /*! \brief Internal string repr. */
89 mutable std::string str_repr_;
90 };
91
92 /*! \brief reference cpass to the target. */
93 class Target : public NodeRef {
94 public:
Target()95 Target() {}
Target(ObjectPtr<Object> n)96 explicit Target(ObjectPtr<Object> n) : NodeRef(n) {}
97 /*!
98 * \brief Create a Target given a string
99 * \param target_str the string to parse
100 */
101 TVM_DLL static Target Create(const std::string& target_str);
102 /*!
103 * \brief Get the current target context from thread local storage.
104 * \param allow_not_defined If the context stack is empty and this is set to true, an
105 * undefined Target will be returned. Otherwise, an empty context stack will cause a
106 * runtime error.
107 * \return The target that is the current context. The target may not be defined if
108 * allow_not_defined is true.
109 */
110 TVM_DLL static tvm::Target Current(bool allow_not_defined = true);
111
112 const TargetNode* operator->() const {
113 return static_cast<const TargetNode*>(get());
114 }
115
116 using ContainerType = TargetNode;
117 class Internal;
118 private:
119 // enable with syntax.
120 friend class Internal;
121 friend class With<Target>;
122 /*!
123 * \brief Push a new target context onto the thread local stack.
124 * The Target on top of the stack is used to determine which
125 * specialization to use when invoking a GenericFunc.
126 */
127 TVM_DLL void EnterWithScope();
128 /*!
129 * \brief Pop a target off the thread local context stack,
130 * restoring the previous target as the current context.
131 */
132 TVM_DLL void ExitWithScope();
133 };
134
135 /*! \brief This namespace provides functions to construct Target instances */
136 namespace target {
137 /*! \return A target for LLVM */
138 TVM_DLL Target llvm(const std::vector<std::string>& options =
139 std::vector<std::string>());
140
141 /*! \return A target for CUDA */
142 TVM_DLL Target cuda(const std::vector<std::string>& options =
143 std::vector<std::string>());
144
145 /*! \return A target for ROCm */
146 TVM_DLL Target rocm(const std::vector<std::string>& options =
147 std::vector<std::string>());
148
149 /*! \return A target for OpenCL */
150 TVM_DLL Target opencl(const std::vector<std::string>& options =
151 std::vector<std::string>());
152
153 /*! \return A target for Metal */
154 TVM_DLL Target metal(const std::vector<std::string>& options =
155 std::vector<std::string>());
156
157 /*! \return A target for rasp */
158 TVM_DLL Target rasp(const std::vector<std::string>& options =
159 std::vector<std::string>());
160
161 /*! \return A target for Mali */
162 TVM_DLL Target mali(const std::vector<std::string>& options =
163 std::vector<std::string>());
164
165 /*! \return A target for Intel Graphics */
166 TVM_DLL Target intel_graphics(const std::vector<std::string>& options =
167 std::vector<std::string>());
168
169 /*! \return A target for stackvm */
170 TVM_DLL Target stackvm(const std::vector<std::string>& options =
171 std::vector<std::string>());
172
173 } // namespace target
174
175 /*!
176 * \brief Container for build configuration options
177 */
178 class BuildConfigNode : public Node {
179 public:
180 /*!
181 * \brief The data alignment to use when constructing buffers. If this is set to
182 * -1, then TVM's internal default will be used
183 */
184 int data_alignment = -1;
185 /*!
186 * \brief The offset factor to use when constructing buffers. If this is set to
187 * 0, then the offset field is not used.
188 */
189 int offset_factor = 0;
190
191 /*!
192 * \brief Splitting factor for loop splitting. If this is set to zero, no splitting will be
193 * done. Otherwise, a split will be done with this factor and the inner loop will be unrolled.
194 */
195 int double_buffer_split_loop = 1;
196 /*! \brief Threshold of number of steps in the loop to be automatically unrolled */
197 int auto_unroll_max_step = 0;
198 /*! \brief The maximum nested level of loops that can be automatically unrolled */
199 int auto_unroll_max_depth = 8;
200 /*! \brief The maximum extent of loop that will be unrolled */
201 int auto_unroll_max_extent = 0;
202 /*!
203 * \brief Whether to explicitly unroll the loop. If set to false, the unroll hint will
204 * be passed to the CodeGen phase. Set to true if CodeGen supports unroll pragma.
205 */
206 bool unroll_explicit = true;
207
208 /*! \brief Set to true if buffer arguments do not overlap. This enables more optimization. */
209 bool restricted_func = true;
210
211 /*! \brief Whether to detect global barrier */
212 bool detect_global_barrier = false;
213
214 /*! \brief Whether to partition const loop */
215 bool partition_const_loop = false;
216
217 /*! \brief Whether to dump the IR of each pass (only when building from python) */
218 std::vector< std::pair<int, runtime::PackedFunc> > add_lower_pass;
219
220 /*! \brief Whether to dump the IR of each pass (only when building from python) */
221 bool dump_pass_ir = false;
222
223 /*! \brief Whether to instrument loads and stores with check for out of the bounds. */
224 bool instrument_bound_checkers = false;
225
226 /*! \brief Whether to disable select rewriting. */
227 bool disable_select_rewriting = false;
228
229 /*! \brief Whether to disable loop vectorization. */
230 bool disable_vectorize = false;
231
232 /*! \brief Whether to disable assert stmt generation. */
233 bool disable_assert = false;
234
VisitAttrs(AttrVisitor * v)235 void VisitAttrs(AttrVisitor* v) {
236 v->Visit("data_alignment", &data_alignment);
237 v->Visit("offset_factor", &offset_factor);
238 v->Visit("double_buffer_split_loop", &double_buffer_split_loop);
239 v->Visit("auto_unroll_max_step", &auto_unroll_max_step);
240 v->Visit("auto_unroll_max_depth", &auto_unroll_max_depth);
241 v->Visit("auto_unroll_max_extent", &auto_unroll_max_extent);
242 v->Visit("unroll_explicit", &unroll_explicit);
243 v->Visit("restricted_func", &restricted_func);
244 v->Visit("detect_global_barrier", &detect_global_barrier);
245 v->Visit("partition_const_loop", &partition_const_loop);
246 v->Visit("dump_pass_ir", &dump_pass_ir);
247 v->Visit("instrument_bound_checkers", &instrument_bound_checkers);
248 v->Visit("disable_select_rewriting", &disable_select_rewriting);
249 v->Visit("disable_vectorize", &disable_vectorize);
250 v->Visit("disable_assert", &disable_assert);
251 }
252
253 static constexpr const char* _type_key = "BuildConfig";
254 TVM_DECLARE_NODE_TYPE_INFO(BuildConfigNode, Node);
255 };
256
257 /*!
258 * \brief Build configuration for compilations.
259 */
260 class BuildConfig : public ::tvm::NodeRef {
261 public:
BuildConfig()262 BuildConfig() {}
BuildConfig(ObjectPtr<Object> n)263 explicit BuildConfig(ObjectPtr<Object> n) : NodeRef(n) {}
264 const BuildConfigNode* operator->() const {
265 return static_cast<const BuildConfigNode*>(get());
266 }
267 BuildConfigNode* operator->() {
268 return static_cast<BuildConfigNode*>(get_mutable());
269 }
270 /*!
271 * \brief Construct a BuildConfig containing a empty build config node.
272 * \return The new BuildConfig
273 */
274 TVM_DLL static BuildConfig Create();
275 /*!
276 * \brief Get the current BuildConfig context from thread local storage, or a default
277 * configuration if a BuildConfig scope has not been entered.
278 * \return The configuration that is the current context.
279 */
280 TVM_DLL static BuildConfig Current();
281
282 using ContainerType = BuildConfigNode;
283 class Internal;
284
285 private:
286 // Enable with syntax.
287 friend class With<BuildConfig>;
288 /*!
289 * \brief Push a new BuildConfig context onto the thread local stack.
290 */
291 TVM_DLL void EnterWithScope();
292
293 /*!
294 * \brief Pop a build config off the thread local context stack,
295 * restoring the previous configuration as the current context.
296 */
297 TVM_DLL void ExitWithScope();
298 };
299
300 /*!
301 * \brief Build a LoweredFunc given a schedule, args and binds
302 * \param sch The schedule to lower.
303 * \param args The arguments to the function.
304 * \param name The name of the lowered function.
305 * \param binds Buffer assignments.
306 * \param config The build configuration.
307 * \return The lowered function.
308 */
309 TVM_DLL Array<LoweredFunc> lower(Schedule sch,
310 const Array<Tensor>& args,
311 const std::string& name,
312 const std::unordered_map<Tensor, Buffer>& binds,
313 const BuildConfig& config);
314 /*!
315 * \brief Split host/device function and running necessary pass before build
316 * \param funcs The functions to be built.
317 * \param target The target device to build for.
318 * \param target_host The target for building host code. To use the default, pass Target()
319 * \param config The build configuration.
320 * \return The Array<Array<LoweredFunc>> with 2 elements. First is host function Array,
321 second is device function array
322 */
323 TVM_DLL Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs,
324 const Target& target,
325 const Target& target_host,
326 const BuildConfig& config);
327
328 /*!
329 * \brief Build a device and host module for a specific target from an array of lowered functions.
330 * \param funcs The functions to be built.
331 * \param target The target device to build for.
332 * \param target_host The target for building host code. To use the default, pass Target()
333 * \param config The build configuration.
334 * \return The built module.
335 */
336 TVM_DLL runtime::Module build(const Array<LoweredFunc>& funcs,
337 const Target& target,
338 const Target& target_host,
339 const BuildConfig& config);
340
341 /*!
342 * \brief Build a device and host module for a specific target from a map
343 * contains target to a list of lowered functions pairs. This function is used
344 * for heterogeneous build.
345 * \param input The map contains target to a list of lowered functions pairs.
346 * \param target_host The target for building host code. To use the default,
347 * pass Target().
348 * \param config The build configuration.
349 * \return The built module that contains code for different processors.
350 */
351 TVM_DLL runtime::Module build(const Map<Target, Array<LoweredFunc>>& input,
352 const Target& target_host,
353 const BuildConfig& config);
354
355 /*!
356 * \brief Build a device and host module for a specific target from a map
357 * contains target to a list of lowered functions pairs. This function is used
358 * for heterogeneous build.
359 * \param input The map contains target string to a list of lowered functions
360 * pairs.
361 * \param target_host The target for building host code. To use the default,
362 * pass Target().
363 * \param config The build configuration.
364 * \return The built module that contains code for different processors.
365 */
366 TVM_DLL runtime::Module build(const Map<std::string, Array<LoweredFunc>>& input,
367 const Target& target_host,
368 const BuildConfig& config);
369
370 class GenericFuncNode;
371
372 /*!
373 * \brief Generic function that can be specialized on a per-target basis.
374 */
375 class GenericFunc : public NodeRef {
376 public:
GenericFunc()377 GenericFunc() {}
GenericFunc(ObjectPtr<Object> n)378 explicit GenericFunc(ObjectPtr<Object> n) : NodeRef(n) {}
379
380 /*!
381 * \brief Set the default function implementaiton.
382 * \param value The default function
383 * \param allow_override If true, this call may override a previously registered function. If
384 * false, an error will be logged if the call would override a previously registered function.
385 * \return reference to self.
386 */
387 TVM_DLL GenericFunc& set_default(const runtime::PackedFunc value,
388 bool allow_override = false);
389 /*!
390 * \brief Register a specialized function
391 * \param tags The tags for this specialization
392 * \param value The specialized function
393 * \param allow_override If true, this call may override previously registered tags. If false,
394 * an error will be logged if the call would override previously registered tags.
395 * \return reference to self.
396 */
397 TVM_DLL GenericFunc& register_func(const std::vector<std::string>& tags,
398 const runtime::PackedFunc value,
399 bool allow_override = false);
400 /*!
401 * \brief Call generic function by directly passing in unpacked format.
402 * \param args Arguments to be passed.
403 * \tparam Args arguments to be passed.
404 *
405 * \code
406 * // Example code on how to call generic function
407 * void CallGeneirc(GenericFunc f) {
408 * // call like normal functions by pass in arguments
409 * // return value is automatically converted back
410 * int rvalue = f(1, 2.0);
411 * }
412 * \endcode
413 */
414 template<typename... Args>
415 inline runtime::TVMRetValue operator()(Args&& ...args) const;
416 /*!
417 * \brief Invoke the relevant function for the current target context, set by set_target_context.
418 * Arguments are passed in packed format.
419 * \param args The arguments to pass to the function.
420 * \param ret The return value
421 */
422 TVM_DLL void CallPacked(runtime::TVMArgs args,
423 runtime::TVMRetValue* ret) const;
424
425 /*!
426 * \brief Find or register the GenericFunc instance corresponding to the give name
427 * \param name The name of the registered GenericFunc
428 * \return The GenericFunc instance
429 */
430 TVM_DLL static GenericFunc Get(const std::string& name);
431
432 /*!
433 * \brief Add a GenericFunc instance to the registry
434 * \param func The GenericFunc instance
435 * \param name The name of the registered GenericFunc
436 */
437 TVM_DLL static void RegisterGenericFunc(GenericFunc func, const std::string& name);
438
439 /*!
440 * \brief access the internal node container
441 * \return the pointer to the internal node container
442 */
443 inline GenericFuncNode* operator->();
444
445 // declare container type
446 using ContainerType = GenericFuncNode;
447
448 // Internal class.
449 struct Manager;
450
451 private:
452 friend struct Manager;
453 };
454
455 template<typename... Args>
operator()456 inline runtime::TVMRetValue GenericFunc::operator()(Args&& ...args) const {
457 const int kNumArgs = sizeof...(Args);
458 const int kArraySize = kNumArgs > 0 ? kNumArgs : 1;
459 TVMValue values[kArraySize];
460 int type_codes[kArraySize];
461 runtime::detail::for_each(runtime::TVMArgsSetter(values, type_codes),
462 std::forward<Args>(args)...);
463 runtime::TVMRetValue rv;
464 CallPacked(runtime::TVMArgs(values, type_codes, kNumArgs), &rv);
465 return rv;
466 }
467
468 /*!
469 * \brief Represents a generic function that can be specialized on a per-target basis.
470 */
471 class GenericFuncNode : public Node {
472 public:
473 /*! \brief name of the function */
474 std::string name_;
475 /* \brief the generic builder */
476 runtime::PackedFunc generic_func_;
477 /* \brief map from keys to registered functions */
478 std::unordered_map<std::string, runtime::PackedFunc> dispatch_dict_;
479
VisitAttrs(AttrVisitor * v)480 void VisitAttrs(AttrVisitor* v) {}
481
482 static constexpr const char* _type_key = "GenericFunc";
483 TVM_DECLARE_NODE_TYPE_INFO(GenericFuncNode, Node);
484 };
485
486 inline GenericFuncNode* GenericFunc::operator->() {
487 return static_cast<GenericFuncNode*>(get_mutable());
488 }
489
490 #define TVM_GENERIC_FUNC_REG_VAR_DEF \
491 static TVM_ATTRIBUTE_UNUSED ::tvm::GenericFunc& __mk_ ## TVM
492
493 /*!
494 * \def TVM_REGISTER_GENERIC_FUNC
495 * \brief Register a new generic function, or set a device-specific variant
496 * of the corresponding function.
497 *
498 * \param name The name of the function
499 */
500 #define TVM_REGISTER_GENERIC_FUNC(name) \
501 TVM_STR_CONCAT(TVM_GENERIC_FUNC_REG_VAR_DEF, __COUNTER__) = \
502 ::tvm::GenericFunc::Get(#name)
503
504
505 } // namespace tvm
506
507 #endif // TVM_BUILD_MODULE_H_
508