1 #ifndef HALIDE_GENERATOR_H_
2 #define HALIDE_GENERATOR_H_
3
4 /** \file
5 *
6 * Generator is a class used to encapsulate the building of Funcs in user
7 * pipelines. A Generator is agnostic to JIT vs AOT compilation; it can be used for
8 * either purpose, but is especially convenient to use for AOT compilation.
9 *
10 * A Generator explicitly declares the Inputs and Outputs associated for a given
11 * pipeline, and (optionally) separates the code for constructing the outputs from the code from
12 * scheduling them. For instance:
13 *
14 * \code
15 * class Blur : public Generator<Blur> {
16 * public:
17 * Input<Func> input{"input", UInt(16), 2};
18 * Output<Func> output{"output", UInt(16), 2};
19 * void generate() {
20 * blur_x(x, y) = (input(x, y) + input(x+1, y) + input(x+2, y))/3;
21 * blur_y(x, y) = (blur_x(x, y) + blur_x(x, y+1) + blur_x(x, y+2))/3;
22 * output(x, y) = blur(x, y);
23 * }
24 * void schedule() {
25 * blur_y.split(y, y, yi, 8).parallel(y).vectorize(x, 8);
26 * blur_x.store_at(blur_y, y).compute_at(blur_y, yi).vectorize(x, 8);
27 * }
28 * private:
29 * Var x, y, xi, yi;
30 * Func blur_x, blur_y;
31 * };
32 * \endcode
33 *
34 * Halide can compile a Generator into the correct pipeline by introspecting these
35 * values and constructing an appropriate signature based on them.
36 *
37 * A Generator provides implementations of two methods:
38 *
39 * - generate(), which must fill in all Output Func(s); it may optionally also do scheduling
40 * if no schedule() method is present.
41 * - schedule(), which (if present) should contain all scheduling code.
42 *
43 * Inputs can be any C++ scalar type:
44 *
45 * \code
46 * Input<float> radius{"radius"};
47 * Input<int32_t> increment{"increment"};
48 * \endcode
49 *
50 * An Input<Func> is (essentially) like an ImageParam, except that it may (or may
51 * not) not be backed by an actual buffer, and thus has no defined extents.
52 *
53 * \code
54 * Input<Func> input{"input", Float(32), 2};
55 * \endcode
56 *
57 * You can optionally make the type and/or dimensions of Input<Func> unspecified,
58 * in which case the value is simply inferred from the actual Funcs passed to them.
59 * Of course, if you specify an explicit Type or Dimension, we still require the
60 * input Func to match, or a compilation error results.
61 *
62 * \code
63 * Input<Func> input{ "input", 3 }; // require 3-dimensional Func,
64 * // but leave Type unspecified
65 * \endcode
66 *
67 * A Generator must explicitly list the output(s) it produces:
68 *
69 * \code
70 * Output<Func> output{"output", Float(32), 2};
71 * \endcode
72 *
73 * You can specify an output that returns a Tuple by specifying a list of Types:
74 *
75 * \code
76 * class Tupler : Generator<Tupler> {
77 * Input<Func> input{"input", Int(32), 2};
78 * Output<Func> output{"output", {Float(32), UInt(8)}, 2};
79 * void generate() {
80 * Var x, y;
81 * Expr a = cast<float>(input(x, y));
82 * Expr b = cast<uint8_t>(input(x, y));
83 * output(x, y) = Tuple(a, b);
84 * }
85 * };
86 * \endcode
87 *
88 * You can also specify Output<X> for any scalar type (except for Handle types);
89 * this is merely syntactic sugar on top of a zero-dimensional Func, but can be
90 * quite handy, especially when used with multiple outputs:
91 *
92 * \code
93 * Output<float> sum{"sum"}; // equivalent to Output<Func> {"sum", Float(32), 0}
94 * \endcode
95 *
96 * As with Input<Func>, you can optionally make the type and/or dimensions of an
97 * Output<Func> unspecified; any unspecified types must be resolved via an
98 * implicit GeneratorParam in order to use top-level compilation.
99 *
100 * You can also declare an *array* of Input or Output, by using an array type
101 * as the type parameter:
102 *
103 * \code
104 * // Takes exactly 3 images and outputs exactly 3 sums.
105 * class SumRowsAndColumns : Generator<SumRowsAndColumns> {
106 * Input<Func[3]> inputs{"inputs", Float(32), 2};
107 * Input<int32_t[2]> extents{"extents"};
108 * Output<Func[3]> sums{"sums", Float(32), 1};
109 * void generate() {
110 * assert(inputs.size() == sums.size());
111 * // assume all inputs are same extent
112 * Expr width = extent[0];
113 * Expr height = extent[1];
114 * for (size_t i = 0; i < inputs.size(); ++i) {
115 * RDom r(0, width, 0, height);
116 * sums[i]() = 0.f;
117 * sums[i]() += inputs[i](r.x, r.y);
118 * }
119 * }
120 * };
121 * \endcode
122 *
123 * You can also leave array size unspecified, with some caveats:
124 * - For ahead-of-time compilation, Inputs must have a concrete size specified
125 * via a GeneratorParam at build time (e.g., pyramid.size=3)
126 * - For JIT compilation via a Stub, Inputs array sizes will be inferred
127 * from the vector passed.
128 * - For ahead-of-time compilation, Outputs may specify a concrete size
129 * via a GeneratorParam at build time (e.g., pyramid.size=3), or the
130 * size can be specified via a resize() method.
131 *
132 * \code
133 * class Pyramid : public Generator<Pyramid> {
134 * public:
135 * GeneratorParam<int32_t> levels{"levels", 10};
136 * Input<Func> input{ "input", Float(32), 2 };
137 * Output<Func[]> pyramid{ "pyramid", Float(32), 2 };
138 * void generate() {
139 * pyramid.resize(levels);
140 * pyramid[0](x, y) = input(x, y);
141 * for (int i = 1; i < pyramid.size(); i++) {
142 * pyramid[i](x, y) = (pyramid[i-1](2*x, 2*y) +
143 * pyramid[i-1](2*x+1, 2*y) +
144 * pyramid[i-1](2*x, 2*y+1) +
145 * pyramid[i-1](2*x+1, 2*y+1))/4;
146 * }
147 * }
148 * };
149 * \endcode
150 *
151 * A Generator can also be customized via compile-time parameters (GeneratorParams),
152 * which affect code generation.
153 *
154 * GeneratorParams, Inputs, and Outputs are (by convention) always
155 * public and always declared at the top of the Generator class, in the order
156 *
157 * \code
158 * GeneratorParam(s)
159 * Input<Func>(s)
160 * Input<non-Func>(s)
161 * Output<Func>(s)
162 * \endcode
163 *
164 * Note that the Inputs and Outputs will appear in the C function call in the order
165 * they are declared. All Input<Func> and Output<Func> are represented as halide_buffer_t;
166 * all other Input<> are the appropriate C++ scalar type. (GeneratorParams are
167 * always referenced by name, not position, so their order is irrelevant.)
168 *
169 * All Inputs and Outputs must have explicit names, and all such names must match
170 * the regex [A-Za-z][A-Za-z_0-9]* (i.e., essentially a C/C++ variable name, with
171 * some extra restrictions on underscore use). By convention, the name should match
172 * the member-variable name.
173 *
174 * You can dynamically add Inputs and Outputs to your Generator via adding a
175 * configure() method; if present, it will be called before generate(). It can
176 * examine GeneratorParams but it may not examine predeclared Inputs or Outputs;
177 * the only thing it should do is call add_input<>() and/or add_output<>().
178 * Added inputs will be appended (in order) after predeclared Inputs but before
179 * any Outputs; added outputs will be appended after predeclared Outputs.
180 *
181 * Note that the pointers returned by add_input() and add_output() are owned
182 * by the Generator and will remain valid for the Generator's lifetime; user code
183 * should not attempt to delete or free them.
184 *
185 * \code
186 * class MultiSum : public Generator<MultiSum> {
187 * public:
188 * GeneratorParam<int32_t> input_count{"input_count", 10};
189 * Output<Func> output{ "output", Float(32), 2 };
190 *
191 * void configure() {
192 * for (int i = 0; i < input_count; ++i) {
193 * extra_inputs.push_back(
194 * add_input<Func>("input_" + std::to_string(i), Float(32), 2);
195 * }
196 * }
197 *
198 * void generate() {
199 * Expr sum = 0.f;
200 * for (int i = 0; i < input_count; ++i) {
201 * sum += (*extra_inputs)[i](x, y);
202 * }
203 * output(x, y) = sum;
204 * }
205 * private:
206 * std::vector<Input<Func>* extra_inputs;
207 * };
208 * \endcode
209 *
210 * All Generators have three GeneratorParams that are implicitly provided
211 * by the base class:
212 *
213 * GeneratorParam<Target> target{"target", Target()};
214 * GeneratorParam<bool> auto_schedule{"auto_schedule", false};
215 * GeneratorParam<MachineParams> machine_params{"machine_params", MachineParams::generic()};
216 *
217 * - 'target' is the Halide::Target for which the Generator is producing code.
218 * It is read-only during the Generator's lifetime, and must not be modified;
219 * its value should always be filled in by the calling code: either the Halide
220 * build system (for ahead-of-time compilation), or ordinary C++ code
221 * (for JIT compilation).
222 * - 'auto_schedule' indicates whether the auto-scheduler should be run for this
223 * Generator:
224 * - if 'false', the Generator should schedule its Funcs as it sees fit.
225 * - if 'true', the Generator should only provide estimate()s for its Funcs,
226 * and not call any other scheduling methods.
227 * - 'machine_params' is only used if auto_schedule is true; it is ignored
228 * if auto_schedule is false. It provides details about the machine architecture
229 * being targeted which may be used to enhance the automatically-generated
230 * schedule.
231 *
232 * Generators are added to a global registry to simplify AOT build mechanics; this
233 * is done by simply using the HALIDE_REGISTER_GENERATOR macro at global scope:
234 *
235 * \code
236 * HALIDE_REGISTER_GENERATOR(ExampleGen, jit_example)
237 * \endcode
238 *
239 * The registered name of the Generator is provided must match the same rules as
240 * Input names, above.
241 *
242 * Note that the class name of the generated Stub class will match the registered
243 * name by default; if you want to vary it (typically, to include namespaces),
244 * you can add it as an optional third argument:
245 *
246 * \code
247 * HALIDE_REGISTER_GENERATOR(ExampleGen, jit_example, SomeNamespace::JitExampleStub)
248 * \endcode
249 *
250 * Note that a Generator is always executed with a specific Target assigned to it,
251 * that you can access via the get_target() method. (You should *not* use the
252 * global get_target_from_environment(), etc. methods provided in Target.h)
253 *
254 * (Note that there are older variations of Generator that differ from what's
255 * documented above; these are still supported but not described here. See
256 * https://github.com/halide/Halide/wiki/Old-Generator-Documentation for
257 * more information.)
258 */
259
260 #include <algorithm>
261 #include <iterator>
262 #include <limits>
263 #include <memory>
264 #include <mutex>
265 #include <set>
266 #include <sstream>
267 #include <string>
268 #include <type_traits>
269 #include <utility>
270 #include <vector>
271
272 #include "ExternalCode.h"
273 #include "Func.h"
274 #include "ImageParam.h"
275 #include "Introspection.h"
276 #include "ObjectInstanceRegistry.h"
277 #include "Target.h"
278
279 namespace Halide {
280
281 template<typename T>
282 class Buffer;
283
284 namespace Internal {
285
286 void generator_test();
287
288 /**
289 * ValueTracker is an internal utility class that attempts to track and flag certain
290 * obvious Stub-related errors at Halide compile time: it tracks the constraints set
291 * on any Parameter-based argument (i.e., Input<Buffer> and Output<Buffer>) to
292 * ensure that incompatible values aren't set.
293 *
294 * e.g.: if a Generator A requires stride[0] == 1,
295 * and Generator B uses Generator A via stub, but requires stride[0] == 4,
296 * we should be able to detect this at Halide compilation time, and fail immediately,
297 * rather than producing code that fails at runtime and/or runs slowly due to
298 * vectorization being unavailable.
299 *
300 * We do this by tracking the active values at entrance and exit to all user-provided
301 * Generator methods (build()/generate()/schedule()); if we ever find more than two unique
302 * values active, we know we have a potential conflict. ("two" here because the first
303 * value is the default value for a given constraint.)
304 *
305 * Note that this won't catch all cases:
306 * -- JIT compilation has no way to check for conflicts at the top-level
307 * -- constraints that match the default value (e.g. if dim(0).set_stride(1) is the
308 * first value seen by the tracker) will be ignored, so an explicit requirement set
309 * this way can be missed
310 *
311 * Nevertheless, this is likely to be much better than nothing when composing multiple
312 * layers of Stubs in a single fused result.
313 */
314 class ValueTracker {
315 private:
316 std::map<std::string, std::vector<std::vector<Expr>>> values_history;
317 const size_t max_unique_values;
318
319 public:
320 explicit ValueTracker(size_t max_unique_values = 2)
max_unique_values(max_unique_values)321 : max_unique_values(max_unique_values) {
322 }
323 void track_values(const std::string &name, const std::vector<Expr> &values);
324 };
325
326 std::vector<Expr> parameter_constraints(const Parameter &p);
327
328 template<typename T>
enum_to_string(const std::map<std::string,T> & enum_map,const T & t)329 HALIDE_NO_USER_CODE_INLINE std::string enum_to_string(const std::map<std::string, T> &enum_map, const T &t) {
330 for (auto key_value : enum_map) {
331 if (t == key_value.second) {
332 return key_value.first;
333 }
334 }
335 user_error << "Enumeration value not found.\n";
336 return "";
337 }
338
339 template<typename T>
enum_from_string(const std::map<std::string,T> & enum_map,const std::string & s)340 T enum_from_string(const std::map<std::string, T> &enum_map, const std::string &s) {
341 auto it = enum_map.find(s);
342 user_assert(it != enum_map.end()) << "Enumeration value not found: " << s << "\n";
343 return it->second;
344 }
345
346 extern const std::map<std::string, Halide::Type> &get_halide_type_enum_map();
halide_type_to_enum_string(const Type & t)347 inline std::string halide_type_to_enum_string(const Type &t) {
348 return enum_to_string(get_halide_type_enum_map(), t);
349 }
350
351 // Convert a Halide Type into a string representation of its C source.
352 // e.g., Int(32) -> "Halide::Int(32)"
353 std::string halide_type_to_c_source(const Type &t);
354
355 // Convert a Halide Type into a string representation of its C Source.
356 // e.g., Int(32) -> "int32_t"
357 std::string halide_type_to_c_type(const Type &t);
358
359 /** generate_filter_main() is a convenient wrapper for GeneratorRegistry::create() +
360 * compile_to_files(); it can be trivially wrapped by a "real" main() to produce a
361 * command-line utility for ahead-of-time filter compilation. */
362 int generate_filter_main(int argc, char **argv, std::ostream &cerr);
363
364 // select_type<> is to std::conditional as switch is to if:
365 // it allows a multiway compile-time type definition via the form
366 //
367 // select_type<cond<condition1, type1>,
368 // cond<condition2, type2>,
369 // ....
370 // cond<conditionN, typeN>>::type
371 //
372 // Note that the conditions are evaluated in order; the first evaluating to true
373 // is chosen.
374 //
375 // Note that if no conditions evaluate to true, the resulting type is illegal
376 // and will produce a compilation error. (You can provide a default by simply
377 // using cond<true, SomeType> as the final entry.)
378 template<bool B, typename T>
379 struct cond {
380 static constexpr bool value = B;
381 using type = T;
382 };
383
384 template<typename First, typename... Rest>
385 struct select_type : std::conditional<First::value, typename First::type, typename select_type<Rest...>::type> {};
386
387 template<typename First>
388 struct select_type<First> { using type = typename std::conditional<First::value, typename First::type, void>::type; };
389
390 class GeneratorBase;
391 class GeneratorParamInfo;
392
393 class GeneratorParamBase {
394 public:
395 explicit GeneratorParamBase(const std::string &name);
396 virtual ~GeneratorParamBase();
397
398 const std::string name;
399
400 // overload the set() function to call the right virtual method based on type.
401 // This allows us to attempt to set a GeneratorParam via a
402 // plain C++ type, even if we don't know the specific templated
403 // subclass. Attempting to set the wrong type will assert.
404 // Notice that there is no typed setter for Enums, for obvious reasons;
405 // setting enums in an unknown type must fallback to using set_from_string.
406 //
407 // It's always a bit iffy to use macros for this, but IMHO it clarifies the situation here.
408 #define HALIDE_GENERATOR_PARAM_TYPED_SETTER(TYPE) \
409 virtual void set(const TYPE &new_value) = 0;
410
411 HALIDE_GENERATOR_PARAM_TYPED_SETTER(bool)
412 HALIDE_GENERATOR_PARAM_TYPED_SETTER(int8_t)
413 HALIDE_GENERATOR_PARAM_TYPED_SETTER(int16_t)
414 HALIDE_GENERATOR_PARAM_TYPED_SETTER(int32_t)
415 HALIDE_GENERATOR_PARAM_TYPED_SETTER(int64_t)
416 HALIDE_GENERATOR_PARAM_TYPED_SETTER(uint8_t)
417 HALIDE_GENERATOR_PARAM_TYPED_SETTER(uint16_t)
418 HALIDE_GENERATOR_PARAM_TYPED_SETTER(uint32_t)
419 HALIDE_GENERATOR_PARAM_TYPED_SETTER(uint64_t)
420 HALIDE_GENERATOR_PARAM_TYPED_SETTER(float)
421 HALIDE_GENERATOR_PARAM_TYPED_SETTER(double)
422 HALIDE_GENERATOR_PARAM_TYPED_SETTER(Target)
423 HALIDE_GENERATOR_PARAM_TYPED_SETTER(MachineParams)
424 HALIDE_GENERATOR_PARAM_TYPED_SETTER(Type)
425 HALIDE_GENERATOR_PARAM_TYPED_SETTER(LoopLevel)
426
427 #undef HALIDE_GENERATOR_PARAM_TYPED_SETTER
428
429 // Add overloads for string and char*
430 void set(const std::string &new_value) {
431 set_from_string(new_value);
432 }
433 void set(const char *new_value) {
434 set_from_string(std::string(new_value));
435 }
436
437 protected:
438 friend class GeneratorBase;
439 friend class GeneratorParamInfo;
440 friend class StubEmitter;
441
442 void check_value_readable() const;
443 void check_value_writable() const;
444
445 // All GeneratorParams are settable from string.
446 virtual void set_from_string(const std::string &value_string) = 0;
447
448 virtual std::string call_to_string(const std::string &v) const = 0;
449 virtual std::string get_c_type() const = 0;
450
451 virtual std::string get_type_decls() const {
452 return "";
453 }
454
455 virtual std::string get_default_value() const = 0;
456
457 virtual bool is_synthetic_param() const {
458 return false;
459 }
460
461 virtual bool is_looplevel_param() const {
462 return false;
463 }
464
465 void fail_wrong_type(const char *type);
466
467 private:
468 // No copy
469 GeneratorParamBase(const GeneratorParamBase &) = delete;
470 void operator=(const GeneratorParamBase &) = delete;
471 // No move
472 GeneratorParamBase(GeneratorParamBase &&) = delete;
473 void operator=(GeneratorParamBase &&) = delete;
474
475 // Generator which owns this GeneratorParam. Note that this will be null
476 // initially; the GeneratorBase itself will set this field when it initially
477 // builds its info about params. However, since it (generally) isn't
478 // appropriate for GeneratorParam<> to be declared outside of a Generator,
479 // all reasonable non-testing code should expect this to be non-null.
480 GeneratorBase *generator{nullptr};
481 };
482
483 // This is strictly some syntactic sugar to suppress certain compiler warnings.
484 template<typename FROM, typename TO>
485 struct Convert {
486 template<typename TO2 = TO, typename std::enable_if<!std::is_same<TO2, bool>::value>::type * = nullptr>
487 inline static TO2 value(const FROM &from) {
488 return static_cast<TO2>(from);
489 }
490
491 template<typename TO2 = TO, typename std::enable_if<std::is_same<TO2, bool>::value>::type * = nullptr>
492 inline static TO2 value(const FROM &from) {
493 return from != 0;
494 }
495 };
496
497 template<typename T>
498 class GeneratorParamImpl : public GeneratorParamBase {
499 public:
500 using type = T;
501
502 GeneratorParamImpl(const std::string &name, const T &value)
503 : GeneratorParamBase(name), value_(value) {
504 }
505
506 T value() const {
507 this->check_value_readable();
508 return value_;
509 }
510
511 operator T() const {
512 return this->value();
513 }
514
515 operator Expr() const {
516 return make_const(type_of<T>(), this->value());
517 }
518
519 #define HALIDE_GENERATOR_PARAM_TYPED_SETTER(TYPE) \
520 void set(const TYPE &new_value) override { \
521 typed_setter_impl<TYPE>(new_value, #TYPE); \
522 }
523
524 HALIDE_GENERATOR_PARAM_TYPED_SETTER(bool)
525 HALIDE_GENERATOR_PARAM_TYPED_SETTER(int8_t)
526 HALIDE_GENERATOR_PARAM_TYPED_SETTER(int16_t)
527 HALIDE_GENERATOR_PARAM_TYPED_SETTER(int32_t)
528 HALIDE_GENERATOR_PARAM_TYPED_SETTER(int64_t)
529 HALIDE_GENERATOR_PARAM_TYPED_SETTER(uint8_t)
530 HALIDE_GENERATOR_PARAM_TYPED_SETTER(uint16_t)
531 HALIDE_GENERATOR_PARAM_TYPED_SETTER(uint32_t)
532 HALIDE_GENERATOR_PARAM_TYPED_SETTER(uint64_t)
533 HALIDE_GENERATOR_PARAM_TYPED_SETTER(float)
534 HALIDE_GENERATOR_PARAM_TYPED_SETTER(double)
535 HALIDE_GENERATOR_PARAM_TYPED_SETTER(Target)
536 HALIDE_GENERATOR_PARAM_TYPED_SETTER(MachineParams)
537 HALIDE_GENERATOR_PARAM_TYPED_SETTER(Type)
538 HALIDE_GENERATOR_PARAM_TYPED_SETTER(LoopLevel)
539
540 #undef HALIDE_GENERATOR_PARAM_TYPED_SETTER
541
542 // Overload for std::string.
543 void set(const std::string &new_value) {
544 check_value_writable();
545 value_ = new_value;
546 }
547
548 protected:
549 virtual void set_impl(const T &new_value) {
550 check_value_writable();
551 value_ = new_value;
552 }
553
554 // Needs to be protected to allow GeneratorParam<LoopLevel>::set() override
555 T value_;
556
557 private:
558 // If FROM->T is not legal, fail
559 template<typename FROM, typename std::enable_if<
560 !std::is_convertible<FROM, T>::value>::type * = nullptr>
561 HALIDE_ALWAYS_INLINE void typed_setter_impl(const FROM &, const char *msg) {
562 fail_wrong_type(msg);
563 }
564
565 // If FROM and T are identical, just assign
566 template<typename FROM, typename std::enable_if<
567 std::is_same<FROM, T>::value>::type * = nullptr>
568 HALIDE_ALWAYS_INLINE void typed_setter_impl(const FROM &value, const char *msg) {
569 check_value_writable();
570 value_ = value;
571 }
572
573 // If both FROM->T and T->FROM are legal, ensure it's lossless
574 template<typename FROM, typename std::enable_if<
575 !std::is_same<FROM, T>::value &&
576 std::is_convertible<FROM, T>::value &&
577 std::is_convertible<T, FROM>::value>::type * = nullptr>
578 HALIDE_ALWAYS_INLINE void typed_setter_impl(const FROM &value, const char *msg) {
579 check_value_writable();
580 const T t = Convert<FROM, T>::value(value);
581 const FROM value2 = Convert<T, FROM>::value(t);
582 if (value2 != value) {
583 fail_wrong_type(msg);
584 }
585 value_ = t;
586 }
587
588 // If FROM->T is legal but T->FROM is not, just assign
589 template<typename FROM, typename std::enable_if<
590 !std::is_same<FROM, T>::value &&
591 std::is_convertible<FROM, T>::value &&
592 !std::is_convertible<T, FROM>::value>::type * = nullptr>
593 HALIDE_ALWAYS_INLINE void typed_setter_impl(const FROM &value, const char *msg) {
594 check_value_writable();
595 value_ = value;
596 }
597 };
598
599 // Stubs for type-specific implementations of GeneratorParam, to avoid
600 // many complex enable_if<> statements that were formerly spread through the
601 // implementation. Note that not all of these need to be templated classes,
602 // (e.g. for GeneratorParam_Target, T == Target always), but are declared
603 // that way for symmetry of declaration.
604 template<typename T>
605 class GeneratorParam_Target : public GeneratorParamImpl<T> {
606 public:
607 GeneratorParam_Target(const std::string &name, const T &value)
608 : GeneratorParamImpl<T>(name, value) {
609 }
610
611 void set_from_string(const std::string &new_value_string) override {
612 this->set(Target(new_value_string));
613 }
614
615 std::string get_default_value() const override {
616 return this->value().to_string();
617 }
618
619 std::string call_to_string(const std::string &v) const override {
620 std::ostringstream oss;
621 oss << v << ".to_string()";
622 return oss.str();
623 }
624
625 std::string get_c_type() const override {
626 return "Target";
627 }
628 };
629
630 template<typename T>
631 class GeneratorParam_MachineParams : public GeneratorParamImpl<T> {
632 public:
633 GeneratorParam_MachineParams(const std::string &name, const T &value)
634 : GeneratorParamImpl<T>(name, value) {
635 }
636
637 void set_from_string(const std::string &new_value_string) override {
638 this->set(MachineParams(new_value_string));
639 }
640
641 std::string get_default_value() const override {
642 return this->value().to_string();
643 }
644
645 std::string call_to_string(const std::string &v) const override {
646 std::ostringstream oss;
647 oss << v << ".to_string()";
648 return oss.str();
649 }
650
651 std::string get_c_type() const override {
652 return "MachineParams";
653 }
654 };
655
656 class GeneratorParam_LoopLevel : public GeneratorParamImpl<LoopLevel> {
657 public:
658 GeneratorParam_LoopLevel(const std::string &name, const LoopLevel &value)
659 : GeneratorParamImpl<LoopLevel>(name, value) {
660 }
661
662 using GeneratorParamImpl<LoopLevel>::set;
663
664 void set(const LoopLevel &value) override {
665 // Don't call check_value_writable(): It's OK to set a LoopLevel after generate().
666 // check_value_writable();
667
668 // This looks odd, but is deliberate:
669
670 // First, mutate the existing contents to match the value passed in,
671 // so that any existing usage of the LoopLevel now uses the newer value.
672 // (Strictly speaking, this is really only necessary if this method
673 // is called after generate(): before generate(), there is no usage
674 // to be concerned with.)
675 value_.set(value);
676
677 // Then, reset the value itself so that it points to the same LoopLevelContents
678 // as the value passed in. (Strictly speaking, this is really only
679 // useful if this method is called before generate(): afterwards, it's
680 // too late to alter the code to refer to a different LoopLevelContents.)
681 value_ = value;
682 }
683
684 void set_from_string(const std::string &new_value_string) override {
685 if (new_value_string == "root") {
686 this->set(LoopLevel::root());
687 } else if (new_value_string == "inlined") {
688 this->set(LoopLevel::inlined());
689 } else {
690 user_error << "Unable to parse " << this->name << ": " << new_value_string;
691 }
692 }
693
694 std::string get_default_value() const override {
695 // This is dodgy but safe in this case: we want to
696 // see what the value of our LoopLevel is *right now*,
697 // so we make a copy and lock the copy so we can inspect it.
698 // (Note that ordinarily this is a bad idea, since LoopLevels
699 // can be mutated later on; however, this method is only
700 // called by the Generator infrastructure, on LoopLevels that
701 // will never be mutated, so this is really just an elaborate way
702 // to avoid runtime assertions.)
703 LoopLevel copy;
704 copy.set(this->value());
705 copy.lock();
706 if (copy.is_inlined()) {
707 return "LoopLevel::inlined()";
708 } else if (copy.is_root()) {
709 return "LoopLevel::root()";
710 } else {
711 internal_error;
712 return "";
713 }
714 }
715
716 std::string call_to_string(const std::string &v) const override {
717 internal_error;
718 return std::string();
719 }
720
721 std::string get_c_type() const override {
722 return "LoopLevel";
723 }
724
725 bool is_looplevel_param() const override {
726 return true;
727 }
728 };
729
730 template<typename T>
731 class GeneratorParam_Arithmetic : public GeneratorParamImpl<T> {
732 public:
733 GeneratorParam_Arithmetic(const std::string &name,
734 const T &value,
735 const T &min = std::numeric_limits<T>::lowest(),
736 const T &max = std::numeric_limits<T>::max())
737 : GeneratorParamImpl<T>(name, value), min(min), max(max) {
738 // call set() to ensure value is clamped to min/max
739 this->set(value);
740 }
741
742 void set_impl(const T &new_value) override {
743 user_assert(new_value >= min && new_value <= max) << "Value out of range: " << new_value;
744 GeneratorParamImpl<T>::set_impl(new_value);
745 }
746
747 void set_from_string(const std::string &new_value_string) override {
748 std::istringstream iss(new_value_string);
749 T t;
750 // All one-byte ints int8 and uint8 should be parsed as integers, not chars --
751 // including 'char' itself. (Note that sizeof(bool) is often-but-not-always-1,
752 // so be sure to exclude that case.)
753 if (sizeof(T) == sizeof(char) && !std::is_same<T, bool>::value) {
754 int i;
755 iss >> i;
756 t = (T)i;
757 } else {
758 iss >> t;
759 }
760 user_assert(!iss.fail() && iss.get() == EOF) << "Unable to parse: " << new_value_string;
761 this->set(t);
762 }
763
764 std::string get_default_value() const override {
765 std::ostringstream oss;
766 oss << this->value();
767 if (std::is_same<T, float>::value) {
768 // If the constant has no decimal point ("1")
769 // we must append one before appending "f"
770 if (oss.str().find(".") == std::string::npos) {
771 oss << ".";
772 }
773 oss << "f";
774 }
775 return oss.str();
776 }
777
778 std::string call_to_string(const std::string &v) const override {
779 std::ostringstream oss;
780 oss << "std::to_string(" << v << ")";
781 return oss.str();
782 }
783
784 std::string get_c_type() const override {
785 std::ostringstream oss;
786 if (std::is_same<T, float>::value) {
787 return "float";
788 } else if (std::is_same<T, double>::value) {
789 return "double";
790 } else if (std::is_integral<T>::value) {
791 if (std::is_unsigned<T>::value) {
792 oss << "u";
793 }
794 oss << "int" << (sizeof(T) * 8) << "_t";
795 return oss.str();
796 } else {
797 user_error << "Unknown arithmetic type\n";
798 return "";
799 }
800 }
801
802 private:
803 const T min, max;
804 };
805
806 template<typename T>
807 class GeneratorParam_Bool : public GeneratorParam_Arithmetic<T> {
808 public:
809 GeneratorParam_Bool(const std::string &name, const T &value)
810 : GeneratorParam_Arithmetic<T>(name, value) {
811 }
812
813 void set_from_string(const std::string &new_value_string) override {
814 bool v = false;
815 if (new_value_string == "true" || new_value_string == "True") {
816 v = true;
817 } else if (new_value_string == "false" || new_value_string == "False") {
818 v = false;
819 } else {
820 user_assert(false) << "Unable to parse bool: " << new_value_string;
821 }
822 this->set(v);
823 }
824
825 std::string get_default_value() const override {
826 return this->value() ? "true" : "false";
827 }
828
829 std::string call_to_string(const std::string &v) const override {
830 std::ostringstream oss;
831 oss << "std::string((" << v << ") ? \"true\" : \"false\")";
832 return oss.str();
833 }
834
835 std::string get_c_type() const override {
836 return "bool";
837 }
838 };
839
840 template<typename T>
841 class GeneratorParam_Enum : public GeneratorParamImpl<T> {
842 public:
843 GeneratorParam_Enum(const std::string &name, const T &value, const std::map<std::string, T> &enum_map)
844 : GeneratorParamImpl<T>(name, value), enum_map(enum_map) {
845 }
846
847 // define a "set" that takes our specific enum (but don't hide the inherited virtual functions)
848 using GeneratorParamImpl<T>::set;
849
850 template<typename T2 = T, typename std::enable_if<!std::is_same<T2, Type>::value>::type * = nullptr>
851 void set(const T &e) {
852 this->set_impl(e);
853 }
854
855 void set_from_string(const std::string &new_value_string) override {
856 auto it = enum_map.find(new_value_string);
857 user_assert(it != enum_map.end()) << "Enumeration value not found: " << new_value_string;
858 this->set_impl(it->second);
859 }
860
861 std::string call_to_string(const std::string &v) const override {
862 return "Enum_" + this->name + "_map().at(" + v + ")";
863 }
864
865 std::string get_c_type() const override {
866 return "Enum_" + this->name;
867 }
868
869 std::string get_default_value() const override {
870 return "Enum_" + this->name + "::" + enum_to_string(enum_map, this->value());
871 }
872
873 std::string get_type_decls() const override {
874 std::ostringstream oss;
875 oss << "enum class Enum_" << this->name << " {\n";
876 for (auto key_value : enum_map) {
877 oss << " " << key_value.first << ",\n";
878 }
879 oss << "};\n";
880 oss << "\n";
881
882 // TODO: since we generate the enums, we could probably just use a vector (or array!) rather than a map,
883 // since we can ensure that the enum values are a nice tight range.
884 oss << "inline HALIDE_NO_USER_CODE_INLINE const std::map<Enum_" << this->name << ", std::string>& Enum_" << this->name << "_map() {\n";
885 oss << " static const std::map<Enum_" << this->name << ", std::string> m = {\n";
886 for (auto key_value : enum_map) {
887 oss << " { Enum_" << this->name << "::" << key_value.first << ", \"" << key_value.first << "\"},\n";
888 }
889 oss << " };\n";
890 oss << " return m;\n";
891 oss << "};\n";
892 return oss.str();
893 }
894
895 private:
896 const std::map<std::string, T> enum_map;
897 };
898
899 template<typename T>
900 class GeneratorParam_Type : public GeneratorParam_Enum<T> {
901 public:
902 GeneratorParam_Type(const std::string &name, const T &value)
903 : GeneratorParam_Enum<T>(name, value, get_halide_type_enum_map()) {
904 }
905
906 std::string call_to_string(const std::string &v) const override {
907 return "Halide::Internal::halide_type_to_enum_string(" + v + ")";
908 }
909
910 std::string get_c_type() const override {
911 return "Type";
912 }
913
914 std::string get_default_value() const override {
915 return halide_type_to_c_source(this->value());
916 }
917
918 std::string get_type_decls() const override {
919 return "";
920 }
921 };
922
923 template<typename T>
924 class GeneratorParam_String : public Internal::GeneratorParamImpl<T> {
925 public:
926 GeneratorParam_String(const std::string &name, const std::string &value)
927 : GeneratorParamImpl<T>(name, value) {
928 }
929 void set_from_string(const std::string &new_value_string) override {
930 this->set(new_value_string);
931 }
932
933 std::string get_default_value() const override {
934 return "\"" + this->value() + "\"";
935 }
936
937 std::string call_to_string(const std::string &v) const override {
938 return v;
939 }
940
941 std::string get_c_type() const override {
942 return "std::string";
943 }
944 };
945
946 template<typename T>
947 using GeneratorParamImplBase =
948 typename select_type<
949 cond<std::is_same<T, Target>::value, GeneratorParam_Target<T>>,
950 cond<std::is_same<T, MachineParams>::value, GeneratorParam_MachineParams<T>>,
951 cond<std::is_same<T, LoopLevel>::value, GeneratorParam_LoopLevel>,
952 cond<std::is_same<T, std::string>::value, GeneratorParam_String<T>>,
953 cond<std::is_same<T, Type>::value, GeneratorParam_Type<T>>,
954 cond<std::is_same<T, bool>::value, GeneratorParam_Bool<T>>,
955 cond<std::is_arithmetic<T>::value, GeneratorParam_Arithmetic<T>>,
956 cond<std::is_enum<T>::value, GeneratorParam_Enum<T>>>::type;
957
958 } // namespace Internal
959
960 /** GeneratorParam is a templated class that can be used to modify the behavior
961 * of the Generator at code-generation time. GeneratorParams are commonly
962 * specified in build files (e.g. Makefile) to customize the behavior of
963 * a given Generator, thus they have a very constrained set of types to allow
964 * for efficient specification via command-line flags. A GeneratorParam can be:
965 * - any float or int type.
966 * - bool
967 * - enum
968 * - Halide::Target
969 * - Halide::Type
970 * - std::string
971 * Please don't use std::string unless there's no way to do what you want with some
972 * other type; in particular, don't use this if you can use enum instead.
973 * All GeneratorParams have a default value. Arithmetic types can also
974 * optionally specify min and max. Enum types must specify a string-to-value
975 * map.
976 *
977 * Halide::Type is treated as though it were an enum, with the mappings:
978 *
979 * "int8" Halide::Int(8)
980 * "int16" Halide::Int(16)
981 * "int32" Halide::Int(32)
982 * "uint8" Halide::UInt(8)
983 * "uint16" Halide::UInt(16)
984 * "uint32" Halide::UInt(32)
985 * "float32" Halide::Float(32)
986 * "float64" Halide::Float(64)
987 *
988 * No vector Types are currently supported by this mapping.
989 *
990 */
991 template<typename T>
992 class GeneratorParam : public Internal::GeneratorParamImplBase<T> {
993 public:
994 template<typename T2 = T, typename std::enable_if<!std::is_same<T2, std::string>::value>::type * = nullptr>
995 GeneratorParam(const std::string &name, const T &value)
996 : Internal::GeneratorParamImplBase<T>(name, value) {
997 }
998
999 GeneratorParam(const std::string &name, const T &value, const T &min, const T &max)
1000 : Internal::GeneratorParamImplBase<T>(name, value, min, max) {
1001 }
1002
1003 GeneratorParam(const std::string &name, const T &value, const std::map<std::string, T> &enum_map)
1004 : Internal::GeneratorParamImplBase<T>(name, value, enum_map) {
1005 }
1006
1007 GeneratorParam(const std::string &name, const std::string &value)
1008 : Internal::GeneratorParamImplBase<T>(name, value) {
1009 }
1010 };
1011
1012 /** Addition between GeneratorParam<T> and any type that supports operator+ with T.
1013 * Returns type of underlying operator+. */
1014 // @{
1015 template<typename Other, typename T>
1016 auto operator+(const Other &a, const GeneratorParam<T> &b) -> decltype(a + (T)b) {
1017 return a + (T)b;
1018 }
1019 template<typename Other, typename T>
1020 auto operator+(const GeneratorParam<T> &a, const Other &b) -> decltype((T)a + b) {
1021 return (T)a + b;
1022 }
1023 // @}
1024
1025 /** Subtraction between GeneratorParam<T> and any type that supports operator- with T.
1026 * Returns type of underlying operator-. */
1027 // @{
1028 template<typename Other, typename T>
1029 auto operator-(const Other &a, const GeneratorParam<T> &b) -> decltype(a - (T)b) {
1030 return a - (T)b;
1031 }
1032 template<typename Other, typename T>
1033 auto operator-(const GeneratorParam<T> &a, const Other &b) -> decltype((T)a - b) {
1034 return (T)a - b;
1035 }
1036 // @}
1037
1038 /** Multiplication between GeneratorParam<T> and any type that supports operator* with T.
1039 * Returns type of underlying operator*. */
1040 // @{
1041 template<typename Other, typename T>
1042 auto operator*(const Other &a, const GeneratorParam<T> &b) -> decltype(a * (T)b) {
1043 return a * (T)b;
1044 }
1045 template<typename Other, typename T>
1046 auto operator*(const GeneratorParam<T> &a, const Other &b) -> decltype((T)a * b) {
1047 return (T)a * b;
1048 }
1049 // @}
1050
1051 /** Division between GeneratorParam<T> and any type that supports operator/ with T.
1052 * Returns type of underlying operator/. */
1053 // @{
1054 template<typename Other, typename T>
1055 auto operator/(const Other &a, const GeneratorParam<T> &b) -> decltype(a / (T)b) {
1056 return a / (T)b;
1057 }
1058 template<typename Other, typename T>
1059 auto operator/(const GeneratorParam<T> &a, const Other &b) -> decltype((T)a / b) {
1060 return (T)a / b;
1061 }
1062 // @}
1063
1064 /** Modulo between GeneratorParam<T> and any type that supports operator% with T.
1065 * Returns type of underlying operator%. */
1066 // @{
1067 template<typename Other, typename T>
1068 auto operator%(const Other &a, const GeneratorParam<T> &b) -> decltype(a % (T)b) {
1069 return a % (T)b;
1070 }
1071 template<typename Other, typename T>
1072 auto operator%(const GeneratorParam<T> &a, const Other &b) -> decltype((T)a % b) {
1073 return (T)a % b;
1074 }
1075 // @}
1076
1077 /** Greater than comparison between GeneratorParam<T> and any type that supports operator> with T.
1078 * Returns type of underlying operator>. */
1079 // @{
1080 template<typename Other, typename T>
1081 auto operator>(const Other &a, const GeneratorParam<T> &b) -> decltype(a > (T)b) {
1082 return a > (T)b;
1083 }
1084 template<typename Other, typename T>
1085 auto operator>(const GeneratorParam<T> &a, const Other &b) -> decltype((T)a > b) {
1086 return (T)a > b;
1087 }
1088 // @}
1089
1090 /** Less than comparison between GeneratorParam<T> and any type that supports operator< with T.
1091 * Returns type of underlying operator<. */
1092 // @{
1093 template<typename Other, typename T>
1094 auto operator<(const Other &a, const GeneratorParam<T> &b) -> decltype(a < (T)b) {
1095 return a < (T)b;
1096 }
1097 template<typename Other, typename T>
1098 auto operator<(const GeneratorParam<T> &a, const Other &b) -> decltype((T)a < b) {
1099 return (T)a < b;
1100 }
1101 // @}
1102
1103 /** Greater than or equal comparison between GeneratorParam<T> and any type that supports operator>= with T.
1104 * Returns type of underlying operator>=. */
1105 // @{
1106 template<typename Other, typename T>
1107 auto operator>=(const Other &a, const GeneratorParam<T> &b) -> decltype(a >= (T)b) {
1108 return a >= (T)b;
1109 }
1110 template<typename Other, typename T>
1111 auto operator>=(const GeneratorParam<T> &a, const Other &b) -> decltype((T)a >= b) {
1112 return (T)a >= b;
1113 }
1114 // @}
1115
1116 /** Less than or equal comparison between GeneratorParam<T> and any type that supports operator<= with T.
1117 * Returns type of underlying operator<=. */
1118 // @{
1119 template<typename Other, typename T>
1120 auto operator<=(const Other &a, const GeneratorParam<T> &b) -> decltype(a <= (T)b) {
1121 return a <= (T)b;
1122 }
1123 template<typename Other, typename T>
1124 auto operator<=(const GeneratorParam<T> &a, const Other &b) -> decltype((T)a <= b) {
1125 return (T)a <= b;
1126 }
1127 // @}
1128
1129 /** Equality comparison between GeneratorParam<T> and any type that supports operator== with T.
1130 * Returns type of underlying operator==. */
1131 // @{
1132 template<typename Other, typename T>
1133 auto operator==(const Other &a, const GeneratorParam<T> &b) -> decltype(a == (T)b) {
1134 return a == (T)b;
1135 }
1136 template<typename Other, typename T>
1137 auto operator==(const GeneratorParam<T> &a, const Other &b) -> decltype((T)a == b) {
1138 return (T)a == b;
1139 }
1140 // @}
1141
1142 /** Inequality comparison between between GeneratorParam<T> and any type that supports operator!= with T.
1143 * Returns type of underlying operator!=. */
1144 // @{
1145 template<typename Other, typename T>
1146 auto operator!=(const Other &a, const GeneratorParam<T> &b) -> decltype(a != (T)b) {
1147 return a != (T)b;
1148 }
1149 template<typename Other, typename T>
1150 auto operator!=(const GeneratorParam<T> &a, const Other &b) -> decltype((T)a != b) {
1151 return (T)a != b;
1152 }
1153 // @}
1154
1155 /** Logical and between between GeneratorParam<T> and any type that supports operator&& with T.
1156 * Returns type of underlying operator&&. */
1157 // @{
1158 template<typename Other, typename T>
1159 auto operator&&(const Other &a, const GeneratorParam<T> &b) -> decltype(a && (T)b) {
1160 return a && (T)b;
1161 }
1162 template<typename Other, typename T>
1163 auto operator&&(const GeneratorParam<T> &a, const Other &b) -> decltype((T)a && b) {
1164 return (T)a && b;
1165 }
1166 template<typename T>
1167 auto operator&&(const GeneratorParam<T> &a, const GeneratorParam<T> &b) -> decltype((T)a && (T)b) {
1168 return (T)a && (T)b;
1169 }
1170 // @}
1171
1172 /** Logical or between between GeneratorParam<T> and any type that supports operator|| with T.
1173 * Returns type of underlying operator||. */
1174 // @{
1175 template<typename Other, typename T>
1176 auto operator||(const Other &a, const GeneratorParam<T> &b) -> decltype(a || (T)b) {
1177 return a || (T)b;
1178 }
1179 template<typename Other, typename T>
1180 auto operator||(const GeneratorParam<T> &a, const Other &b) -> decltype((T)a || b) {
1181 return (T)a || b;
1182 }
1183 template<typename T>
1184 auto operator||(const GeneratorParam<T> &a, const GeneratorParam<T> &b) -> decltype((T)a || (T)b) {
1185 return (T)a || (T)b;
1186 }
1187 // @}
1188
1189 /* min and max are tricky as the language support for these is in the std
1190 * namespace. In order to make this work, forwarding functions are used that
1191 * are declared in a namespace that has std::min and std::max in scope.
1192 */
1193 namespace Internal {
1194 namespace GeneratorMinMax {
1195
1196 using std::max;
1197 using std::min;
1198
1199 template<typename Other, typename T>
1200 auto min_forward(const Other &a, const GeneratorParam<T> &b) -> decltype(min(a, (T)b)) {
1201 return min(a, (T)b);
1202 }
1203 template<typename Other, typename T>
1204 auto min_forward(const GeneratorParam<T> &a, const Other &b) -> decltype(min((T)a, b)) {
1205 return min((T)a, b);
1206 }
1207
1208 template<typename Other, typename T>
1209 auto max_forward(const Other &a, const GeneratorParam<T> &b) -> decltype(max(a, (T)b)) {
1210 return max(a, (T)b);
1211 }
1212 template<typename Other, typename T>
1213 auto max_forward(const GeneratorParam<T> &a, const Other &b) -> decltype(max((T)a, b)) {
1214 return max((T)a, b);
1215 }
1216
1217 } // namespace GeneratorMinMax
1218 } // namespace Internal
1219
1220 /** Compute minimum between GeneratorParam<T> and any type that supports min with T.
1221 * Will automatically import std::min. Returns type of underlying min call. */
1222 // @{
1223 template<typename Other, typename T>
1224 auto min(const Other &a, const GeneratorParam<T> &b) -> decltype(Internal::GeneratorMinMax::min_forward(a, b)) {
1225 return Internal::GeneratorMinMax::min_forward(a, b);
1226 }
1227 template<typename Other, typename T>
1228 auto min(const GeneratorParam<T> &a, const Other &b) -> decltype(Internal::GeneratorMinMax::min_forward(a, b)) {
1229 return Internal::GeneratorMinMax::min_forward(a, b);
1230 }
1231 // @}
1232
1233 /** Compute the maximum value between GeneratorParam<T> and any type that supports max with T.
1234 * Will automatically import std::max. Returns type of underlying max call. */
1235 // @{
1236 template<typename Other, typename T>
1237 auto max(const Other &a, const GeneratorParam<T> &b) -> decltype(Internal::GeneratorMinMax::max_forward(a, b)) {
1238 return Internal::GeneratorMinMax::max_forward(a, b);
1239 }
1240 template<typename Other, typename T>
1241 auto max(const GeneratorParam<T> &a, const Other &b) -> decltype(Internal::GeneratorMinMax::max_forward(a, b)) {
1242 return Internal::GeneratorMinMax::max_forward(a, b);
1243 }
1244 // @}
1245
1246 /** Not operator for GeneratorParam */
1247 template<typename T>
1248 auto operator!(const GeneratorParam<T> &a) -> decltype(!(T)a) {
1249 return !(T)a;
1250 }
1251
1252 namespace Internal {
1253
1254 template<typename T2>
1255 class GeneratorInput_Buffer;
1256
1257 enum class IOKind { Scalar,
1258 Function,
1259 Buffer };
1260
1261 /**
1262 * StubInputBuffer is the placeholder that a Stub uses when it requires
1263 * a Buffer for an input (rather than merely a Func or Expr). It is constructed
1264 * to allow only two possible sorts of input:
1265 * -- Assignment of an Input<Buffer<>>, with compatible type and dimensions,
1266 * essentially allowing us to pipe a parameter from an enclosing Generator to an internal Stub.
1267 * -- Assignment of a Buffer<>, with compatible type and dimensions,
1268 * causing the Input<Buffer<>> to become a precompiled buffer in the generated code.
1269 */
1270 template<typename T = void>
1271 class StubInputBuffer {
1272 friend class StubInput;
1273 template<typename T2>
1274 friend class GeneratorInput_Buffer;
1275
1276 Parameter parameter_;
1277
1278 HALIDE_NO_USER_CODE_INLINE explicit StubInputBuffer(const Parameter &p)
1279 : parameter_(p) {
1280 // Create an empty 1-element buffer with the right runtime typing and dimensions,
1281 // which we'll use only to pass to can_convert_from() to verify this
1282 // Parameter is compatible with our constraints.
1283 Buffer<> other(p.type(), nullptr, std::vector<int>(p.dimensions(), 1));
1284 internal_assert((Buffer<T>::can_convert_from(other)));
1285 }
1286
1287 template<typename T2>
1288 HALIDE_NO_USER_CODE_INLINE static Parameter parameter_from_buffer(const Buffer<T2> &b) {
1289 user_assert((Buffer<T>::can_convert_from(b)));
1290 Parameter p(b.type(), true, b.dimensions());
1291 p.set_buffer(b);
1292 return p;
1293 }
1294
1295 public:
1296 StubInputBuffer() = default;
1297
1298 // *not* explicit -- this ctor should only be used when you want
1299 // to pass a literal Buffer<> for a Stub Input; this Buffer<> will be
1300 // compiled into the Generator's product, rather than becoming
1301 // a runtime Parameter.
1302 template<typename T2>
1303 StubInputBuffer(const Buffer<T2> &b)
1304 : parameter_(parameter_from_buffer(b)) {
1305 }
1306 };
1307
1308 class StubOutputBufferBase {
1309 protected:
1310 Func f;
1311 std::shared_ptr<GeneratorBase> generator;
1312
1313 void check_scheduled(const char *m) const;
1314 Target get_target() const;
1315
1316 explicit StubOutputBufferBase(const Func &f, std::shared_ptr<GeneratorBase> generator)
1317 : f(f), generator(std::move(generator)) {
1318 }
1319 StubOutputBufferBase() = default;
1320
1321 public:
1322 Realization realize(std::vector<int32_t> sizes) {
1323 check_scheduled("realize");
1324 return f.realize(std::move(sizes), get_target());
1325 }
1326
1327 template<typename... Args>
1328 Realization realize(Args &&... args) {
1329 check_scheduled("realize");
1330 return f.realize(std::forward<Args>(args)..., get_target());
1331 }
1332
1333 template<typename Dst>
1334 void realize(Dst dst) {
1335 check_scheduled("realize");
1336 f.realize(dst, get_target());
1337 }
1338 };
1339
1340 /**
1341 * StubOutputBuffer is the placeholder that a Stub uses when it requires
1342 * a Buffer for an output (rather than merely a Func). It is constructed
1343 * to allow only two possible sorts of things:
1344 * -- Assignment to an Output<Buffer<>>, with compatible type and dimensions,
1345 * essentially allowing us to pipe a parameter from the result of a Stub to an
1346 * enclosing Generator
1347 * -- Realization into a Buffer<>; this is useful only in JIT compilation modes
1348 * (and shouldn't be usable otherwise)
1349 *
1350 * It is deliberate that StubOutputBuffer is not (easily) convertible to Func.
1351 */
1352 template<typename T = void>
1353 class StubOutputBuffer : public StubOutputBufferBase {
1354 template<typename T2>
1355 friend class GeneratorOutput_Buffer;
1356 friend class GeneratorStub;
1357 explicit StubOutputBuffer(const Func &f, const std::shared_ptr<GeneratorBase> &generator)
1358 : StubOutputBufferBase(f, generator) {
1359 }
1360
1361 public:
1362 StubOutputBuffer() = default;
1363 };
1364
1365 // This is a union-like class that allows for convenient initialization of Stub Inputs
1366 // via C++11 initializer-list syntax; it is only used in situations where the
1367 // downstream consumer will be able to explicitly check that each value is
1368 // of the expected/required kind.
1369 class StubInput {
1370 const IOKind kind_;
1371 // Exactly one of the following fields should be defined:
1372 const Parameter parameter_;
1373 const Func func_;
1374 const Expr expr_;
1375
1376 public:
1377 // *not* explicit.
1378 template<typename T2>
1379 StubInput(const StubInputBuffer<T2> &b)
1380 : kind_(IOKind::Buffer), parameter_(b.parameter_), func_(), expr_() {
1381 }
1382 StubInput(const Func &f)
1383 : kind_(IOKind::Function), parameter_(), func_(f), expr_() {
1384 }
1385 StubInput(const Expr &e)
1386 : kind_(IOKind::Scalar), parameter_(), func_(), expr_(e) {
1387 }
1388
1389 private:
1390 friend class GeneratorInputBase;
1391
1392 IOKind kind() const {
1393 return kind_;
1394 }
1395
1396 Parameter parameter() const {
1397 internal_assert(kind_ == IOKind::Buffer);
1398 return parameter_;
1399 }
1400
1401 Func func() const {
1402 internal_assert(kind_ == IOKind::Function);
1403 return func_;
1404 }
1405
1406 Expr expr() const {
1407 internal_assert(kind_ == IOKind::Scalar);
1408 return expr_;
1409 }
1410 };
1411
1412 /** GIOBase is the base class for all GeneratorInput<> and GeneratorOutput<>
1413 * instantiations; it is not part of the public API and should never be
1414 * used directly by user code.
1415 *
1416 * Every GIOBase instance can be either a single value or an array-of-values;
1417 * each of these values can be an Expr or a Func. (Note that for an
1418 * array-of-values, the types/dimensions of all values in the array must match.)
1419 *
1420 * A GIOBase can have multiple Types, in which case it represents a Tuple.
1421 * (Note that Tuples are currently only supported for GeneratorOutput, but
1422 * it is likely that GeneratorInput will be extended to support Tuple as well.)
1423 *
1424 * The array-size, type(s), and dimensions can all be left "unspecified" at
1425 * creation time, in which case they may assume values provided by a Stub.
1426 * (It is important to note that attempting to use a GIOBase with unspecified
1427 * values will assert-fail; you must ensure that all unspecified values are
1428 * filled in prior to use.)
1429 */
1430 class GIOBase {
1431 public:
1432 bool array_size_defined() const;
1433 size_t array_size() const;
1434 virtual bool is_array() const;
1435
1436 const std::string &name() const;
1437 IOKind kind() const;
1438
1439 bool types_defined() const;
1440 const std::vector<Type> &types() const;
1441 Type type() const;
1442
1443 bool dims_defined() const;
1444 int dims() const;
1445
1446 const std::vector<Func> &funcs() const;
1447 const std::vector<Expr> &exprs() const;
1448
1449 virtual ~GIOBase();
1450
1451 protected:
1452 GIOBase(size_t array_size,
1453 const std::string &name,
1454 IOKind kind,
1455 const std::vector<Type> &types,
1456 int dims);
1457
1458 friend class GeneratorBase;
1459 friend class GeneratorParamInfo;
1460
1461 mutable int array_size_; // always 1 if is_array() == false.
1462 // -1 if is_array() == true but unspecified.
1463
1464 const std::string name_;
1465 const IOKind kind_;
1466 mutable std::vector<Type> types_; // empty if type is unspecified
1467 mutable int dims_; // -1 if dim is unspecified
1468
1469 // Exactly one of these will have nonzero length
1470 std::vector<Func> funcs_;
1471 std::vector<Expr> exprs_;
1472
1473 // Generator which owns this Input or Output. Note that this will be null
1474 // initially; the GeneratorBase itself will set this field when it initially
1475 // builds its info about params. However, since it isn't
1476 // appropriate for Input<> or Output<> to be declared outside of a Generator,
1477 // all reasonable non-testing code should expect this to be non-null.
1478 GeneratorBase *generator{nullptr};
1479
1480 std::string array_name(size_t i) const;
1481
1482 virtual void verify_internals();
1483
1484 void check_matching_array_size(size_t size) const;
1485 void check_matching_types(const std::vector<Type> &t) const;
1486 void check_matching_dims(int d) const;
1487
1488 template<typename ElemType>
1489 const std::vector<ElemType> &get_values() const;
1490
1491 void check_gio_access() const;
1492
1493 virtual void check_value_writable() const = 0;
1494
1495 virtual const char *input_or_output() const = 0;
1496
1497 private:
1498 template<typename T>
1499 friend class GeneratorParam_Synthetic;
1500
1501 // No copy
1502 GIOBase(const GIOBase &) = delete;
1503 void operator=(const GIOBase &) = delete;
1504 // No move
1505 GIOBase(GIOBase &&) = delete;
1506 void operator=(GIOBase &&) = delete;
1507 };
1508
1509 template<>
1510 inline const std::vector<Expr> &GIOBase::get_values<Expr>() const {
1511 return exprs();
1512 }
1513
1514 template<>
1515 inline const std::vector<Func> &GIOBase::get_values<Func>() const {
1516 return funcs();
1517 }
1518
1519 class GeneratorInputBase : public GIOBase {
1520 protected:
1521 GeneratorInputBase(size_t array_size,
1522 const std::string &name,
1523 IOKind kind,
1524 const std::vector<Type> &t,
1525 int d);
1526
1527 GeneratorInputBase(const std::string &name, IOKind kind, const std::vector<Type> &t, int d);
1528
1529 friend class GeneratorBase;
1530 friend class GeneratorParamInfo;
1531
1532 std::vector<Parameter> parameters_;
1533
1534 Parameter parameter() const;
1535
1536 void init_internals();
1537 void set_inputs(const std::vector<StubInput> &inputs);
1538
1539 virtual void set_def_min_max();
1540 virtual Expr get_def_expr() const;
1541
1542 void verify_internals() override;
1543
1544 friend class StubEmitter;
1545
1546 virtual std::string get_c_type() const = 0;
1547
1548 void check_value_writable() const override;
1549
1550 const char *input_or_output() const override {
1551 return "Input";
1552 }
1553
1554 void set_estimate_impl(const Var &var, const Expr &min, const Expr &extent);
1555 void set_estimates_impl(const Region &estimates);
1556
1557 public:
1558 ~GeneratorInputBase() override;
1559 };
1560
1561 template<typename T, typename ValueType>
1562 class GeneratorInputImpl : public GeneratorInputBase {
1563 protected:
1564 using TBase = typename std::remove_all_extents<T>::type;
1565
1566 bool is_array() const override {
1567 return std::is_array<T>::value;
1568 }
1569
1570 template<typename T2 = T, typename std::enable_if<
1571 // Only allow T2 not-an-array
1572 !std::is_array<T2>::value>::type * = nullptr>
1573 GeneratorInputImpl(const std::string &name, IOKind kind, const std::vector<Type> &t, int d)
1574 : GeneratorInputBase(name, kind, t, d) {
1575 }
1576
1577 template<typename T2 = T, typename std::enable_if<
1578 // Only allow T2[kSomeConst]
1579 std::is_array<T2>::value && std::rank<T2>::value == 1 && (std::extent<T2, 0>::value > 0)>::type * = nullptr>
1580 GeneratorInputImpl(const std::string &name, IOKind kind, const std::vector<Type> &t, int d)
1581 : GeneratorInputBase(std::extent<T2, 0>::value, name, kind, t, d) {
1582 }
1583
1584 template<typename T2 = T, typename std::enable_if<
1585 // Only allow T2[]
1586 std::is_array<T2>::value && std::rank<T2>::value == 1 && std::extent<T2, 0>::value == 0>::type * = nullptr>
1587 GeneratorInputImpl(const std::string &name, IOKind kind, const std::vector<Type> &t, int d)
1588 : GeneratorInputBase(-1, name, kind, t, d) {
1589 }
1590
1591 public:
1592 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
1593 size_t size() const {
1594 this->check_gio_access();
1595 return get_values<ValueType>().size();
1596 }
1597
1598 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
1599 const ValueType &operator[](size_t i) const {
1600 this->check_gio_access();
1601 return get_values<ValueType>()[i];
1602 }
1603
1604 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
1605 const ValueType &at(size_t i) const {
1606 this->check_gio_access();
1607 return get_values<ValueType>().at(i);
1608 }
1609
1610 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
1611 typename std::vector<ValueType>::const_iterator begin() const {
1612 this->check_gio_access();
1613 return get_values<ValueType>().begin();
1614 }
1615
1616 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
1617 typename std::vector<ValueType>::const_iterator end() const {
1618 this->check_gio_access();
1619 return get_values<ValueType>().end();
1620 }
1621 };
1622
1623 // When forwarding methods to ImageParam, Func, etc., we must take
1624 // care with the return types: many of the methods return a reference-to-self
1625 // (e.g., ImageParam&); since we create temporaries for most of these forwards,
1626 // returning a ref will crater because it refers to a now-defunct section of the
1627 // stack. Happily, simply removing the reference is solves this, since all of the
1628 // types in question satisfy the property of copies referring to the same underlying
1629 // structure (returning references is just an optimization). Since this is verbose
1630 // and used in several places, we'll use a helper macro:
1631 #define HALIDE_FORWARD_METHOD(Class, Method) \
1632 template<typename... Args> \
1633 inline auto Method(Args &&... args)->typename std::remove_reference<decltype(std::declval<Class>().Method(std::forward<Args>(args)...))>::type { \
1634 return this->template as<Class>().Method(std::forward<Args>(args)...); \
1635 }
1636
1637 #define HALIDE_FORWARD_METHOD_CONST(Class, Method) \
1638 template<typename... Args> \
1639 inline auto Method(Args &&... args) const-> \
1640 typename std::remove_reference<decltype(std::declval<Class>().Method(std::forward<Args>(args)...))>::type { \
1641 this->check_gio_access(); \
1642 return this->template as<Class>().Method(std::forward<Args>(args)...); \
1643 }
1644
1645 template<typename T>
1646 class GeneratorInput_Buffer : public GeneratorInputImpl<T, Func> {
1647 private:
1648 using Super = GeneratorInputImpl<T, Func>;
1649
1650 protected:
1651 using TBase = typename Super::TBase;
1652
1653 friend class ::Halide::Func;
1654 friend class ::Halide::Stage;
1655
1656 std::string get_c_type() const override {
1657 if (TBase::has_static_halide_type) {
1658 return "Halide::Internal::StubInputBuffer<" +
1659 halide_type_to_c_type(TBase::static_halide_type()) +
1660 ">";
1661 } else {
1662 return "Halide::Internal::StubInputBuffer<>";
1663 }
1664 }
1665
1666 template<typename T2>
1667 inline T2 as() const {
1668 return (T2) * this;
1669 }
1670
1671 public:
1672 GeneratorInput_Buffer(const std::string &name)
1673 : Super(name, IOKind::Buffer,
1674 TBase::has_static_halide_type ? std::vector<Type>{TBase::static_halide_type()} : std::vector<Type>{},
1675 -1) {
1676 }
1677
1678 GeneratorInput_Buffer(const std::string &name, const Type &t, int d = -1)
1679 : Super(name, IOKind::Buffer, {t}, d) {
1680 static_assert(!TBase::has_static_halide_type, "You can only specify a Type argument for Input<Buffer<T>> if T is void or omitted.");
1681 }
1682
1683 GeneratorInput_Buffer(const std::string &name, int d)
1684 : Super(name, IOKind::Buffer, TBase::has_static_halide_type ? std::vector<Type>{TBase::static_halide_type()} : std::vector<Type>{}, d) {
1685 }
1686
1687 template<typename... Args>
1688 Expr operator()(Args &&... args) const {
1689 this->check_gio_access();
1690 return Func(*this)(std::forward<Args>(args)...);
1691 }
1692
1693 Expr operator()(std::vector<Expr> args) const {
1694 this->check_gio_access();
1695 return Func(*this)(std::move(args));
1696 }
1697
1698 template<typename T2>
1699 operator StubInputBuffer<T2>() const {
1700 user_assert(!this->is_array()) << "Cannot assign an array type to a non-array type for Input " << this->name();
1701 return StubInputBuffer<T2>(this->parameters_.at(0));
1702 }
1703
1704 operator Func() const {
1705 this->check_gio_access();
1706 return this->funcs().at(0);
1707 }
1708
1709 operator ExternFuncArgument() const {
1710 this->check_gio_access();
1711 return ExternFuncArgument(this->parameters_.at(0));
1712 }
1713
1714 GeneratorInput_Buffer<T> &set_estimate(Var var, Expr min, Expr extent) {
1715 this->check_gio_access();
1716 this->set_estimate_impl(var, min, extent);
1717 return *this;
1718 }
1719
1720 HALIDE_ATTRIBUTE_DEPRECATED("Use set_estimate() instead")
1721 GeneratorInput_Buffer<T> &estimate(const Var &var, const Expr &min, const Expr &extent) {
1722 return set_estimate(var, min, extent);
1723 }
1724
1725 GeneratorInput_Buffer<T> &set_estimates(const Region &estimates) {
1726 this->check_gio_access();
1727 this->set_estimates_impl(estimates);
1728 return *this;
1729 }
1730
1731 Func in() {
1732 this->check_gio_access();
1733 return Func(*this).in();
1734 }
1735
1736 Func in(const Func &other) {
1737 this->check_gio_access();
1738 return Func(*this).in(other);
1739 }
1740
1741 Func in(const std::vector<Func> &others) {
1742 this->check_gio_access();
1743 return Func(*this).in(others);
1744 }
1745
1746 operator ImageParam() const {
1747 this->check_gio_access();
1748 user_assert(!this->is_array()) << "Cannot convert an Input<Buffer<>[]> to an ImageParam; use an explicit subscript operator: " << this->name();
1749 return ImageParam(this->parameters_.at(0), Func(*this));
1750 }
1751
1752 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
1753 size_t size() const {
1754 this->check_gio_access();
1755 return this->parameters_.size();
1756 }
1757
1758 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
1759 ImageParam operator[](size_t i) const {
1760 this->check_gio_access();
1761 return ImageParam(this->parameters_.at(i), this->funcs().at(i));
1762 }
1763
1764 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
1765 ImageParam at(size_t i) const {
1766 this->check_gio_access();
1767 return ImageParam(this->parameters_.at(i), this->funcs().at(i));
1768 }
1769
1770 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
1771 typename std::vector<ImageParam>::const_iterator begin() const {
1772 user_error << "Input<Buffer<>>::begin() is not supported.";
1773 return {};
1774 }
1775
1776 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
1777 typename std::vector<ImageParam>::const_iterator end() const {
1778 user_error << "Input<Buffer<>>::end() is not supported.";
1779 return {};
1780 }
1781
1782 /** Forward methods to the ImageParam. */
1783 // @{
1784 HALIDE_FORWARD_METHOD(ImageParam, dim)
1785 HALIDE_FORWARD_METHOD_CONST(ImageParam, dim)
1786 HALIDE_FORWARD_METHOD_CONST(ImageParam, host_alignment)
1787 HALIDE_FORWARD_METHOD(ImageParam, set_host_alignment)
1788 HALIDE_FORWARD_METHOD_CONST(ImageParam, dimensions)
1789 HALIDE_FORWARD_METHOD_CONST(ImageParam, left)
1790 HALIDE_FORWARD_METHOD_CONST(ImageParam, right)
1791 HALIDE_FORWARD_METHOD_CONST(ImageParam, top)
1792 HALIDE_FORWARD_METHOD_CONST(ImageParam, bottom)
1793 HALIDE_FORWARD_METHOD_CONST(ImageParam, width)
1794 HALIDE_FORWARD_METHOD_CONST(ImageParam, height)
1795 HALIDE_FORWARD_METHOD_CONST(ImageParam, channels)
1796 HALIDE_FORWARD_METHOD_CONST(ImageParam, trace_loads)
1797 HALIDE_FORWARD_METHOD_CONST(ImageParam, add_trace_tag)
1798 // }@
1799 };
1800
1801 template<typename T>
1802 class GeneratorInput_Func : public GeneratorInputImpl<T, Func> {
1803 private:
1804 using Super = GeneratorInputImpl<T, Func>;
1805
1806 protected:
1807 using TBase = typename Super::TBase;
1808
1809 std::string get_c_type() const override {
1810 return "Func";
1811 }
1812
1813 template<typename T2>
1814 inline T2 as() const {
1815 return (T2) * this;
1816 }
1817
1818 public:
1819 GeneratorInput_Func(const std::string &name, const Type &t, int d)
1820 : Super(name, IOKind::Function, {t}, d) {
1821 }
1822
1823 // unspecified type
1824 GeneratorInput_Func(const std::string &name, int d)
1825 : Super(name, IOKind::Function, {}, d) {
1826 }
1827
1828 // unspecified dimension
1829 GeneratorInput_Func(const std::string &name, const Type &t)
1830 : Super(name, IOKind::Function, {t}, -1) {
1831 }
1832
1833 // unspecified type & dimension
1834 GeneratorInput_Func(const std::string &name)
1835 : Super(name, IOKind::Function, {}, -1) {
1836 }
1837
1838 GeneratorInput_Func(size_t array_size, const std::string &name, const Type &t, int d)
1839 : Super(array_size, name, IOKind::Function, {t}, d) {
1840 }
1841
1842 // unspecified type
1843 GeneratorInput_Func(size_t array_size, const std::string &name, int d)
1844 : Super(array_size, name, IOKind::Function, {}, d) {
1845 }
1846
1847 // unspecified dimension
1848 GeneratorInput_Func(size_t array_size, const std::string &name, const Type &t)
1849 : Super(array_size, name, IOKind::Function, {t}, -1) {
1850 }
1851
1852 // unspecified type & dimension
1853 GeneratorInput_Func(size_t array_size, const std::string &name)
1854 : Super(array_size, name, IOKind::Function, {}, -1) {
1855 }
1856
1857 template<typename... Args>
1858 Expr operator()(Args &&... args) const {
1859 this->check_gio_access();
1860 return this->funcs().at(0)(std::forward<Args>(args)...);
1861 }
1862
1863 Expr operator()(const std::vector<Expr> &args) const {
1864 this->check_gio_access();
1865 return this->funcs().at(0)(args);
1866 }
1867
1868 operator Func() const {
1869 this->check_gio_access();
1870 return this->funcs().at(0);
1871 }
1872
1873 operator ExternFuncArgument() const {
1874 this->check_gio_access();
1875 return ExternFuncArgument(this->parameters_.at(0));
1876 }
1877
1878 GeneratorInput_Func<T> &set_estimate(Var var, Expr min, Expr extent) {
1879 this->check_gio_access();
1880 this->set_estimate_impl(var, min, extent);
1881 return *this;
1882 }
1883
1884 HALIDE_ATTRIBUTE_DEPRECATED("Use set_estimate() instead")
1885 GeneratorInput_Func<T> &estimate(const Var &var, const Expr &min, const Expr &extent) {
1886 return set_estimate(var, min, extent);
1887 }
1888
1889 GeneratorInput_Func<T> &set_estimates(const Region &estimates) {
1890 this->check_gio_access();
1891 this->set_estimates_impl(estimates);
1892 return *this;
1893 }
1894
1895 Func in() {
1896 this->check_gio_access();
1897 return Func(*this).in();
1898 }
1899
1900 Func in(const Func &other) {
1901 this->check_gio_access();
1902 return Func(*this).in(other);
1903 }
1904
1905 Func in(const std::vector<Func> &others) {
1906 this->check_gio_access();
1907 return Func(*this).in(others);
1908 }
1909
1910 /** Forward const methods to the underlying Func. (Non-const methods
1911 * aren't available for Input<Func>.) */
1912 // @{
1913 HALIDE_FORWARD_METHOD_CONST(Func, args)
1914 HALIDE_FORWARD_METHOD_CONST(Func, defined)
1915 HALIDE_FORWARD_METHOD_CONST(Func, has_update_definition)
1916 HALIDE_FORWARD_METHOD_CONST(Func, num_update_definitions)
1917 HALIDE_FORWARD_METHOD_CONST(Func, output_types)
1918 HALIDE_FORWARD_METHOD_CONST(Func, outputs)
1919 HALIDE_FORWARD_METHOD_CONST(Func, rvars)
1920 HALIDE_FORWARD_METHOD_CONST(Func, update_args)
1921 HALIDE_FORWARD_METHOD_CONST(Func, update_value)
1922 HALIDE_FORWARD_METHOD_CONST(Func, update_values)
1923 HALIDE_FORWARD_METHOD_CONST(Func, value)
1924 HALIDE_FORWARD_METHOD_CONST(Func, values)
1925 // }@
1926 };
1927
1928 template<typename T>
1929 class GeneratorInput_Scalar : public GeneratorInputImpl<T, Expr> {
1930 private:
1931 using Super = GeneratorInputImpl<T, Expr>;
1932
1933 protected:
1934 using TBase = typename Super::TBase;
1935
1936 const TBase def_{TBase()};
1937 const Expr def_expr_;
1938
1939 protected:
1940 Expr get_def_expr() const override {
1941 return def_expr_;
1942 }
1943
1944 void set_def_min_max() override {
1945 for (Parameter &p : this->parameters_) {
1946 p.set_scalar<TBase>(def_);
1947 }
1948 }
1949
1950 std::string get_c_type() const override {
1951 return "Expr";
1952 }
1953
1954 // Expr() doesn't accept a pointer type in its ctor; add a SFINAE adapter
1955 // so that pointer (aka handle) Inputs will get cast to uint64.
1956 template<typename TBase2 = TBase, typename std::enable_if<!std::is_pointer<TBase2>::value>::type * = nullptr>
1957 static Expr TBaseToExpr(const TBase2 &value) {
1958 return Expr(value);
1959 }
1960
1961 template<typename TBase2 = TBase, typename std::enable_if<std::is_pointer<TBase2>::value>::type * = nullptr>
1962 static Expr TBaseToExpr(const TBase2 &value) {
1963 return Expr((uint64_t)value);
1964 }
1965
1966 public:
1967 explicit GeneratorInput_Scalar(const std::string &name)
1968 : Super(name, IOKind::Scalar, {type_of<TBase>()}, 0), def_(static_cast<TBase>(0)), def_expr_(Expr()) {
1969 }
1970
1971 GeneratorInput_Scalar(const std::string &name, const TBase &def)
1972 : Super(name, IOKind::Scalar, {type_of<TBase>()}, 0), def_(def), def_expr_(TBaseToExpr(def)) {
1973 }
1974
1975 GeneratorInput_Scalar(size_t array_size,
1976 const std::string &name)
1977 : Super(array_size, name, IOKind::Scalar, {type_of<TBase>()}, 0), def_(static_cast<TBase>(0)), def_expr_(Expr()) {
1978 }
1979
1980 GeneratorInput_Scalar(size_t array_size,
1981 const std::string &name,
1982 const TBase &def)
1983 : Super(array_size, name, IOKind::Scalar, {type_of<TBase>()}, 0), def_(def), def_expr_(TBaseToExpr(def)) {
1984 }
1985
1986 /** You can use this Input as an expression in a halide
1987 * function definition */
1988 operator Expr() const {
1989 this->check_gio_access();
1990 return this->exprs().at(0);
1991 }
1992
1993 /** Using an Input as the argument to an external stage treats it
1994 * as an Expr */
1995 operator ExternFuncArgument() const {
1996 this->check_gio_access();
1997 return ExternFuncArgument(this->exprs().at(0));
1998 }
1999
2000 template<typename T2 = T, typename std::enable_if<std::is_pointer<T2>::value>::type * = nullptr>
2001 void set_estimate(const TBase &value) {
2002 this->check_gio_access();
2003 user_assert(value == nullptr) << "nullptr is the only valid estimate for Input<PointerType>";
2004 Expr e = reinterpret(type_of<T2>(), cast<uint64_t>(0));
2005 for (Parameter &p : this->parameters_) {
2006 p.set_estimate(e);
2007 }
2008 }
2009
2010 template<typename T2 = T, typename std::enable_if<!std::is_array<T2>::value && !std::is_pointer<T2>::value>::type * = nullptr>
2011 void set_estimate(const TBase &value) {
2012 this->check_gio_access();
2013 Expr e = Expr(value);
2014 if (std::is_same<T2, bool>::value) {
2015 e = cast<bool>(e);
2016 }
2017 for (Parameter &p : this->parameters_) {
2018 p.set_estimate(e);
2019 }
2020 }
2021
2022 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
2023 void set_estimate(size_t index, const TBase &value) {
2024 this->check_gio_access();
2025 Expr e = Expr(value);
2026 if (std::is_same<T2, bool>::value) {
2027 e = cast<bool>(e);
2028 }
2029 this->parameters_.at(index).set_estimate(e);
2030 }
2031 };
2032
2033 template<typename T>
2034 class GeneratorInput_Arithmetic : public GeneratorInput_Scalar<T> {
2035 private:
2036 using Super = GeneratorInput_Scalar<T>;
2037
2038 protected:
2039 using TBase = typename Super::TBase;
2040
2041 const Expr min_, max_;
2042
2043 protected:
2044 void set_def_min_max() override {
2045 Super::set_def_min_max();
2046 // Don't set min/max for bool
2047 if (!std::is_same<TBase, bool>::value) {
2048 for (Parameter &p : this->parameters_) {
2049 if (min_.defined()) p.set_min_value(min_);
2050 if (max_.defined()) p.set_max_value(max_);
2051 }
2052 }
2053 }
2054
2055 public:
2056 explicit GeneratorInput_Arithmetic(const std::string &name)
2057 : Super(name), min_(Expr()), max_(Expr()) {
2058 }
2059
2060 GeneratorInput_Arithmetic(const std::string &name,
2061 const TBase &def)
2062 : Super(name, def), min_(Expr()), max_(Expr()) {
2063 }
2064
2065 GeneratorInput_Arithmetic(size_t array_size,
2066 const std::string &name)
2067 : Super(array_size, name), min_(Expr()), max_(Expr()) {
2068 }
2069
2070 GeneratorInput_Arithmetic(size_t array_size,
2071 const std::string &name,
2072 const TBase &def)
2073 : Super(array_size, name, def), min_(Expr()), max_(Expr()) {
2074 }
2075
2076 GeneratorInput_Arithmetic(const std::string &name,
2077 const TBase &def,
2078 const TBase &min,
2079 const TBase &max)
2080 : Super(name, def), min_(min), max_(max) {
2081 }
2082
2083 GeneratorInput_Arithmetic(size_t array_size,
2084 const std::string &name,
2085 const TBase &def,
2086 const TBase &min,
2087 const TBase &max)
2088 : Super(array_size, name, def), min_(min), max_(max) {
2089 }
2090 };
2091
2092 template<typename>
2093 struct type_sink { typedef void type; };
2094
2095 template<typename T2, typename = void>
2096 struct has_static_halide_type_method : std::false_type {};
2097
2098 template<typename T2>
2099 struct has_static_halide_type_method<T2, typename type_sink<decltype(T2::static_halide_type())>::type> : std::true_type {};
2100
2101 template<typename T, typename TBase = typename std::remove_all_extents<T>::type>
2102 using GeneratorInputImplBase =
2103 typename select_type<
2104 cond<has_static_halide_type_method<TBase>::value, GeneratorInput_Buffer<T>>,
2105 cond<std::is_same<TBase, Func>::value, GeneratorInput_Func<T>>,
2106 cond<std::is_arithmetic<TBase>::value, GeneratorInput_Arithmetic<T>>,
2107 cond<std::is_scalar<TBase>::value, GeneratorInput_Scalar<T>>>::type;
2108
2109 } // namespace Internal
2110
2111 template<typename T>
2112 class GeneratorInput : public Internal::GeneratorInputImplBase<T> {
2113 private:
2114 using Super = Internal::GeneratorInputImplBase<T>;
2115
2116 protected:
2117 using TBase = typename Super::TBase;
2118
2119 // Trick to avoid ambiguous ctor between Func-with-dim and int-with-default-value;
2120 // since we can't use std::enable_if on ctors, define the argument to be one that
2121 // can only be properly resolved for TBase=Func.
2122 struct Unused;
2123 using IntIfNonScalar =
2124 typename Internal::select_type<
2125 Internal::cond<Internal::has_static_halide_type_method<TBase>::value, int>,
2126 Internal::cond<std::is_same<TBase, Func>::value, int>,
2127 Internal::cond<true, Unused>>::type;
2128
2129 public:
2130 explicit GeneratorInput(const std::string &name)
2131 : Super(name) {
2132 }
2133
2134 GeneratorInput(const std::string &name, const TBase &def)
2135 : Super(name, def) {
2136 }
2137
2138 GeneratorInput(size_t array_size, const std::string &name, const TBase &def)
2139 : Super(array_size, name, def) {
2140 }
2141
2142 GeneratorInput(const std::string &name,
2143 const TBase &def, const TBase &min, const TBase &max)
2144 : Super(name, def, min, max) {
2145 }
2146
2147 GeneratorInput(size_t array_size, const std::string &name,
2148 const TBase &def, const TBase &min, const TBase &max)
2149 : Super(array_size, name, def, min, max) {
2150 }
2151
2152 GeneratorInput(const std::string &name, const Type &t, int d)
2153 : Super(name, t, d) {
2154 }
2155
2156 GeneratorInput(const std::string &name, const Type &t)
2157 : Super(name, t) {
2158 }
2159
2160 // Avoid ambiguity between Func-with-dim and int-with-default
2161 GeneratorInput(const std::string &name, IntIfNonScalar d)
2162 : Super(name, d) {
2163 }
2164
2165 GeneratorInput(size_t array_size, const std::string &name, const Type &t, int d)
2166 : Super(array_size, name, t, d) {
2167 }
2168
2169 GeneratorInput(size_t array_size, const std::string &name, const Type &t)
2170 : Super(array_size, name, t) {
2171 }
2172
2173 // Avoid ambiguity between Func-with-dim and int-with-default
2174 //template <typename T2 = T, typename std::enable_if<std::is_same<TBase, Func>::value>::type * = nullptr>
2175 GeneratorInput(size_t array_size, const std::string &name, IntIfNonScalar d)
2176 : Super(array_size, name, d) {
2177 }
2178
2179 GeneratorInput(size_t array_size, const std::string &name)
2180 : Super(array_size, name) {
2181 }
2182 };
2183
2184 namespace Internal {
2185
2186 class GeneratorOutputBase : public GIOBase {
2187 protected:
2188 template<typename T2, typename std::enable_if<std::is_same<T2, Func>::value>::type * = nullptr>
2189 HALIDE_NO_USER_CODE_INLINE T2 as() const {
2190 static_assert(std::is_same<T2, Func>::value, "Only Func allowed here");
2191 internal_assert(kind() != IOKind::Scalar);
2192 internal_assert(exprs_.empty());
2193 user_assert(funcs_.size() == 1) << "Use [] to access individual Funcs in Output<Func[]>";
2194 return funcs_[0];
2195 }
2196
2197 public:
2198 HALIDE_ATTRIBUTE_DEPRECATED("Use set_estimate() instead")
2199 GeneratorOutputBase &estimate(const Var &var, const Expr &min, const Expr &extent) {
2200 this->as<Func>().set_estimate(var, min, extent);
2201 return *this;
2202 }
2203
2204 /** Forward schedule-related methods to the underlying Func. */
2205 // @{
2206 HALIDE_FORWARD_METHOD(Func, add_trace_tag)
2207 HALIDE_FORWARD_METHOD(Func, align_bounds)
2208 HALIDE_FORWARD_METHOD(Func, align_storage)
2209 HALIDE_FORWARD_METHOD_CONST(Func, args)
2210 HALIDE_FORWARD_METHOD(Func, bound)
2211 HALIDE_FORWARD_METHOD(Func, bound_extent)
2212 HALIDE_FORWARD_METHOD(Func, compute_at)
2213 HALIDE_FORWARD_METHOD(Func, compute_inline)
2214 HALIDE_FORWARD_METHOD(Func, compute_root)
2215 HALIDE_FORWARD_METHOD(Func, compute_with)
2216 HALIDE_FORWARD_METHOD(Func, copy_to_device)
2217 HALIDE_FORWARD_METHOD(Func, copy_to_host)
2218 HALIDE_FORWARD_METHOD(Func, define_extern)
2219 HALIDE_FORWARD_METHOD_CONST(Func, defined)
2220 HALIDE_FORWARD_METHOD(Func, fold_storage)
2221 HALIDE_FORWARD_METHOD(Func, fuse)
2222 HALIDE_FORWARD_METHOD(Func, glsl)
2223 HALIDE_FORWARD_METHOD(Func, gpu)
2224 HALIDE_FORWARD_METHOD(Func, gpu_blocks)
2225 HALIDE_FORWARD_METHOD(Func, gpu_single_thread)
2226 HALIDE_FORWARD_METHOD(Func, gpu_threads)
2227 HALIDE_FORWARD_METHOD(Func, gpu_tile)
2228 HALIDE_FORWARD_METHOD_CONST(Func, has_update_definition)
2229 HALIDE_FORWARD_METHOD(Func, hexagon)
2230 HALIDE_FORWARD_METHOD(Func, in)
2231 HALIDE_FORWARD_METHOD(Func, memoize)
2232 HALIDE_FORWARD_METHOD_CONST(Func, num_update_definitions)
2233 HALIDE_FORWARD_METHOD_CONST(Func, output_types)
2234 HALIDE_FORWARD_METHOD_CONST(Func, outputs)
2235 HALIDE_FORWARD_METHOD(Func, parallel)
2236 HALIDE_FORWARD_METHOD(Func, prefetch)
2237 HALIDE_FORWARD_METHOD(Func, print_loop_nest)
2238 HALIDE_FORWARD_METHOD(Func, rename)
2239 HALIDE_FORWARD_METHOD(Func, reorder)
2240 HALIDE_FORWARD_METHOD(Func, reorder_storage)
2241 HALIDE_FORWARD_METHOD_CONST(Func, rvars)
2242 HALIDE_FORWARD_METHOD(Func, serial)
2243 HALIDE_FORWARD_METHOD(Func, set_estimate)
2244 HALIDE_FORWARD_METHOD(Func, shader)
2245 HALIDE_FORWARD_METHOD(Func, specialize)
2246 HALIDE_FORWARD_METHOD(Func, specialize_fail)
2247 HALIDE_FORWARD_METHOD(Func, split)
2248 HALIDE_FORWARD_METHOD(Func, store_at)
2249 HALIDE_FORWARD_METHOD(Func, store_root)
2250 HALIDE_FORWARD_METHOD(Func, tile)
2251 HALIDE_FORWARD_METHOD(Func, trace_stores)
2252 HALIDE_FORWARD_METHOD(Func, unroll)
2253 HALIDE_FORWARD_METHOD(Func, update)
2254 HALIDE_FORWARD_METHOD_CONST(Func, update_args)
2255 HALIDE_FORWARD_METHOD_CONST(Func, update_value)
2256 HALIDE_FORWARD_METHOD_CONST(Func, update_values)
2257 HALIDE_FORWARD_METHOD_CONST(Func, value)
2258 HALIDE_FORWARD_METHOD_CONST(Func, values)
2259 HALIDE_FORWARD_METHOD(Func, vectorize)
2260 // }@
2261
2262 #undef HALIDE_OUTPUT_FORWARD
2263 #undef HALIDE_OUTPUT_FORWARD_CONST
2264
2265 protected:
2266 GeneratorOutputBase(size_t array_size,
2267 const std::string &name,
2268 IOKind kind,
2269 const std::vector<Type> &t,
2270 int d);
2271
2272 GeneratorOutputBase(const std::string &name,
2273 IOKind kind,
2274 const std::vector<Type> &t,
2275 int d);
2276
2277 friend class GeneratorBase;
2278 friend class StubEmitter;
2279
2280 void init_internals();
2281 void resize(size_t size);
2282
2283 virtual std::string get_c_type() const {
2284 return "Func";
2285 }
2286
2287 void check_value_writable() const override;
2288
2289 const char *input_or_output() const override {
2290 return "Output";
2291 }
2292
2293 public:
2294 ~GeneratorOutputBase() override;
2295 };
2296
2297 template<typename T>
2298 class GeneratorOutputImpl : public GeneratorOutputBase {
2299 protected:
2300 using TBase = typename std::remove_all_extents<T>::type;
2301 using ValueType = Func;
2302
2303 bool is_array() const override {
2304 return std::is_array<T>::value;
2305 }
2306
2307 template<typename T2 = T, typename std::enable_if<
2308 // Only allow T2 not-an-array
2309 !std::is_array<T2>::value>::type * = nullptr>
2310 GeneratorOutputImpl(const std::string &name, IOKind kind, const std::vector<Type> &t, int d)
2311 : GeneratorOutputBase(name, kind, t, d) {
2312 }
2313
2314 template<typename T2 = T, typename std::enable_if<
2315 // Only allow T2[kSomeConst]
2316 std::is_array<T2>::value && std::rank<T2>::value == 1 && (std::extent<T2, 0>::value > 0)>::type * = nullptr>
2317 GeneratorOutputImpl(const std::string &name, IOKind kind, const std::vector<Type> &t, int d)
2318 : GeneratorOutputBase(std::extent<T2, 0>::value, name, kind, t, d) {
2319 }
2320
2321 template<typename T2 = T, typename std::enable_if<
2322 // Only allow T2[]
2323 std::is_array<T2>::value && std::rank<T2>::value == 1 && std::extent<T2, 0>::value == 0>::type * = nullptr>
2324 GeneratorOutputImpl(const std::string &name, IOKind kind, const std::vector<Type> &t, int d)
2325 : GeneratorOutputBase(-1, name, kind, t, d) {
2326 }
2327
2328 public:
2329 template<typename... Args, typename T2 = T, typename std::enable_if<!std::is_array<T2>::value>::type * = nullptr>
2330 FuncRef operator()(Args &&... args) const {
2331 this->check_gio_access();
2332 return get_values<ValueType>().at(0)(std::forward<Args>(args)...);
2333 }
2334
2335 template<typename ExprOrVar, typename T2 = T, typename std::enable_if<!std::is_array<T2>::value>::type * = nullptr>
2336 FuncRef operator()(std::vector<ExprOrVar> args) const {
2337 this->check_gio_access();
2338 return get_values<ValueType>().at(0)(args);
2339 }
2340
2341 template<typename T2 = T, typename std::enable_if<!std::is_array<T2>::value>::type * = nullptr>
2342 operator Func() const {
2343 this->check_gio_access();
2344 return get_values<ValueType>().at(0);
2345 }
2346
2347 template<typename T2 = T, typename std::enable_if<!std::is_array<T2>::value>::type * = nullptr>
2348 operator Stage() const {
2349 this->check_gio_access();
2350 return get_values<ValueType>().at(0);
2351 }
2352
2353 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
2354 size_t size() const {
2355 this->check_gio_access();
2356 return get_values<ValueType>().size();
2357 }
2358
2359 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
2360 const ValueType &operator[](size_t i) const {
2361 this->check_gio_access();
2362 return get_values<ValueType>()[i];
2363 }
2364
2365 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
2366 const ValueType &at(size_t i) const {
2367 this->check_gio_access();
2368 return get_values<ValueType>().at(i);
2369 }
2370
2371 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
2372 typename std::vector<ValueType>::const_iterator begin() const {
2373 this->check_gio_access();
2374 return get_values<ValueType>().begin();
2375 }
2376
2377 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
2378 typename std::vector<ValueType>::const_iterator end() const {
2379 this->check_gio_access();
2380 return get_values<ValueType>().end();
2381 }
2382
2383 template<typename T2 = T, typename std::enable_if<
2384 // Only allow T2[]
2385 std::is_array<T2>::value && std::rank<T2>::value == 1 && std::extent<T2, 0>::value == 0>::type * = nullptr>
2386 void resize(size_t size) {
2387 this->check_gio_access();
2388 GeneratorOutputBase::resize(size);
2389 }
2390 };
2391
2392 template<typename T>
2393 class GeneratorOutput_Buffer : public GeneratorOutputImpl<T> {
2394 private:
2395 using Super = GeneratorOutputImpl<T>;
2396
2397 HALIDE_NO_USER_CODE_INLINE void assign_from_func(const Func &f) {
2398 this->check_value_writable();
2399
2400 internal_assert(f.defined());
2401
2402 if (this->types_defined()) {
2403 const auto &my_types = this->types();
2404 user_assert(my_types.size() == f.output_types().size())
2405 << "Cannot assign Func \"" << f.name()
2406 << "\" to Output \"" << this->name() << "\"\n"
2407 << "Output " << this->name()
2408 << " is declared to have " << my_types.size() << " tuple elements"
2409 << " but Func " << f.name()
2410 << " has " << f.output_types().size() << " tuple elements.\n";
2411 for (size_t i = 0; i < my_types.size(); i++) {
2412 user_assert(my_types[i] == f.output_types().at(i))
2413 << "Cannot assign Func \"" << f.name()
2414 << "\" to Output \"" << this->name() << "\"\n"
2415 << (my_types.size() > 1 ? "In tuple element " + std::to_string(i) + ", " : "")
2416 << "Output " << this->name()
2417 << " has declared type " << my_types[i]
2418 << " but Func " << f.name()
2419 << " has type " << f.output_types().at(i) << "\n";
2420 }
2421 }
2422 if (this->dims_defined()) {
2423 user_assert(f.dimensions() == this->dims())
2424 << "Cannot assign Func \"" << f.name()
2425 << "\" to Output \"" << this->name() << "\"\n"
2426 << "Output " << this->name()
2427 << " has declared dimensionality " << this->dims()
2428 << " but Func " << f.name()
2429 << " has dimensionality " << f.dimensions() << "\n";
2430 }
2431
2432 internal_assert(this->exprs_.empty() && this->funcs_.size() == 1);
2433 user_assert(!this->funcs_.at(0).defined());
2434 this->funcs_[0] = f;
2435 }
2436
2437 protected:
2438 using TBase = typename Super::TBase;
2439
2440 static std::vector<Type> my_types(const std::vector<Type> &t) {
2441 if (TBase::has_static_halide_type) {
2442 user_assert(t.empty()) << "Cannot pass a Type argument for an Output<Buffer> with a non-void static type\n";
2443 return std::vector<Type>{TBase::static_halide_type()};
2444 }
2445 return t;
2446 }
2447
2448 protected:
2449 GeneratorOutput_Buffer(const std::string &name, const std::vector<Type> &t = {}, int d = -1)
2450 : Super(name, IOKind::Buffer, my_types(t), d) {
2451 }
2452
2453 GeneratorOutput_Buffer(size_t array_size, const std::string &name, const std::vector<Type> &t = {}, int d = -1)
2454 : Super(array_size, name, IOKind::Buffer, my_types(t), d) {
2455 }
2456
2457 HALIDE_NO_USER_CODE_INLINE std::string get_c_type() const override {
2458 if (TBase::has_static_halide_type) {
2459 return "Halide::Internal::StubOutputBuffer<" +
2460 halide_type_to_c_type(TBase::static_halide_type()) +
2461 ">";
2462 } else {
2463 return "Halide::Internal::StubOutputBuffer<>";
2464 }
2465 }
2466
2467 template<typename T2, typename std::enable_if<!std::is_same<T2, Func>::value>::type * = nullptr>
2468 HALIDE_NO_USER_CODE_INLINE T2 as() const {
2469 return (T2) * this;
2470 }
2471
2472 public:
2473 // Allow assignment from a Buffer<> to an Output<Buffer<>>;
2474 // this allows us to use a statically-compiled buffer inside a Generator
2475 // to assign to an output.
2476 // TODO: This used to take the buffer as a const ref. This no longer works as
2477 // using it in a Pipeline might change the dev field so it is currently
2478 // not considered const. We should consider how this really ought to work.
2479 template<typename T2>
2480 HALIDE_NO_USER_CODE_INLINE GeneratorOutput_Buffer<T> &operator=(Buffer<T2> &buffer) {
2481 this->check_gio_access();
2482 this->check_value_writable();
2483
2484 user_assert(T::can_convert_from(buffer))
2485 << "Cannot assign to the Output \"" << this->name()
2486 << "\": the expression is not convertible to the same Buffer type and/or dimensions.\n";
2487
2488 if (this->types_defined()) {
2489 user_assert(Type(buffer.type()) == this->type())
2490 << "Output " << this->name() << " should have type=" << this->type() << " but saw type=" << Type(buffer.type()) << "\n";
2491 }
2492 if (this->dims_defined()) {
2493 user_assert(buffer.dimensions() == this->dims())
2494 << "Output " << this->name() << " should have dim=" << this->dims() << " but saw dim=" << buffer.dimensions() << "\n";
2495 }
2496
2497 internal_assert(this->exprs_.empty() && this->funcs_.size() == 1);
2498 user_assert(!this->funcs_.at(0).defined());
2499 this->funcs_.at(0)(_) = buffer(_);
2500
2501 return *this;
2502 }
2503
2504 // Allow assignment from a StubOutputBuffer to an Output<Buffer>;
2505 // this allows us to pipeline the results of a Stub to the results
2506 // of the enclosing Generator.
2507 template<typename T2>
2508 GeneratorOutput_Buffer<T> &operator=(const StubOutputBuffer<T2> &stub_output_buffer) {
2509 this->check_gio_access();
2510 assign_from_func(stub_output_buffer.f);
2511 return *this;
2512 }
2513
2514 // Allow assignment from a Func to an Output<Buffer>;
2515 // this allows us to use helper functions that return a plain Func
2516 // to simply set the output(s) without needing a wrapper Func.
2517 GeneratorOutput_Buffer<T> &operator=(const Func &f) {
2518 this->check_gio_access();
2519 assign_from_func(f);
2520 return *this;
2521 }
2522
2523 operator OutputImageParam() const {
2524 this->check_gio_access();
2525 user_assert(!this->is_array()) << "Cannot convert an Output<Buffer<>[]> to an ImageParam; use an explicit subscript operator: " << this->name();
2526 internal_assert(this->exprs_.empty() && this->funcs_.size() == 1);
2527 return this->funcs_.at(0).output_buffer();
2528 }
2529
2530 // 'perfect forwarding' won't work with initializer lists,
2531 // so hand-roll our own forwarding method for set_estimates,
2532 // rather than using HALIDE_FORWARD_METHOD.
2533 GeneratorOutput_Buffer<T> &set_estimates(const Region &estimates) {
2534 this->as<OutputImageParam>().set_estimates(estimates);
2535 return *this;
2536 }
2537
2538 /** Forward methods to the OutputImageParam. */
2539 // @{
2540 HALIDE_FORWARD_METHOD(OutputImageParam, dim)
2541 HALIDE_FORWARD_METHOD_CONST(OutputImageParam, dim)
2542 HALIDE_FORWARD_METHOD_CONST(OutputImageParam, host_alignment)
2543 HALIDE_FORWARD_METHOD(OutputImageParam, set_host_alignment)
2544 HALIDE_FORWARD_METHOD_CONST(OutputImageParam, dimensions)
2545 HALIDE_FORWARD_METHOD_CONST(OutputImageParam, left)
2546 HALIDE_FORWARD_METHOD_CONST(OutputImageParam, right)
2547 HALIDE_FORWARD_METHOD_CONST(OutputImageParam, top)
2548 HALIDE_FORWARD_METHOD_CONST(OutputImageParam, bottom)
2549 HALIDE_FORWARD_METHOD_CONST(OutputImageParam, width)
2550 HALIDE_FORWARD_METHOD_CONST(OutputImageParam, height)
2551 HALIDE_FORWARD_METHOD_CONST(OutputImageParam, channels)
2552 // }@
2553 };
2554
2555 template<typename T>
2556 class GeneratorOutput_Func : public GeneratorOutputImpl<T> {
2557 private:
2558 using Super = GeneratorOutputImpl<T>;
2559
2560 HALIDE_NO_USER_CODE_INLINE Func &get_assignable_func_ref(size_t i) {
2561 internal_assert(this->exprs_.empty() && this->funcs_.size() > i);
2562 return this->funcs_.at(i);
2563 }
2564
2565 protected:
2566 using TBase = typename Super::TBase;
2567
2568 protected:
2569 GeneratorOutput_Func(const std::string &name)
2570 : Super(name, IOKind::Function, std::vector<Type>{}, -1) {
2571 }
2572
2573 GeneratorOutput_Func(const std::string &name, const std::vector<Type> &t, int d = -1)
2574 : Super(name, IOKind::Function, t, d) {
2575 }
2576
2577 GeneratorOutput_Func(size_t array_size, const std::string &name, const std::vector<Type> &t, int d)
2578 : Super(array_size, name, IOKind::Function, t, d) {
2579 }
2580
2581 public:
2582 // Allow Output<Func> = Func
2583 template<typename T2 = T, typename std::enable_if<!std::is_array<T2>::value>::type * = nullptr>
2584 GeneratorOutput_Func<T> &operator=(const Func &f) {
2585 this->check_gio_access();
2586 this->check_value_writable();
2587
2588 // Don't bother verifying the Func type, dimensions, etc., here:
2589 // That's done later, when we produce the pipeline.
2590 get_assignable_func_ref(0) = f;
2591 return *this;
2592 }
2593
2594 // Allow Output<Func[]> = Func
2595 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
2596 Func &operator[](size_t i) {
2597 this->check_gio_access();
2598 this->check_value_writable();
2599 return get_assignable_func_ref(i);
2600 }
2601
2602 // Allow Func = Output<Func[]>
2603 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
2604 const Func &operator[](size_t i) const {
2605 this->check_gio_access();
2606 return Super::operator[](i);
2607 }
2608
2609 GeneratorOutput_Func<T> &set_estimate(const Var &var, const Expr &min, const Expr &extent) {
2610 this->check_gio_access();
2611 internal_assert(this->exprs_.empty() && !this->funcs_.empty());
2612 for (Func &f : this->funcs_) {
2613 f.set_estimate(var, min, extent);
2614 }
2615 return *this;
2616 }
2617
2618 HALIDE_ATTRIBUTE_DEPRECATED("Use set_estimate() instead")
2619 GeneratorOutput_Func<T> &estimate(const Var &var, const Expr &min, const Expr &extent) {
2620 return set_estimate(var, min, extent);
2621 }
2622
2623 GeneratorOutput_Func<T> &set_estimates(const Region &estimates) {
2624 this->check_gio_access();
2625 internal_assert(this->exprs_.empty() && !this->funcs_.empty());
2626 for (Func &f : this->funcs_) {
2627 f.set_estimates(estimates);
2628 }
2629 return *this;
2630 }
2631 };
2632
2633 template<typename T>
2634 class GeneratorOutput_Arithmetic : public GeneratorOutputImpl<T> {
2635 private:
2636 using Super = GeneratorOutputImpl<T>;
2637
2638 protected:
2639 using TBase = typename Super::TBase;
2640
2641 protected:
2642 explicit GeneratorOutput_Arithmetic(const std::string &name)
2643 : Super(name, IOKind::Function, {type_of<TBase>()}, 0) {
2644 }
2645
2646 GeneratorOutput_Arithmetic(size_t array_size, const std::string &name)
2647 : Super(array_size, name, IOKind::Function, {type_of<TBase>()}, 0) {
2648 }
2649 };
2650
2651 template<typename T, typename TBase = typename std::remove_all_extents<T>::type>
2652 using GeneratorOutputImplBase =
2653 typename select_type<
2654 cond<has_static_halide_type_method<TBase>::value, GeneratorOutput_Buffer<T>>,
2655 cond<std::is_same<TBase, Func>::value, GeneratorOutput_Func<T>>,
2656 cond<std::is_arithmetic<TBase>::value, GeneratorOutput_Arithmetic<T>>>::type;
2657
2658 } // namespace Internal
2659
2660 template<typename T>
2661 class GeneratorOutput : public Internal::GeneratorOutputImplBase<T> {
2662 private:
2663 using Super = Internal::GeneratorOutputImplBase<T>;
2664
2665 protected:
2666 using TBase = typename Super::TBase;
2667
2668 public:
2669 explicit GeneratorOutput(const std::string &name)
2670 : Super(name) {
2671 }
2672
2673 explicit GeneratorOutput(const char *name)
2674 : GeneratorOutput(std::string(name)) {
2675 }
2676
2677 GeneratorOutput(size_t array_size, const std::string &name)
2678 : Super(array_size, name) {
2679 }
2680
2681 GeneratorOutput(const std::string &name, int d)
2682 : Super(name, {}, d) {
2683 }
2684
2685 GeneratorOutput(const std::string &name, const Type &t, int d)
2686 : Super(name, {t}, d) {
2687 }
2688
2689 GeneratorOutput(const std::string &name, const std::vector<Type> &t, int d)
2690 : Super(name, t, d) {
2691 }
2692
2693 GeneratorOutput(size_t array_size, const std::string &name, int d)
2694 : Super(array_size, name, {}, d) {
2695 }
2696
2697 GeneratorOutput(size_t array_size, const std::string &name, const Type &t, int d)
2698 : Super(array_size, name, {t}, d) {
2699 }
2700
2701 GeneratorOutput(size_t array_size, const std::string &name, const std::vector<Type> &t, int d)
2702 : Super(array_size, name, t, d) {
2703 }
2704
2705 // TODO: This used to take the buffer as a const ref. This no longer works as
2706 // using it in a Pipeline might change the dev field so it is currently
2707 // not considered const. We should consider how this really ought to work.
2708 template<typename T2>
2709 GeneratorOutput<T> &operator=(Buffer<T2> &buffer) {
2710 Super::operator=(buffer);
2711 return *this;
2712 }
2713
2714 template<typename T2>
2715 GeneratorOutput<T> &operator=(const Internal::StubOutputBuffer<T2> &stub_output_buffer) {
2716 Super::operator=(stub_output_buffer);
2717 return *this;
2718 }
2719
2720 GeneratorOutput<T> &operator=(const Func &f) {
2721 Super::operator=(f);
2722 return *this;
2723 }
2724 };
2725
2726 namespace Internal {
2727
2728 template<typename T>
2729 T parse_scalar(const std::string &value) {
2730 std::istringstream iss(value);
2731 T t;
2732 iss >> t;
2733 user_assert(!iss.fail() && iss.get() == EOF) << "Unable to parse: " << value;
2734 return t;
2735 }
2736
2737 std::vector<Type> parse_halide_type_list(const std::string &types);
2738
2739 enum class SyntheticParamType { Type,
2740 Dim,
2741 ArraySize };
2742
2743 // This is a type of GeneratorParam used internally to create 'synthetic' params
2744 // (e.g. image.type, image.dim); it is not possible for user code to instantiate it.
2745 template<typename T>
2746 class GeneratorParam_Synthetic : public GeneratorParamImpl<T> {
2747 public:
2748 void set_from_string(const std::string &new_value_string) override {
2749 // If error_msg is not empty, this is unsettable:
2750 // display error_msg as a user error.
2751 if (!error_msg.empty()) {
2752 user_error << error_msg;
2753 }
2754 set_from_string_impl<T>(new_value_string);
2755 }
2756
2757 std::string get_default_value() const override {
2758 internal_error;
2759 return std::string();
2760 }
2761
2762 std::string call_to_string(const std::string &v) const override {
2763 internal_error;
2764 return std::string();
2765 }
2766
2767 std::string get_c_type() const override {
2768 internal_error;
2769 return std::string();
2770 }
2771
2772 bool is_synthetic_param() const override {
2773 return true;
2774 }
2775
2776 private:
2777 friend class GeneratorParamInfo;
2778
2779 static std::unique_ptr<Internal::GeneratorParamBase> make(
2780 GeneratorBase *generator,
2781 const std::string &generator_name,
2782 const std::string &gpname,
2783 GIOBase &gio,
2784 SyntheticParamType which,
2785 bool defined) {
2786 std::string error_msg = defined ? "Cannot set the GeneratorParam " + gpname + " for " + generator_name + " because the value is explicitly specified in the C++ source." : "";
2787 return std::unique_ptr<GeneratorParam_Synthetic<T>>(
2788 new GeneratorParam_Synthetic<T>(gpname, gio, which, error_msg));
2789 }
2790
2791 GeneratorParam_Synthetic(const std::string &name, GIOBase &gio, SyntheticParamType which, const std::string &error_msg = "")
2792 : GeneratorParamImpl<T>(name, T()), gio(gio), which(which), error_msg(error_msg) {
2793 }
2794
2795 template<typename T2 = T, typename std::enable_if<std::is_same<T2, ::Halide::Type>::value>::type * = nullptr>
2796 void set_from_string_impl(const std::string &new_value_string) {
2797 internal_assert(which == SyntheticParamType::Type);
2798 gio.types_ = parse_halide_type_list(new_value_string);
2799 }
2800
2801 template<typename T2 = T, typename std::enable_if<std::is_integral<T2>::value>::type * = nullptr>
2802 void set_from_string_impl(const std::string &new_value_string) {
2803 if (which == SyntheticParamType::Dim) {
2804 gio.dims_ = parse_scalar<T2>(new_value_string);
2805 } else if (which == SyntheticParamType::ArraySize) {
2806 gio.array_size_ = parse_scalar<T2>(new_value_string);
2807 } else {
2808 internal_error;
2809 }
2810 }
2811
2812 GIOBase &gio;
2813 const SyntheticParamType which;
2814 const std::string error_msg;
2815 };
2816
2817 class GeneratorStub;
2818
2819 } // namespace Internal
2820
2821 /** GeneratorContext is a base class that is used when using Generators (or Stubs) directly;
2822 * it is used to allow the outer context (typically, either a Generator or "top-level" code)
2823 * to specify certain information to the inner context to ensure that inner and outer
2824 * Generators are compiled in a compatible way.
2825 *
2826 * If you are using this at "top level" (e.g. with the JIT), you can construct a GeneratorContext
2827 * with a Target:
2828 * \code
2829 * auto my_stub = MyStub(
2830 * GeneratorContext(get_target_from_environment()),
2831 * // inputs
2832 * { ... },
2833 * // generator params
2834 * { ... }
2835 * );
2836 * \endcode
2837 *
2838 * Note that all Generators inherit from GeneratorContext, so if you are using a Stub
2839 * from within a Generator, you can just pass 'this' for the GeneratorContext:
2840 * \code
2841 * struct SomeGen : Generator<SomeGen> {
2842 * void generate() {
2843 * ...
2844 * auto my_stub = MyStub(
2845 * this, // GeneratorContext
2846 * // inputs
2847 * { ... },
2848 * // generator params
2849 * { ... }
2850 * );
2851 * ...
2852 * }
2853 * };
2854 * \endcode
2855 */
2856 class GeneratorContext {
2857 public:
2858 using ExternsMap = std::map<std::string, ExternalCode>;
2859
2860 explicit GeneratorContext(const Target &t,
2861 bool auto_schedule = false,
2862 const MachineParams &machine_params = MachineParams::generic());
2863 virtual ~GeneratorContext();
2864
2865 inline Target get_target() const {
2866 return target;
2867 }
2868 inline bool get_auto_schedule() const {
2869 return auto_schedule;
2870 }
2871 inline MachineParams get_machine_params() const {
2872 return machine_params;
2873 }
2874
2875 /** Generators can register ExternalCode objects onto
2876 * themselves. The Generator infrastructure will arrange to have
2877 * this ExternalCode appended to the Module that is finally
2878 * compiled using the Generator. This allows encapsulating
2879 * functionality that depends on external libraries or handwritten
2880 * code for various targets. The name argument should match the
2881 * name of the ExternalCode block and is used to ensure the same
2882 * code block is not duplicated in the output. Halide does not do
2883 * anything other than to compare names for equality. To guarantee
2884 * uniqueness in public code, we suggest using a Java style
2885 * inverted domain name followed by organization specific
2886 * naming. E.g.:
2887 * com.yoyodyne.overthruster.0719acd19b66df2a9d8d628a8fefba911a0ab2b7
2888 *
2889 * See test/generator/external_code_generator.cpp for example use. */
2890 inline std::shared_ptr<ExternsMap> get_externs_map() const {
2891 return externs_map;
2892 }
2893
2894 template<typename T>
2895 inline std::unique_ptr<T> create() const {
2896 return T::create(*this);
2897 }
2898
2899 template<typename T, typename... Args>
2900 inline std::unique_ptr<T> apply(const Args &... args) const {
2901 auto t = this->create<T>();
2902 t->apply(args...);
2903 return t;
2904 }
2905
2906 protected:
2907 GeneratorParam<Target> target;
2908 GeneratorParam<bool> auto_schedule;
2909 GeneratorParam<MachineParams> machine_params;
2910 std::shared_ptr<ExternsMap> externs_map;
2911 std::shared_ptr<Internal::ValueTracker> value_tracker;
2912
2913 GeneratorContext()
2914 : GeneratorContext(Target()) {
2915 }
2916
2917 virtual void init_from_context(const Halide::GeneratorContext &context);
2918
2919 inline std::shared_ptr<Internal::ValueTracker> get_value_tracker() const {
2920 return value_tracker;
2921 }
2922
2923 // No copy
2924 GeneratorContext(const GeneratorContext &) = delete;
2925 void operator=(const GeneratorContext &) = delete;
2926 // No move
2927 GeneratorContext(GeneratorContext &&) = delete;
2928 void operator=(GeneratorContext &&) = delete;
2929 };
2930
2931 class NamesInterface {
2932 // Names in this class are only intended for use in derived classes.
2933 protected:
2934 // Import a consistent list of Halide names that can be used in
2935 // Halide generators without qualification.
2936 using Expr = Halide::Expr;
2937 using ExternFuncArgument = Halide::ExternFuncArgument;
2938 using Func = Halide::Func;
2939 using GeneratorContext = Halide::GeneratorContext;
2940 using ImageParam = Halide::ImageParam;
2941 using LoopLevel = Halide::LoopLevel;
2942 using MemoryType = Halide::MemoryType;
2943 using NameMangling = Halide::NameMangling;
2944 using Pipeline = Halide::Pipeline;
2945 using PrefetchBoundStrategy = Halide::PrefetchBoundStrategy;
2946 using RDom = Halide::RDom;
2947 using RVar = Halide::RVar;
2948 using TailStrategy = Halide::TailStrategy;
2949 using Target = Halide::Target;
2950 using Tuple = Halide::Tuple;
2951 using Type = Halide::Type;
2952 using Var = Halide::Var;
2953 template<typename T>
2954 static Expr cast(Expr e) {
2955 return Halide::cast<T>(e);
2956 }
2957 static inline Expr cast(Halide::Type t, Expr e) {
2958 return Halide::cast(t, std::move(e));
2959 }
2960 template<typename T>
2961 using GeneratorParam = Halide::GeneratorParam<T>;
2962 template<typename T = void>
2963 using Buffer = Halide::Buffer<T>;
2964 template<typename T>
2965 using Param = Halide::Param<T>;
2966 static inline Type Bool(int lanes = 1) {
2967 return Halide::Bool(lanes);
2968 }
2969 static inline Type Float(int bits, int lanes = 1) {
2970 return Halide::Float(bits, lanes);
2971 }
2972 static inline Type Int(int bits, int lanes = 1) {
2973 return Halide::Int(bits, lanes);
2974 }
2975 static inline Type UInt(int bits, int lanes = 1) {
2976 return Halide::UInt(bits, lanes);
2977 }
2978 };
2979
2980 namespace Internal {
2981
2982 template<typename... Args>
2983 struct NoRealizations : std::false_type {};
2984
2985 template<>
2986 struct NoRealizations<> : std::true_type {};
2987
2988 template<typename T, typename... Args>
2989 struct NoRealizations<T, Args...> {
2990 static const bool value = !std::is_convertible<T, Realization>::value && NoRealizations<Args...>::value;
2991 };
2992
2993 class GeneratorStub;
2994 class SimpleGeneratorFactory;
2995
2996 // Note that these functions must never return null:
2997 // if they cannot return a valid Generator, they must assert-fail.
2998 using GeneratorFactory = std::function<std::unique_ptr<GeneratorBase>(const GeneratorContext &)>;
2999
3000 struct StringOrLoopLevel {
3001 std::string string_value;
3002 LoopLevel loop_level;
3003
3004 StringOrLoopLevel() = default;
3005 /*not-explicit*/ StringOrLoopLevel(const char *s)
3006 : string_value(s) {
3007 }
3008 /*not-explicit*/ StringOrLoopLevel(const std::string &s)
3009 : string_value(s) {
3010 }
3011 /*not-explicit*/ StringOrLoopLevel(const LoopLevel &loop_level)
3012 : loop_level(loop_level) {
3013 }
3014 };
3015 using GeneratorParamsMap = std::map<std::string, StringOrLoopLevel>;
3016
3017 class GeneratorParamInfo {
3018 // names used across all params, inputs, and outputs.
3019 std::set<std::string> names;
3020
3021 // Ordered-list of non-null ptrs to GeneratorParam<> fields.
3022 std::vector<Internal::GeneratorParamBase *> filter_generator_params;
3023
3024 // Ordered-list of non-null ptrs to Input<> fields.
3025 std::vector<Internal::GeneratorInputBase *> filter_inputs;
3026
3027 // Ordered-list of non-null ptrs to Output<> fields; empty if old-style Generator.
3028 std::vector<Internal::GeneratorOutputBase *> filter_outputs;
3029
3030 // list of synthetic GP's that we dynamically created; this list only exists to simplify
3031 // lifetime management, and shouldn't be accessed directly outside of our ctor/dtor,
3032 // regardless of friend access.
3033 std::vector<std::unique_ptr<Internal::GeneratorParamBase>> owned_synthetic_params;
3034
3035 // list of dynamically-added inputs and outputs, here only for lifetime management.
3036 std::vector<std::unique_ptr<Internal::GIOBase>> owned_extras;
3037
3038 public:
3039 friend class GeneratorBase;
3040
3041 GeneratorParamInfo(GeneratorBase *generator, const size_t size);
3042
3043 const std::vector<Internal::GeneratorParamBase *> &generator_params() const {
3044 return filter_generator_params;
3045 }
3046 const std::vector<Internal::GeneratorInputBase *> &inputs() const {
3047 return filter_inputs;
3048 }
3049 const std::vector<Internal::GeneratorOutputBase *> &outputs() const {
3050 return filter_outputs;
3051 }
3052 };
3053
3054 class GeneratorBase : public NamesInterface, public GeneratorContext {
3055 public:
3056 ~GeneratorBase() override;
3057
3058 void set_generator_param_values(const GeneratorParamsMap ¶ms);
3059
3060 /** Given a data type, return an estimate of the "natural" vector size
3061 * for that data type when compiling for the current target. */
3062 int natural_vector_size(Halide::Type t) const {
3063 return get_target().natural_vector_size(t);
3064 }
3065
3066 /** Given a data type, return an estimate of the "natural" vector size
3067 * for that data type when compiling for the current target. */
3068 template<typename data_t>
3069 int natural_vector_size() const {
3070 return get_target().natural_vector_size<data_t>();
3071 }
3072
3073 void emit_cpp_stub(const std::string &stub_file_path);
3074
3075 // Call build() and produce a Module for the result.
3076 // If function_name is empty, generator_name() will be used for the function.
3077 Module build_module(const std::string &function_name = "",
3078 const LinkageType linkage_type = LinkageType::ExternalPlusMetadata);
3079
3080 /**
3081 * Build a module that is suitable for using for gradient descent calculation in TensorFlow or PyTorch.
3082 *
3083 * Essentially:
3084 * - A new Pipeline is synthesized from the current Generator (according to the rules below)
3085 * - The new Pipeline is autoscheduled (if autoscheduling is requested, but it would be odd not to do so)
3086 * - The Pipeline is compiled to a Module and returned
3087 *
3088 * The new Pipeline is adjoint to the original; it has:
3089 * - All the same inputs as the original, in the same order
3090 * - Followed by one grad-input for each original output
3091 * - Followed by one output for each unique pairing of original-output + original-input.
3092 * (For the common case of just one original-output, this amounts to being one output for each original-input.)
3093 */
3094 Module build_gradient_module(const std::string &function_name);
3095
3096 /**
3097 * set_inputs is a variadic wrapper around set_inputs_vector, which makes usage much simpler
3098 * in many cases, as it constructs the relevant entries for the vector for you, which
3099 * is often a bit unintuitive at present. The arguments are passed in Input<>-declaration-order,
3100 * and the types must be compatible. Array inputs are passed as std::vector<> of the relevant type.
3101 *
3102 * Note: at present, scalar input types must match *exactly*, i.e., for Input<uint8_t>, you
3103 * must pass an argument that is actually uint8_t; an argument that is int-that-will-fit-in-uint8
3104 * will assert-fail at Halide compile time.
3105 */
3106 template<typename... Args>
3107 void set_inputs(const Args &... args) {
3108 // set_inputs_vector() checks this too, but checking it here allows build_inputs() to avoid out-of-range checks.
3109 GeneratorParamInfo &pi = this->param_info();
3110 user_assert(sizeof...(args) == pi.inputs().size())
3111 << "Expected exactly " << pi.inputs().size()
3112 << " inputs but got " << sizeof...(args) << "\n";
3113 set_inputs_vector(build_inputs(std::forward_as_tuple<const Args &...>(args...), make_index_sequence<sizeof...(Args)>{}));
3114 }
3115
3116 Realization realize(std::vector<int32_t> sizes) {
3117 this->check_scheduled("realize");
3118 return get_pipeline().realize(std::move(sizes), get_target());
3119 }
3120
3121 // Only enable if none of the args are Realization; otherwise we can incorrectly
3122 // select this method instead of the Realization-as-outparam variant
3123 template<typename... Args, typename std::enable_if<NoRealizations<Args...>::value>::type * = nullptr>
3124 Realization realize(Args &&... args) {
3125 this->check_scheduled("realize");
3126 return get_pipeline().realize(std::forward<Args>(args)..., get_target());
3127 }
3128
3129 void realize(Realization r) {
3130 this->check_scheduled("realize");
3131 get_pipeline().realize(r, get_target());
3132 }
3133
3134 // Return the Pipeline that has been built by the generate() method.
3135 // This method can only be used from a Generator that has a generate()
3136 // method (vs a build() method), and currently can only be called from
3137 // the schedule() method. (This may be relaxed in the future to allow
3138 // calling from generate() as long as all Outputs have been defined.)
3139 Pipeline get_pipeline();
3140
3141 // Create Input<Buffer> or Input<Func> with dynamic type
3142 template<typename T,
3143 typename std::enable_if<!std::is_arithmetic<T>::value>::type * = nullptr>
3144 GeneratorInput<T> *add_input(const std::string &name, const Type &t, int dimensions) {
3145 check_exact_phase(GeneratorBase::ConfigureCalled);
3146 auto *p = new GeneratorInput<T>(name, t, dimensions);
3147 p->generator = this;
3148 param_info_ptr->owned_extras.push_back(std::unique_ptr<Internal::GIOBase>(p));
3149 param_info_ptr->filter_inputs.push_back(p);
3150 return p;
3151 }
3152
3153 // Create a Input<Buffer> or Input<Func> with compile-time type
3154 template<typename T,
3155 typename std::enable_if<T::has_static_halide_type>::type * = nullptr>
3156 GeneratorInput<T> *add_input(const std::string &name, int dimensions) {
3157 check_exact_phase(GeneratorBase::ConfigureCalled);
3158 auto *p = new GeneratorInput<T>(name, dimensions);
3159 p->generator = this;
3160 param_info_ptr->owned_extras.push_back(std::unique_ptr<Internal::GIOBase>(p));
3161 param_info_ptr->filter_inputs.push_back(p);
3162 return p;
3163 }
3164
3165 // Create Input<scalar>
3166 template<typename T,
3167 typename std::enable_if<std::is_arithmetic<T>::value>::type * = nullptr>
3168 GeneratorInput<T> *add_input(const std::string &name) {
3169 check_exact_phase(GeneratorBase::ConfigureCalled);
3170 auto *p = new GeneratorInput<T>(name);
3171 p->generator = this;
3172 param_info_ptr->owned_extras.push_back(std::unique_ptr<Internal::GIOBase>(p));
3173 param_info_ptr->filter_inputs.push_back(p);
3174 return p;
3175 }
3176
3177 // Create Output<Buffer> or Output<Func> with dynamic type
3178 template<typename T,
3179 typename std::enable_if<!std::is_arithmetic<T>::value>::type * = nullptr>
3180 GeneratorOutput<T> *add_output(const std::string &name, const Type &t, int dimensions) {
3181 check_exact_phase(GeneratorBase::ConfigureCalled);
3182 auto *p = new GeneratorOutput<T>(name, t, dimensions);
3183 p->generator = this;
3184 param_info_ptr->owned_extras.push_back(std::unique_ptr<Internal::GIOBase>(p));
3185 param_info_ptr->filter_outputs.push_back(p);
3186 return p;
3187 }
3188
3189 // Create a Output<Buffer> or Output<Func> with compile-time type
3190 template<typename T,
3191 typename std::enable_if<T::has_static_halide_type>::type * = nullptr>
3192 GeneratorOutput<T> *add_output(const std::string &name, int dimensions) {
3193 check_exact_phase(GeneratorBase::ConfigureCalled);
3194 auto *p = new GeneratorOutput<T>(name, dimensions);
3195 p->generator = this;
3196 param_info_ptr->owned_extras.push_back(std::unique_ptr<Internal::GIOBase>(p));
3197 param_info_ptr->filter_outputs.push_back(p);
3198 return p;
3199 }
3200
3201 template<typename... Args>
3202 HALIDE_NO_USER_CODE_INLINE void add_requirement(Expr condition, Args &&... args) {
3203 get_pipeline().add_requirement(condition, std::forward<Args>(args)...);
3204 }
3205
3206 void trace_pipeline() {
3207 get_pipeline().trace_pipeline();
3208 }
3209
3210 protected:
3211 GeneratorBase(size_t size, const void *introspection_helper);
3212 void set_generator_names(const std::string ®istered_name, const std::string &stub_name);
3213
3214 void init_from_context(const Halide::GeneratorContext &context) override;
3215
3216 virtual Pipeline build_pipeline() = 0;
3217 virtual void call_configure() = 0;
3218 virtual void call_generate() = 0;
3219 virtual void call_schedule() = 0;
3220
3221 void track_parameter_values(bool include_outputs);
3222
3223 void pre_build();
3224 void post_build();
3225 void pre_configure();
3226 void post_configure();
3227 void pre_generate();
3228 void post_generate();
3229 void pre_schedule();
3230 void post_schedule();
3231
3232 template<typename T>
3233 using Input = GeneratorInput<T>;
3234
3235 template<typename T>
3236 using Output = GeneratorOutput<T>;
3237
3238 // A Generator's creation and usage must go in a certain phase to ensure correctness;
3239 // the state machine here is advanced and checked at various points to ensure
3240 // this is the case.
3241 enum Phase {
3242 // Generator has just come into being.
3243 Created,
3244
3245 // Generator has had its configure() method called. (For Generators without
3246 // a configure() method, this phase will be skipped and will advance
3247 // directly to InputsSet.)
3248 ConfigureCalled,
3249
3250 // All Input<>/Param<> fields have been set. (Applicable only in JIT mode;
3251 // in AOT mode, this can be skipped, going Created->GenerateCalled directly.)
3252 InputsSet,
3253
3254 // Generator has had its generate() method called. (For Generators with
3255 // a build() method instead of generate(), this phase will be skipped
3256 // and will advance directly to ScheduleCalled.)
3257 GenerateCalled,
3258
3259 // Generator has had its schedule() method (if any) called.
3260 ScheduleCalled,
3261 } phase{Created};
3262
3263 void check_exact_phase(Phase expected_phase) const;
3264 void check_min_phase(Phase expected_phase) const;
3265 void advance_phase(Phase new_phase);
3266
3267 private:
3268 friend void ::Halide::Internal::generator_test();
3269 friend class GeneratorParamBase;
3270 friend class GIOBase;
3271 friend class GeneratorInputBase;
3272 friend class GeneratorOutputBase;
3273 friend class GeneratorParamInfo;
3274 friend class GeneratorStub;
3275 friend class SimpleGeneratorFactory;
3276 friend class StubOutputBufferBase;
3277
3278 const size_t size;
3279
3280 // Lazily-allocated-and-inited struct with info about our various Params.
3281 // Do not access directly: use the param_info() getter.
3282 std::unique_ptr<GeneratorParamInfo> param_info_ptr;
3283
3284 mutable std::shared_ptr<ExternsMap> externs_map;
3285
3286 bool inputs_set{false};
3287 std::string generator_registered_name, generator_stub_name;
3288 Pipeline pipeline;
3289
3290 // Return our GeneratorParamInfo.
3291 GeneratorParamInfo ¶m_info();
3292
3293 Internal::GeneratorOutputBase *find_output_by_name(const std::string &name);
3294
3295 void check_scheduled(const char *m) const;
3296
3297 void build_params(bool force = false);
3298
3299 // Provide private, unimplemented, wrong-result-type methods here
3300 // so that Generators don't attempt to call the global methods
3301 // of the same name by accident: use the get_target() method instead.
3302 void get_host_target();
3303 void get_jit_target_from_environment();
3304 void get_target_from_environment();
3305
3306 // Return the Output<Func> or Output<Buffer> with the given name,
3307 // which must be a singular (non-array) Func or Buffer output.
3308 // If no such name exists (or is non-array), assert; this method never returns an undefined Func.
3309 Func get_output(const std::string &n);
3310
3311 // Return the Output<Func[]> with the given name, which must be an
3312 // array-of-Func output. If no such name exists (or is non-array), assert;
3313 // this method never returns undefined Funcs.
3314 std::vector<Func> get_array_output(const std::string &n);
3315
3316 void set_inputs_vector(const std::vector<std::vector<StubInput>> &inputs);
3317
3318 static void check_input_is_singular(Internal::GeneratorInputBase *in);
3319 static void check_input_is_array(Internal::GeneratorInputBase *in);
3320 static void check_input_kind(Internal::GeneratorInputBase *in, Internal::IOKind kind);
3321
3322 // Allow Buffer<> if:
3323 // -- we are assigning it to an Input<Buffer<>> (with compatible type and dimensions),
3324 // causing the Input<Buffer<>> to become a precompiled buffer in the generated code.
3325 // -- we are assigningit to an Input<Func>, in which case we just Func-wrap the Buffer<>.
3326 template<typename T>
3327 std::vector<StubInput> build_input(size_t i, const Buffer<T> &arg) {
3328 auto *in = param_info().inputs().at(i);
3329 check_input_is_singular(in);
3330 const auto k = in->kind();
3331 if (k == Internal::IOKind::Buffer) {
3332 Halide::Buffer<> b = arg;
3333 StubInputBuffer<> sib(b);
3334 StubInput si(sib);
3335 return {si};
3336 } else if (k == Internal::IOKind::Function) {
3337 Halide::Func f(arg.name() + "_im");
3338 f(Halide::_) = arg(Halide::_);
3339 StubInput si(f);
3340 return {si};
3341 } else {
3342 check_input_kind(in, Internal::IOKind::Buffer); // just to trigger assertion
3343 return {};
3344 }
3345 }
3346
3347 // Allow Input<Buffer<>> if:
3348 // -- we are assigning it to another Input<Buffer<>> (with compatible type and dimensions),
3349 // allowing us to simply pipe a parameter from an enclosing Generator to the Invoker.
3350 // -- we are assigningit to an Input<Func>, in which case we just Func-wrap the Input<Buffer<>>.
3351 template<typename T>
3352 std::vector<StubInput> build_input(size_t i, const GeneratorInput<Buffer<T>> &arg) {
3353 auto *in = param_info().inputs().at(i);
3354 check_input_is_singular(in);
3355 const auto k = in->kind();
3356 if (k == Internal::IOKind::Buffer) {
3357 StubInputBuffer<> sib = arg;
3358 StubInput si(sib);
3359 return {si};
3360 } else if (k == Internal::IOKind::Function) {
3361 Halide::Func f = arg.funcs().at(0);
3362 StubInput si(f);
3363 return {si};
3364 } else {
3365 check_input_kind(in, Internal::IOKind::Buffer); // just to trigger assertion
3366 return {};
3367 }
3368 }
3369
3370 // Allow Func iff we are assigning it to an Input<Func> (with compatible type and dimensions).
3371 std::vector<StubInput> build_input(size_t i, const Func &arg) {
3372 auto *in = param_info().inputs().at(i);
3373 check_input_kind(in, Internal::IOKind::Function);
3374 check_input_is_singular(in);
3375 const Halide::Func &f = arg;
3376 StubInput si(f);
3377 return {si};
3378 }
3379
3380 // Allow vector<Func> iff we are assigning it to an Input<Func[]> (with compatible type and dimensions).
3381 std::vector<StubInput> build_input(size_t i, const std::vector<Func> &arg) {
3382 auto *in = param_info().inputs().at(i);
3383 check_input_kind(in, Internal::IOKind::Function);
3384 check_input_is_array(in);
3385 // My kingdom for a list comprehension...
3386 std::vector<StubInput> siv;
3387 siv.reserve(arg.size());
3388 for (const auto &f : arg) {
3389 siv.emplace_back(f);
3390 }
3391 return siv;
3392 }
3393
3394 // Expr must be Input<Scalar>.
3395 std::vector<StubInput> build_input(size_t i, const Expr &arg) {
3396 auto *in = param_info().inputs().at(i);
3397 check_input_kind(in, Internal::IOKind::Scalar);
3398 check_input_is_singular(in);
3399 StubInput si(arg);
3400 return {si};
3401 }
3402
3403 // (Array form)
3404 std::vector<StubInput> build_input(size_t i, const std::vector<Expr> &arg) {
3405 auto *in = param_info().inputs().at(i);
3406 check_input_kind(in, Internal::IOKind::Scalar);
3407 check_input_is_array(in);
3408 std::vector<StubInput> siv;
3409 siv.reserve(arg.size());
3410 for (const auto &value : arg) {
3411 siv.emplace_back(value);
3412 }
3413 return siv;
3414 }
3415
3416 // Any other type must be convertible to Expr and must be associated with an Input<Scalar>.
3417 // Use is_arithmetic since some Expr conversions are explicit.
3418 template<typename T,
3419 typename std::enable_if<std::is_arithmetic<T>::value>::type * = nullptr>
3420 std::vector<StubInput> build_input(size_t i, const T &arg) {
3421 auto *in = param_info().inputs().at(i);
3422 check_input_kind(in, Internal::IOKind::Scalar);
3423 check_input_is_singular(in);
3424 // We must use an explicit Expr() ctor to preserve the type
3425 Expr e(arg);
3426 StubInput si(e);
3427 return {si};
3428 }
3429
3430 // (Array form)
3431 template<typename T,
3432 typename std::enable_if<std::is_arithmetic<T>::value>::type * = nullptr>
3433 std::vector<StubInput> build_input(size_t i, const std::vector<T> &arg) {
3434 auto *in = param_info().inputs().at(i);
3435 check_input_kind(in, Internal::IOKind::Scalar);
3436 check_input_is_array(in);
3437 std::vector<StubInput> siv;
3438 siv.reserve(arg.size());
3439 for (const auto &value : arg) {
3440 // We must use an explicit Expr() ctor to preserve the type;
3441 // otherwise, implicit conversions can downgrade (e.g.) float -> int
3442 Expr e(value);
3443 siv.emplace_back(e);
3444 }
3445 return siv;
3446 }
3447
3448 template<typename... Args, size_t... Indices>
3449 std::vector<std::vector<StubInput>> build_inputs(const std::tuple<const Args &...> &t, index_sequence<Indices...>) {
3450 return {build_input(Indices, std::get<Indices>(t))...};
3451 }
3452
3453 // No copy
3454 GeneratorBase(const GeneratorBase &) = delete;
3455 void operator=(const GeneratorBase &) = delete;
3456 // No move
3457 GeneratorBase(GeneratorBase &&that) = delete;
3458 void operator=(GeneratorBase &&that) = delete;
3459 };
3460
3461 class GeneratorRegistry {
3462 public:
3463 static void register_factory(const std::string &name, GeneratorFactory generator_factory);
3464 static void unregister_factory(const std::string &name);
3465 static std::vector<std::string> enumerate();
3466 // Note that this method will never return null:
3467 // if it cannot return a valid Generator, it should assert-fail.
3468 static std::unique_ptr<GeneratorBase> create(const std::string &name,
3469 const Halide::GeneratorContext &context);
3470
3471 private:
3472 using GeneratorFactoryMap = std::map<const std::string, GeneratorFactory>;
3473
3474 GeneratorFactoryMap factories;
3475 std::mutex mutex;
3476
3477 static GeneratorRegistry &get_registry();
3478
3479 GeneratorRegistry() = default;
3480 GeneratorRegistry(const GeneratorRegistry &) = delete;
3481 void operator=(const GeneratorRegistry &) = delete;
3482 };
3483
3484 } // namespace Internal
3485
3486 template<class T>
3487 class Generator : public Internal::GeneratorBase {
3488 protected:
3489 Generator()
3490 : Internal::GeneratorBase(sizeof(T),
3491 Internal::Introspection::get_introspection_helper<T>()) {
3492 }
3493
3494 public:
3495 static std::unique_ptr<T> create(const Halide::GeneratorContext &context) {
3496 // We must have an object of type T (not merely GeneratorBase) to call a protected method,
3497 // because CRTP is a weird beast.
3498 auto g = std::unique_ptr<T>(new T());
3499 g->init_from_context(context);
3500 return g;
3501 }
3502
3503 // This is public but intended only for use by the HALIDE_REGISTER_GENERATOR() macro.
3504 static std::unique_ptr<T> create(const Halide::GeneratorContext &context,
3505 const std::string ®istered_name,
3506 const std::string &stub_name) {
3507 auto g = create(context);
3508 g->set_generator_names(registered_name, stub_name);
3509 return g;
3510 }
3511
3512 using Internal::GeneratorBase::apply;
3513 using Internal::GeneratorBase::create;
3514
3515 template<typename... Args>
3516 void apply(const Args &... args) {
3517 #ifndef _MSC_VER
3518 // VS2015 apparently has some SFINAE issues, so this can inappropriately
3519 // trigger there. (We'll still fail when generate() is called, just
3520 // with a less-helpful error message.)
3521 static_assert(has_generate_method<T>::value, "apply() is not supported for old-style Generators.");
3522 #endif
3523 call_configure();
3524 set_inputs(args...);
3525 call_generate();
3526 call_schedule();
3527 }
3528
3529 private:
3530 // std::is_member_function_pointer will fail if there is no member of that name,
3531 // so we use a little SFINAE to detect if there are method-shaped members.
3532 template<typename>
3533 struct type_sink { typedef void type; };
3534
3535 template<typename T2, typename = void>
3536 struct has_configure_method : std::false_type {};
3537
3538 template<typename T2>
3539 struct has_configure_method<T2, typename type_sink<decltype(std::declval<T2>().configure())>::type> : std::true_type {};
3540
3541 template<typename T2, typename = void>
3542 struct has_generate_method : std::false_type {};
3543
3544 template<typename T2>
3545 struct has_generate_method<T2, typename type_sink<decltype(std::declval<T2>().generate())>::type> : std::true_type {};
3546
3547 template<typename T2, typename = void>
3548 struct has_schedule_method : std::false_type {};
3549
3550 template<typename T2>
3551 struct has_schedule_method<T2, typename type_sink<decltype(std::declval<T2>().schedule())>::type> : std::true_type {};
3552
3553 template<typename T2 = T,
3554 typename std::enable_if<!has_generate_method<T2>::value>::type * = nullptr>
3555
3556 // Implementations for build_pipeline_impl(), specialized on whether we
3557 // have build() or generate()/schedule() methods.
3558
3559 // MSVC apparently has some weirdness with the usual sfinae tricks
3560 // for detecting method-shaped things, so we can't actually use
3561 // the helpers above outside of static_assert. Instead we make as
3562 // many overloads as we can exist, and then use C++'s preference
3563 // for treating a 0 as an int rather than a double to choose one
3564 // of them.
3565 Pipeline build_pipeline_impl(double) {
3566 static_assert(!has_configure_method<T2>::value, "The configure() method is ignored if you define a build() method; use generate() instead.");
3567 static_assert(!has_schedule_method<T2>::value, "The schedule() method is ignored if you define a build() method; use generate() instead.");
3568 pre_build();
3569 Pipeline p = ((T *)this)->build();
3570 post_build();
3571 return p;
3572 }
3573
3574 template<typename T2 = T,
3575 typename = decltype(std::declval<T2>().generate())>
3576 Pipeline build_pipeline_impl(int) {
3577 // No: configure() must be called prior to this
3578 // (and in fact, prior to calling set_inputs).
3579 //
3580 // ((T *)this)->call_configure_impl(0, 0);
3581
3582 ((T *)this)->call_generate_impl(0);
3583 ((T *)this)->call_schedule_impl(0, 0);
3584 return get_pipeline();
3585 }
3586
3587 // Implementations for call_configure_impl(), specialized on whether we
3588 // have build() or configure()/generate()/schedule() methods.
3589
3590 void call_configure_impl(double, double) {
3591 // Called as a side effect for build()-method Generators; quietly do nothing.
3592 }
3593
3594 template<typename T2 = T,
3595 typename = decltype(std::declval<T2>().generate())>
3596 void call_configure_impl(double, int) {
3597 // Generator has a generate() method but no configure() method. This is ok. Just advance the phase.
3598 pre_configure();
3599 static_assert(!has_configure_method<T2>::value, "Did not expect a configure method here.");
3600 post_configure();
3601 }
3602
3603 template<typename T2 = T,
3604 typename = decltype(std::declval<T2>().generate()),
3605 typename = decltype(std::declval<T2>().configure())>
3606 void call_configure_impl(int, int) {
3607 T *t = (T *)this;
3608 static_assert(std::is_void<decltype(t->configure())>::value, "configure() must return void");
3609 pre_configure();
3610 t->configure();
3611 post_configure();
3612 }
3613
3614 // Implementations for call_generate_impl(), specialized on whether we
3615 // have build() or configure()/generate()/schedule() methods.
3616
3617 void call_generate_impl(double) {
3618 user_error << "Unimplemented";
3619 }
3620
3621 template<typename T2 = T,
3622 typename = decltype(std::declval<T2>().generate())>
3623 void call_generate_impl(int) {
3624 T *t = (T *)this;
3625 static_assert(std::is_void<decltype(t->generate())>::value, "generate() must return void");
3626 pre_generate();
3627 t->generate();
3628 post_generate();
3629 }
3630
3631 // Implementations for call_schedule_impl(), specialized on whether we
3632 // have build() or configure()generate()/schedule() methods.
3633
3634 void call_schedule_impl(double, double) {
3635 user_error << "Unimplemented";
3636 }
3637
3638 template<typename T2 = T,
3639 typename = decltype(std::declval<T2>().generate())>
3640 void call_schedule_impl(double, int) {
3641 // Generator has a generate() method but no schedule() method. This is ok. Just advance the phase.
3642 pre_schedule();
3643 post_schedule();
3644 }
3645
3646 template<typename T2 = T,
3647 typename = decltype(std::declval<T2>().generate()),
3648 typename = decltype(std::declval<T2>().schedule())>
3649 void call_schedule_impl(int, int) {
3650 T *t = (T *)this;
3651 static_assert(std::is_void<decltype(t->schedule())>::value, "schedule() must return void");
3652 pre_schedule();
3653 t->schedule();
3654 post_schedule();
3655 }
3656
3657 protected:
3658 Pipeline build_pipeline() override {
3659 return this->build_pipeline_impl(0);
3660 }
3661
3662 void call_configure() override {
3663 this->call_configure_impl(0, 0);
3664 }
3665
3666 void call_generate() override {
3667 this->call_generate_impl(0);
3668 }
3669
3670 void call_schedule() override {
3671 this->call_schedule_impl(0, 0);
3672 }
3673
3674 private:
3675 friend void ::Halide::Internal::generator_test();
3676 friend class Internal::SimpleGeneratorFactory;
3677 friend void ::Halide::Internal::generator_test();
3678 friend class ::Halide::GeneratorContext;
3679
3680 // No copy
3681 Generator(const Generator &) = delete;
3682 void operator=(const Generator &) = delete;
3683 // No move
3684 Generator(Generator &&that) = delete;
3685 void operator=(Generator &&that) = delete;
3686 };
3687
3688 namespace Internal {
3689
3690 class RegisterGenerator {
3691 public:
3692 RegisterGenerator(const char *registered_name, GeneratorFactory generator_factory) {
3693 Internal::GeneratorRegistry::register_factory(registered_name, std::move(generator_factory));
3694 }
3695 };
3696
3697 class GeneratorStub : public NamesInterface {
3698 public:
3699 GeneratorStub(const GeneratorContext &context,
3700 const GeneratorFactory &generator_factory);
3701
3702 GeneratorStub(const GeneratorContext &context,
3703 const GeneratorFactory &generator_factory,
3704 const GeneratorParamsMap &generator_params,
3705 const std::vector<std::vector<Internal::StubInput>> &inputs);
3706 std::vector<std::vector<Func>> generate(const GeneratorParamsMap &generator_params,
3707 const std::vector<std::vector<Internal::StubInput>> &inputs);
3708
3709 // Output(s)
3710 // TODO: identify vars used
3711 Func get_output(const std::string &n) const {
3712 return generator->get_output(n);
3713 }
3714
3715 template<typename T2>
3716 T2 get_output_buffer(const std::string &n) const {
3717 return T2(get_output(n), generator);
3718 }
3719
3720 template<typename T2>
3721 std::vector<T2> get_array_output_buffer(const std::string &n) const {
3722 auto v = generator->get_array_output(n);
3723 std::vector<T2> result;
3724 for (auto &o : v) {
3725 result.push_back(T2(o, generator));
3726 }
3727 return result;
3728 }
3729
3730 std::vector<Func> get_array_output(const std::string &n) const {
3731 return generator->get_array_output(n);
3732 }
3733
3734 static std::vector<StubInput> to_stub_input_vector(const Expr &e) {
3735 return {StubInput(e)};
3736 }
3737
3738 static std::vector<StubInput> to_stub_input_vector(const Func &f) {
3739 return {StubInput(f)};
3740 }
3741
3742 template<typename T = void>
3743 static std::vector<StubInput> to_stub_input_vector(const StubInputBuffer<T> &b) {
3744 return {StubInput(b)};
3745 }
3746
3747 template<typename T>
3748 static std::vector<StubInput> to_stub_input_vector(const std::vector<T> &v) {
3749 std::vector<StubInput> r;
3750 std::copy(v.begin(), v.end(), std::back_inserter(r));
3751 return r;
3752 }
3753
3754 struct Names {
3755 std::vector<std::string> generator_params, inputs, outputs;
3756 };
3757 Names get_names() const;
3758
3759 std::shared_ptr<GeneratorBase> generator;
3760 };
3761
3762 } // namespace Internal
3763
3764 } // namespace Halide
3765
3766 // Define this namespace at global scope so that anonymous namespaces won't
3767 // defeat our static_assert check; define a dummy type inside so we can
3768 // check for type aliasing injected by anonymous namespace usage
3769 namespace halide_register_generator {
3770 struct halide_global_ns;
3771 };
3772
3773 #define _HALIDE_REGISTER_GENERATOR_IMPL(GEN_CLASS_NAME, GEN_REGISTRY_NAME, FULLY_QUALIFIED_STUB_NAME) \
3774 namespace halide_register_generator { \
3775 struct halide_global_ns; \
3776 namespace GEN_REGISTRY_NAME##_ns { \
3777 std::unique_ptr<Halide::Internal::GeneratorBase> factory(const Halide::GeneratorContext &context); \
3778 std::unique_ptr<Halide::Internal::GeneratorBase> factory(const Halide::GeneratorContext &context) { \
3779 return GEN_CLASS_NAME::create(context, #GEN_REGISTRY_NAME, #FULLY_QUALIFIED_STUB_NAME); \
3780 } \
3781 } \
3782 static auto reg_##GEN_REGISTRY_NAME = Halide::Internal::RegisterGenerator(#GEN_REGISTRY_NAME, GEN_REGISTRY_NAME##_ns::factory); \
3783 } \
3784 static_assert(std::is_same<::halide_register_generator::halide_global_ns, halide_register_generator::halide_global_ns>::value, \
3785 "HALIDE_REGISTER_GENERATOR must be used at global scope");
3786
3787 #define _HALIDE_REGISTER_GENERATOR2(GEN_CLASS_NAME, GEN_REGISTRY_NAME) \
3788 _HALIDE_REGISTER_GENERATOR_IMPL(GEN_CLASS_NAME, GEN_REGISTRY_NAME, GEN_REGISTRY_NAME)
3789
3790 #define _HALIDE_REGISTER_GENERATOR3(GEN_CLASS_NAME, GEN_REGISTRY_NAME, FULLY_QUALIFIED_STUB_NAME) \
3791 _HALIDE_REGISTER_GENERATOR_IMPL(GEN_CLASS_NAME, GEN_REGISTRY_NAME, FULLY_QUALIFIED_STUB_NAME)
3792
3793 // MSVC has a broken implementation of variadic macros: it expands __VA_ARGS__
3794 // as a single token in argument lists (rather than multiple tokens).
3795 // Jump through some hoops to work around this.
3796 #define __HALIDE_REGISTER_ARGCOUNT_IMPL(_1, _2, _3, COUNT, ...) \
3797 COUNT
3798
3799 #define _HALIDE_REGISTER_ARGCOUNT_IMPL(ARGS) \
3800 __HALIDE_REGISTER_ARGCOUNT_IMPL ARGS
3801
3802 #define _HALIDE_REGISTER_ARGCOUNT(...) \
3803 _HALIDE_REGISTER_ARGCOUNT_IMPL((__VA_ARGS__, 3, 2, 1, 0))
3804
3805 #define ___HALIDE_REGISTER_CHOOSER(COUNT) \
3806 _HALIDE_REGISTER_GENERATOR##COUNT
3807
3808 #define __HALIDE_REGISTER_CHOOSER(COUNT) \
3809 ___HALIDE_REGISTER_CHOOSER(COUNT)
3810
3811 #define _HALIDE_REGISTER_CHOOSER(COUNT) \
3812 __HALIDE_REGISTER_CHOOSER(COUNT)
3813
3814 #define _HALIDE_REGISTER_GENERATOR_PASTE(A, B) \
3815 A B
3816
3817 #define HALIDE_REGISTER_GENERATOR(...) \
3818 _HALIDE_REGISTER_GENERATOR_PASTE(_HALIDE_REGISTER_CHOOSER(_HALIDE_REGISTER_ARGCOUNT(__VA_ARGS__)), (__VA_ARGS__))
3819
3820 // HALIDE_REGISTER_GENERATOR_ALIAS() can be used to create an an alias-with-a-particular-set-of-param-values
3821 // for a given Generator in the build system. Normally, you wouldn't want to do this;
3822 // however, some existing Halide clients have build systems that make it challenging to
3823 // specify GeneratorParams inside the build system, and this allows a somewhat simpler
3824 // customization route for them. It's highly recommended you don't use this for new code.
3825 //
3826 // The final argument is really an initializer-list of GeneratorParams, in the form
3827 // of an initializer-list for map<string, string>:
3828 //
3829 // { { "gp-name", "gp-value"} [, { "gp2-name", "gp2-value" }] }
3830 //
3831 // It is specified as a variadic template argument to allow for the fact that the embedded commas
3832 // would otherwise confuse the preprocessor; since (in this case) all we're going to do is
3833 // pass it thru as-is, this is fine (and even MSVC's 'broken' __VA_ARGS__ should be OK here).
3834 #define HALIDE_REGISTER_GENERATOR_ALIAS(GEN_REGISTRY_NAME, ORIGINAL_REGISTRY_NAME, ...) \
3835 namespace halide_register_generator { \
3836 struct halide_global_ns; \
3837 namespace ORIGINAL_REGISTRY_NAME##_ns { \
3838 std::unique_ptr<Halide::Internal::GeneratorBase> factory(const Halide::GeneratorContext &context); \
3839 } \
3840 namespace GEN_REGISTRY_NAME##_ns { \
3841 std::unique_ptr<Halide::Internal::GeneratorBase> factory(const Halide::GeneratorContext &context); \
3842 std::unique_ptr<Halide::Internal::GeneratorBase> factory(const Halide::GeneratorContext &context) { \
3843 auto g = ORIGINAL_REGISTRY_NAME##_ns::factory(context); \
3844 g->set_generator_param_values(__VA_ARGS__); \
3845 return g; \
3846 } \
3847 } \
3848 static auto reg_##GEN_REGISTRY_NAME = Halide::Internal::RegisterGenerator(#GEN_REGISTRY_NAME, GEN_REGISTRY_NAME##_ns::factory); \
3849 } \
3850 static_assert(std::is_same<::halide_register_generator::halide_global_ns, halide_register_generator::halide_global_ns>::value, \
3851 "HALIDE_REGISTER_GENERATOR_ALIAS must be used at global scope");
3852
3853 #endif // HALIDE_GENERATOR_H_
3854