1 #ifndef HALIDE_GENERATOR_H_
2 #define HALIDE_GENERATOR_H_
3 
4 /** \file
5  *
6  * Generator is a class used to encapsulate the building of Funcs in user
7  * pipelines. A Generator is agnostic to JIT vs AOT compilation; it can be used for
8  * either purpose, but is especially convenient to use for AOT compilation.
9  *
10  * A Generator explicitly declares the Inputs and Outputs associated for a given
11  * pipeline, and (optionally) separates the code for constructing the outputs from the code from
12  * scheduling them. For instance:
13  *
14  * \code
15  *     class Blur : public Generator<Blur> {
16  *     public:
17  *         Input<Func> input{"input", UInt(16), 2};
18  *         Output<Func> output{"output", UInt(16), 2};
19  *         void generate() {
20  *             blur_x(x, y) = (input(x, y) + input(x+1, y) + input(x+2, y))/3;
21  *             blur_y(x, y) = (blur_x(x, y) + blur_x(x, y+1) + blur_x(x, y+2))/3;
22  *             output(x, y) = blur(x, y);
23  *         }
24  *         void schedule() {
25  *             blur_y.split(y, y, yi, 8).parallel(y).vectorize(x, 8);
26  *             blur_x.store_at(blur_y, y).compute_at(blur_y, yi).vectorize(x, 8);
27  *         }
28  *     private:
29  *         Var x, y, xi, yi;
30  *         Func blur_x, blur_y;
31  *     };
32  * \endcode
33  *
34  * Halide can compile a Generator into the correct pipeline by introspecting these
35  * values and constructing an appropriate signature based on them.
36  *
37  * A Generator provides implementations of two methods:
38  *
39  *   - generate(), which must fill in all Output Func(s); it may optionally also do scheduling
40  *   if no schedule() method is present.
41  *   - schedule(), which (if present) should contain all scheduling code.
42  *
43  * Inputs can be any C++ scalar type:
44  *
45  * \code
46  *     Input<float> radius{"radius"};
47  *     Input<int32_t> increment{"increment"};
48  * \endcode
49  *
50  * An Input<Func> is (essentially) like an ImageParam, except that it may (or may
51  * not) not be backed by an actual buffer, and thus has no defined extents.
52  *
53  * \code
54  *     Input<Func> input{"input", Float(32), 2};
55  * \endcode
56  *
57  * You can optionally make the type and/or dimensions of Input<Func> unspecified,
58  * in which case the value is simply inferred from the actual Funcs passed to them.
59  * Of course, if you specify an explicit Type or Dimension, we still require the
60  * input Func to match, or a compilation error results.
61  *
62  * \code
63  *     Input<Func> input{ "input", 3 };  // require 3-dimensional Func,
64  *                                       // but leave Type unspecified
65  * \endcode
66  *
67  * A Generator must explicitly list the output(s) it produces:
68  *
69  * \code
70  *     Output<Func> output{"output", Float(32), 2};
71  * \endcode
72  *
73  * You can specify an output that returns a Tuple by specifying a list of Types:
74  *
75  * \code
76  *     class Tupler : Generator<Tupler> {
77  *       Input<Func> input{"input", Int(32), 2};
78  *       Output<Func> output{"output", {Float(32), UInt(8)}, 2};
79  *       void generate() {
80  *         Var x, y;
81  *         Expr a = cast<float>(input(x, y));
82  *         Expr b = cast<uint8_t>(input(x, y));
83  *         output(x, y) = Tuple(a, b);
84  *       }
85  *     };
86  * \endcode
87  *
88  * You can also specify Output<X> for any scalar type (except for Handle types);
89  * this is merely syntactic sugar on top of a zero-dimensional Func, but can be
90  * quite handy, especially when used with multiple outputs:
91  *
92  * \code
93  *     Output<float> sum{"sum"};  // equivalent to Output<Func> {"sum", Float(32), 0}
94  * \endcode
95  *
96  * As with Input<Func>, you can optionally make the type and/or dimensions of an
97  * Output<Func> unspecified; any unspecified types must be resolved via an
98  * implicit GeneratorParam in order to use top-level compilation.
99  *
100  * You can also declare an *array* of Input or Output, by using an array type
101  * as the type parameter:
102  *
103  * \code
104  *     // Takes exactly 3 images and outputs exactly 3 sums.
105  *     class SumRowsAndColumns : Generator<SumRowsAndColumns> {
106  *       Input<Func[3]> inputs{"inputs", Float(32), 2};
107  *       Input<int32_t[2]> extents{"extents"};
108  *       Output<Func[3]> sums{"sums", Float(32), 1};
109  *       void generate() {
110  *         assert(inputs.size() == sums.size());
111  *         // assume all inputs are same extent
112  *         Expr width = extent[0];
113  *         Expr height = extent[1];
114  *         for (size_t i = 0; i < inputs.size(); ++i) {
115  *           RDom r(0, width, 0, height);
116  *           sums[i]() = 0.f;
117  *           sums[i]() += inputs[i](r.x, r.y);
118  *          }
119  *       }
120  *     };
121  * \endcode
122  *
123  * You can also leave array size unspecified, with some caveats:
124  *   - For ahead-of-time compilation, Inputs must have a concrete size specified
125  *     via a GeneratorParam at build time (e.g., pyramid.size=3)
126  *   - For JIT compilation via a Stub, Inputs array sizes will be inferred
127  *     from the vector passed.
128  *   - For ahead-of-time compilation, Outputs may specify a concrete size
129  *     via a GeneratorParam at build time (e.g., pyramid.size=3), or the
130  *     size can be specified via a resize() method.
131  *
132  * \code
133  *     class Pyramid : public Generator<Pyramid> {
134  *     public:
135  *         GeneratorParam<int32_t> levels{"levels", 10};
136  *         Input<Func> input{ "input", Float(32), 2 };
137  *         Output<Func[]> pyramid{ "pyramid", Float(32), 2 };
138  *         void generate() {
139  *             pyramid.resize(levels);
140  *             pyramid[0](x, y) = input(x, y);
141  *             for (int i = 1; i < pyramid.size(); i++) {
142  *                 pyramid[i](x, y) = (pyramid[i-1](2*x, 2*y) +
143  *                                    pyramid[i-1](2*x+1, 2*y) +
144  *                                    pyramid[i-1](2*x, 2*y+1) +
145  *                                    pyramid[i-1](2*x+1, 2*y+1))/4;
146  *             }
147  *         }
148  *     };
149  * \endcode
150  *
151  * A Generator can also be customized via compile-time parameters (GeneratorParams),
152  * which affect code generation.
153  *
154  * GeneratorParams, Inputs, and Outputs are (by convention) always
155  * public and always declared at the top of the Generator class, in the order
156  *
157  * \code
158  *     GeneratorParam(s)
159  *     Input<Func>(s)
160  *     Input<non-Func>(s)
161  *     Output<Func>(s)
162  * \endcode
163  *
164  * Note that the Inputs and Outputs will appear in the C function call in the order
165  * they are declared. All Input<Func> and Output<Func> are represented as halide_buffer_t;
166  * all other Input<> are the appropriate C++ scalar type. (GeneratorParams are
167  * always referenced by name, not position, so their order is irrelevant.)
168  *
169  * All Inputs and Outputs must have explicit names, and all such names must match
170  * the regex [A-Za-z][A-Za-z_0-9]* (i.e., essentially a C/C++ variable name, with
171  * some extra restrictions on underscore use). By convention, the name should match
172  * the member-variable name.
173  *
174  * You can dynamically add Inputs and Outputs to your Generator via adding a
175  * configure() method; if present, it will be called before generate(). It can
176  * examine GeneratorParams but it may not examine predeclared Inputs or Outputs;
177  * the only thing it should do is call add_input<>() and/or add_output<>().
178  * Added inputs will be appended (in order) after predeclared Inputs but before
179  * any Outputs; added outputs will be appended after predeclared Outputs.
180  *
181  * Note that the pointers returned by add_input() and add_output() are owned
182  * by the Generator and will remain valid for the Generator's lifetime; user code
183  * should not attempt to delete or free them.
184  *
185  * \code
186  *     class MultiSum : public Generator<MultiSum> {
187  *     public:
188  *         GeneratorParam<int32_t> input_count{"input_count", 10};
189  *         Output<Func> output{ "output", Float(32), 2 };
190  *
191  *         void configure() {
192  *             for (int i = 0; i < input_count; ++i) {
193  *                 extra_inputs.push_back(
194  *                     add_input<Func>("input_" + std::to_string(i), Float(32), 2);
195  *             }
196  *         }
197  *
198  *         void generate() {
199  *             Expr sum = 0.f;
200  *             for (int i = 0; i < input_count; ++i) {
201  *                 sum += (*extra_inputs)[i](x, y);
202  *             }
203  *             output(x, y) = sum;
204  *         }
205  *     private:
206  *         std::vector<Input<Func>* extra_inputs;
207  *     };
208  * \endcode
209  *
210  *  All Generators have three GeneratorParams that are implicitly provided
211  *  by the base class:
212  *
213  *      GeneratorParam<Target> target{"target", Target()};
214  *      GeneratorParam<bool> auto_schedule{"auto_schedule", false};
215  *      GeneratorParam<MachineParams> machine_params{"machine_params", MachineParams::generic()};
216  *
217  *  - 'target' is the Halide::Target for which the Generator is producing code.
218  *    It is read-only during the Generator's lifetime, and must not be modified;
219  *    its value should always be filled in by the calling code: either the Halide
220  *    build system (for ahead-of-time compilation), or ordinary C++ code
221  *    (for JIT compilation).
222  *  - 'auto_schedule' indicates whether the auto-scheduler should be run for this
223  *    Generator:
224  *      - if 'false', the Generator should schedule its Funcs as it sees fit.
225  *      - if 'true', the Generator should only provide estimate()s for its Funcs,
226  *        and not call any other scheduling methods.
227  *  - 'machine_params' is only used if auto_schedule is true; it is ignored
228  *    if auto_schedule is false. It provides details about the machine architecture
229  *    being targeted which may be used to enhance the automatically-generated
230  *    schedule.
231  *
232  * Generators are added to a global registry to simplify AOT build mechanics; this
233  * is done by simply using the HALIDE_REGISTER_GENERATOR macro at global scope:
234  *
235  * \code
236  *      HALIDE_REGISTER_GENERATOR(ExampleGen, jit_example)
237  * \endcode
238  *
239  * The registered name of the Generator is provided must match the same rules as
240  * Input names, above.
241  *
242  * Note that the class name of the generated Stub class will match the registered
243  * name by default; if you want to vary it (typically, to include namespaces),
244  * you can add it as an optional third argument:
245  *
246  * \code
247  *      HALIDE_REGISTER_GENERATOR(ExampleGen, jit_example, SomeNamespace::JitExampleStub)
248  * \endcode
249  *
250  * Note that a Generator is always executed with a specific Target assigned to it,
251  * that you can access via the get_target() method. (You should *not* use the
252  * global get_target_from_environment(), etc. methods provided in Target.h)
253  *
254  * (Note that there are older variations of Generator that differ from what's
255  * documented above; these are still supported but not described here. See
256  * https://github.com/halide/Halide/wiki/Old-Generator-Documentation for
257  * more information.)
258  */
259 
260 #include <algorithm>
261 #include <iterator>
262 #include <limits>
263 #include <memory>
264 #include <mutex>
265 #include <set>
266 #include <sstream>
267 #include <string>
268 #include <type_traits>
269 #include <utility>
270 #include <vector>
271 
272 #include "ExternalCode.h"
273 #include "Func.h"
274 #include "ImageParam.h"
275 #include "Introspection.h"
276 #include "ObjectInstanceRegistry.h"
277 #include "Target.h"
278 
279 namespace Halide {
280 
281 template<typename T>
282 class Buffer;
283 
284 namespace Internal {
285 
286 void generator_test();
287 
288 /**
289  * ValueTracker is an internal utility class that attempts to track and flag certain
290  * obvious Stub-related errors at Halide compile time: it tracks the constraints set
291  * on any Parameter-based argument (i.e., Input<Buffer> and Output<Buffer>) to
292  * ensure that incompatible values aren't set.
293  *
294  * e.g.: if a Generator A requires stride[0] == 1,
295  * and Generator B uses Generator A via stub, but requires stride[0] == 4,
296  * we should be able to detect this at Halide compilation time, and fail immediately,
297  * rather than producing code that fails at runtime and/or runs slowly due to
298  * vectorization being unavailable.
299  *
300  * We do this by tracking the active values at entrance and exit to all user-provided
301  * Generator methods (build()/generate()/schedule()); if we ever find more than two unique
302  * values active, we know we have a potential conflict. ("two" here because the first
303  * value is the default value for a given constraint.)
304  *
305  * Note that this won't catch all cases:
306  * -- JIT compilation has no way to check for conflicts at the top-level
307  * -- constraints that match the default value (e.g. if dim(0).set_stride(1) is the
308  * first value seen by the tracker) will be ignored, so an explicit requirement set
309  * this way can be missed
310  *
311  * Nevertheless, this is likely to be much better than nothing when composing multiple
312  * layers of Stubs in a single fused result.
313  */
314 class ValueTracker {
315 private:
316     std::map<std::string, std::vector<std::vector<Expr>>> values_history;
317     const size_t max_unique_values;
318 
319 public:
320     explicit ValueTracker(size_t max_unique_values = 2)
max_unique_values(max_unique_values)321         : max_unique_values(max_unique_values) {
322     }
323     void track_values(const std::string &name, const std::vector<Expr> &values);
324 };
325 
326 std::vector<Expr> parameter_constraints(const Parameter &p);
327 
328 template<typename T>
enum_to_string(const std::map<std::string,T> & enum_map,const T & t)329 HALIDE_NO_USER_CODE_INLINE std::string enum_to_string(const std::map<std::string, T> &enum_map, const T &t) {
330     for (auto key_value : enum_map) {
331         if (t == key_value.second) {
332             return key_value.first;
333         }
334     }
335     user_error << "Enumeration value not found.\n";
336     return "";
337 }
338 
339 template<typename T>
enum_from_string(const std::map<std::string,T> & enum_map,const std::string & s)340 T enum_from_string(const std::map<std::string, T> &enum_map, const std::string &s) {
341     auto it = enum_map.find(s);
342     user_assert(it != enum_map.end()) << "Enumeration value not found: " << s << "\n";
343     return it->second;
344 }
345 
346 extern const std::map<std::string, Halide::Type> &get_halide_type_enum_map();
halide_type_to_enum_string(const Type & t)347 inline std::string halide_type_to_enum_string(const Type &t) {
348     return enum_to_string(get_halide_type_enum_map(), t);
349 }
350 
351 // Convert a Halide Type into a string representation of its C source.
352 // e.g., Int(32) -> "Halide::Int(32)"
353 std::string halide_type_to_c_source(const Type &t);
354 
355 // Convert a Halide Type into a string representation of its C Source.
356 // e.g., Int(32) -> "int32_t"
357 std::string halide_type_to_c_type(const Type &t);
358 
359 /** generate_filter_main() is a convenient wrapper for GeneratorRegistry::create() +
360  * compile_to_files(); it can be trivially wrapped by a "real" main() to produce a
361  * command-line utility for ahead-of-time filter compilation. */
362 int generate_filter_main(int argc, char **argv, std::ostream &cerr);
363 
364 // select_type<> is to std::conditional as switch is to if:
365 // it allows a multiway compile-time type definition via the form
366 //
367 //    select_type<cond<condition1, type1>,
368 //                cond<condition2, type2>,
369 //                ....
370 //                cond<conditionN, typeN>>::type
371 //
372 // Note that the conditions are evaluated in order; the first evaluating to true
373 // is chosen.
374 //
375 // Note that if no conditions evaluate to true, the resulting type is illegal
376 // and will produce a compilation error. (You can provide a default by simply
377 // using cond<true, SomeType> as the final entry.)
378 template<bool B, typename T>
379 struct cond {
380     static constexpr bool value = B;
381     using type = T;
382 };
383 
384 template<typename First, typename... Rest>
385 struct select_type : std::conditional<First::value, typename First::type, typename select_type<Rest...>::type> {};
386 
387 template<typename First>
388 struct select_type<First> { using type = typename std::conditional<First::value, typename First::type, void>::type; };
389 
390 class GeneratorBase;
391 class GeneratorParamInfo;
392 
393 class GeneratorParamBase {
394 public:
395     explicit GeneratorParamBase(const std::string &name);
396     virtual ~GeneratorParamBase();
397 
398     const std::string name;
399 
400     // overload the set() function to call the right virtual method based on type.
401     // This allows us to attempt to set a GeneratorParam via a
402     // plain C++ type, even if we don't know the specific templated
403     // subclass. Attempting to set the wrong type will assert.
404     // Notice that there is no typed setter for Enums, for obvious reasons;
405     // setting enums in an unknown type must fallback to using set_from_string.
406     //
407     // It's always a bit iffy to use macros for this, but IMHO it clarifies the situation here.
408 #define HALIDE_GENERATOR_PARAM_TYPED_SETTER(TYPE) \
409     virtual void set(const TYPE &new_value) = 0;
410 
411     HALIDE_GENERATOR_PARAM_TYPED_SETTER(bool)
412     HALIDE_GENERATOR_PARAM_TYPED_SETTER(int8_t)
413     HALIDE_GENERATOR_PARAM_TYPED_SETTER(int16_t)
414     HALIDE_GENERATOR_PARAM_TYPED_SETTER(int32_t)
415     HALIDE_GENERATOR_PARAM_TYPED_SETTER(int64_t)
416     HALIDE_GENERATOR_PARAM_TYPED_SETTER(uint8_t)
417     HALIDE_GENERATOR_PARAM_TYPED_SETTER(uint16_t)
418     HALIDE_GENERATOR_PARAM_TYPED_SETTER(uint32_t)
419     HALIDE_GENERATOR_PARAM_TYPED_SETTER(uint64_t)
420     HALIDE_GENERATOR_PARAM_TYPED_SETTER(float)
421     HALIDE_GENERATOR_PARAM_TYPED_SETTER(double)
422     HALIDE_GENERATOR_PARAM_TYPED_SETTER(Target)
423     HALIDE_GENERATOR_PARAM_TYPED_SETTER(MachineParams)
424     HALIDE_GENERATOR_PARAM_TYPED_SETTER(Type)
425     HALIDE_GENERATOR_PARAM_TYPED_SETTER(LoopLevel)
426 
427 #undef HALIDE_GENERATOR_PARAM_TYPED_SETTER
428 
429     // Add overloads for string and char*
430     void set(const std::string &new_value) {
431         set_from_string(new_value);
432     }
433     void set(const char *new_value) {
434         set_from_string(std::string(new_value));
435     }
436 
437 protected:
438     friend class GeneratorBase;
439     friend class GeneratorParamInfo;
440     friend class StubEmitter;
441 
442     void check_value_readable() const;
443     void check_value_writable() const;
444 
445     // All GeneratorParams are settable from string.
446     virtual void set_from_string(const std::string &value_string) = 0;
447 
448     virtual std::string call_to_string(const std::string &v) const = 0;
449     virtual std::string get_c_type() const = 0;
450 
451     virtual std::string get_type_decls() const {
452         return "";
453     }
454 
455     virtual std::string get_default_value() const = 0;
456 
457     virtual bool is_synthetic_param() const {
458         return false;
459     }
460 
461     virtual bool is_looplevel_param() const {
462         return false;
463     }
464 
465     void fail_wrong_type(const char *type);
466 
467 private:
468     // No copy
469     GeneratorParamBase(const GeneratorParamBase &) = delete;
470     void operator=(const GeneratorParamBase &) = delete;
471     // No move
472     GeneratorParamBase(GeneratorParamBase &&) = delete;
473     void operator=(GeneratorParamBase &&) = delete;
474 
475     // Generator which owns this GeneratorParam. Note that this will be null
476     // initially; the GeneratorBase itself will set this field when it initially
477     // builds its info about params. However, since it (generally) isn't
478     // appropriate for GeneratorParam<> to be declared outside of a Generator,
479     // all reasonable non-testing code should expect this to be non-null.
480     GeneratorBase *generator{nullptr};
481 };
482 
483 // This is strictly some syntactic sugar to suppress certain compiler warnings.
484 template<typename FROM, typename TO>
485 struct Convert {
486     template<typename TO2 = TO, typename std::enable_if<!std::is_same<TO2, bool>::value>::type * = nullptr>
487     inline static TO2 value(const FROM &from) {
488         return static_cast<TO2>(from);
489     }
490 
491     template<typename TO2 = TO, typename std::enable_if<std::is_same<TO2, bool>::value>::type * = nullptr>
492     inline static TO2 value(const FROM &from) {
493         return from != 0;
494     }
495 };
496 
497 template<typename T>
498 class GeneratorParamImpl : public GeneratorParamBase {
499 public:
500     using type = T;
501 
502     GeneratorParamImpl(const std::string &name, const T &value)
503         : GeneratorParamBase(name), value_(value) {
504     }
505 
506     T value() const {
507         this->check_value_readable();
508         return value_;
509     }
510 
511     operator T() const {
512         return this->value();
513     }
514 
515     operator Expr() const {
516         return make_const(type_of<T>(), this->value());
517     }
518 
519 #define HALIDE_GENERATOR_PARAM_TYPED_SETTER(TYPE)  \
520     void set(const TYPE &new_value) override {     \
521         typed_setter_impl<TYPE>(new_value, #TYPE); \
522     }
523 
524     HALIDE_GENERATOR_PARAM_TYPED_SETTER(bool)
525     HALIDE_GENERATOR_PARAM_TYPED_SETTER(int8_t)
526     HALIDE_GENERATOR_PARAM_TYPED_SETTER(int16_t)
527     HALIDE_GENERATOR_PARAM_TYPED_SETTER(int32_t)
528     HALIDE_GENERATOR_PARAM_TYPED_SETTER(int64_t)
529     HALIDE_GENERATOR_PARAM_TYPED_SETTER(uint8_t)
530     HALIDE_GENERATOR_PARAM_TYPED_SETTER(uint16_t)
531     HALIDE_GENERATOR_PARAM_TYPED_SETTER(uint32_t)
532     HALIDE_GENERATOR_PARAM_TYPED_SETTER(uint64_t)
533     HALIDE_GENERATOR_PARAM_TYPED_SETTER(float)
534     HALIDE_GENERATOR_PARAM_TYPED_SETTER(double)
535     HALIDE_GENERATOR_PARAM_TYPED_SETTER(Target)
536     HALIDE_GENERATOR_PARAM_TYPED_SETTER(MachineParams)
537     HALIDE_GENERATOR_PARAM_TYPED_SETTER(Type)
538     HALIDE_GENERATOR_PARAM_TYPED_SETTER(LoopLevel)
539 
540 #undef HALIDE_GENERATOR_PARAM_TYPED_SETTER
541 
542     // Overload for std::string.
543     void set(const std::string &new_value) {
544         check_value_writable();
545         value_ = new_value;
546     }
547 
548 protected:
549     virtual void set_impl(const T &new_value) {
550         check_value_writable();
551         value_ = new_value;
552     }
553 
554     // Needs to be protected to allow GeneratorParam<LoopLevel>::set() override
555     T value_;
556 
557 private:
558     // If FROM->T is not legal, fail
559     template<typename FROM, typename std::enable_if<
560                                 !std::is_convertible<FROM, T>::value>::type * = nullptr>
561     HALIDE_ALWAYS_INLINE void typed_setter_impl(const FROM &, const char *msg) {
562         fail_wrong_type(msg);
563     }
564 
565     // If FROM and T are identical, just assign
566     template<typename FROM, typename std::enable_if<
567                                 std::is_same<FROM, T>::value>::type * = nullptr>
568     HALIDE_ALWAYS_INLINE void typed_setter_impl(const FROM &value, const char *msg) {
569         check_value_writable();
570         value_ = value;
571     }
572 
573     // If both FROM->T and T->FROM are legal, ensure it's lossless
574     template<typename FROM, typename std::enable_if<
575                                 !std::is_same<FROM, T>::value &&
576                                 std::is_convertible<FROM, T>::value &&
577                                 std::is_convertible<T, FROM>::value>::type * = nullptr>
578     HALIDE_ALWAYS_INLINE void typed_setter_impl(const FROM &value, const char *msg) {
579         check_value_writable();
580         const T t = Convert<FROM, T>::value(value);
581         const FROM value2 = Convert<T, FROM>::value(t);
582         if (value2 != value) {
583             fail_wrong_type(msg);
584         }
585         value_ = t;
586     }
587 
588     // If FROM->T is legal but T->FROM is not, just assign
589     template<typename FROM, typename std::enable_if<
590                                 !std::is_same<FROM, T>::value &&
591                                 std::is_convertible<FROM, T>::value &&
592                                 !std::is_convertible<T, FROM>::value>::type * = nullptr>
593     HALIDE_ALWAYS_INLINE void typed_setter_impl(const FROM &value, const char *msg) {
594         check_value_writable();
595         value_ = value;
596     }
597 };
598 
599 // Stubs for type-specific implementations of GeneratorParam, to avoid
600 // many complex enable_if<> statements that were formerly spread through the
601 // implementation. Note that not all of these need to be templated classes,
602 // (e.g. for GeneratorParam_Target, T == Target always), but are declared
603 // that way for symmetry of declaration.
604 template<typename T>
605 class GeneratorParam_Target : public GeneratorParamImpl<T> {
606 public:
607     GeneratorParam_Target(const std::string &name, const T &value)
608         : GeneratorParamImpl<T>(name, value) {
609     }
610 
611     void set_from_string(const std::string &new_value_string) override {
612         this->set(Target(new_value_string));
613     }
614 
615     std::string get_default_value() const override {
616         return this->value().to_string();
617     }
618 
619     std::string call_to_string(const std::string &v) const override {
620         std::ostringstream oss;
621         oss << v << ".to_string()";
622         return oss.str();
623     }
624 
625     std::string get_c_type() const override {
626         return "Target";
627     }
628 };
629 
630 template<typename T>
631 class GeneratorParam_MachineParams : public GeneratorParamImpl<T> {
632 public:
633     GeneratorParam_MachineParams(const std::string &name, const T &value)
634         : GeneratorParamImpl<T>(name, value) {
635     }
636 
637     void set_from_string(const std::string &new_value_string) override {
638         this->set(MachineParams(new_value_string));
639     }
640 
641     std::string get_default_value() const override {
642         return this->value().to_string();
643     }
644 
645     std::string call_to_string(const std::string &v) const override {
646         std::ostringstream oss;
647         oss << v << ".to_string()";
648         return oss.str();
649     }
650 
651     std::string get_c_type() const override {
652         return "MachineParams";
653     }
654 };
655 
656 class GeneratorParam_LoopLevel : public GeneratorParamImpl<LoopLevel> {
657 public:
658     GeneratorParam_LoopLevel(const std::string &name, const LoopLevel &value)
659         : GeneratorParamImpl<LoopLevel>(name, value) {
660     }
661 
662     using GeneratorParamImpl<LoopLevel>::set;
663 
664     void set(const LoopLevel &value) override {
665         // Don't call check_value_writable(): It's OK to set a LoopLevel after generate().
666         // check_value_writable();
667 
668         // This looks odd, but is deliberate:
669 
670         // First, mutate the existing contents to match the value passed in,
671         // so that any existing usage of the LoopLevel now uses the newer value.
672         // (Strictly speaking, this is really only necessary if this method
673         // is called after generate(): before generate(), there is no usage
674         // to be concerned with.)
675         value_.set(value);
676 
677         // Then, reset the value itself so that it points to the same LoopLevelContents
678         // as the value passed in. (Strictly speaking, this is really only
679         // useful if this method is called before generate(): afterwards, it's
680         // too late to alter the code to refer to a different LoopLevelContents.)
681         value_ = value;
682     }
683 
684     void set_from_string(const std::string &new_value_string) override {
685         if (new_value_string == "root") {
686             this->set(LoopLevel::root());
687         } else if (new_value_string == "inlined") {
688             this->set(LoopLevel::inlined());
689         } else {
690             user_error << "Unable to parse " << this->name << ": " << new_value_string;
691         }
692     }
693 
694     std::string get_default_value() const override {
695         // This is dodgy but safe in this case: we want to
696         // see what the value of our LoopLevel is *right now*,
697         // so we make a copy and lock the copy so we can inspect it.
698         // (Note that ordinarily this is a bad idea, since LoopLevels
699         // can be mutated later on; however, this method is only
700         // called by the Generator infrastructure, on LoopLevels that
701         // will never be mutated, so this is really just an elaborate way
702         // to avoid runtime assertions.)
703         LoopLevel copy;
704         copy.set(this->value());
705         copy.lock();
706         if (copy.is_inlined()) {
707             return "LoopLevel::inlined()";
708         } else if (copy.is_root()) {
709             return "LoopLevel::root()";
710         } else {
711             internal_error;
712             return "";
713         }
714     }
715 
716     std::string call_to_string(const std::string &v) const override {
717         internal_error;
718         return std::string();
719     }
720 
721     std::string get_c_type() const override {
722         return "LoopLevel";
723     }
724 
725     bool is_looplevel_param() const override {
726         return true;
727     }
728 };
729 
730 template<typename T>
731 class GeneratorParam_Arithmetic : public GeneratorParamImpl<T> {
732 public:
733     GeneratorParam_Arithmetic(const std::string &name,
734                               const T &value,
735                               const T &min = std::numeric_limits<T>::lowest(),
736                               const T &max = std::numeric_limits<T>::max())
737         : GeneratorParamImpl<T>(name, value), min(min), max(max) {
738         // call set() to ensure value is clamped to min/max
739         this->set(value);
740     }
741 
742     void set_impl(const T &new_value) override {
743         user_assert(new_value >= min && new_value <= max) << "Value out of range: " << new_value;
744         GeneratorParamImpl<T>::set_impl(new_value);
745     }
746 
747     void set_from_string(const std::string &new_value_string) override {
748         std::istringstream iss(new_value_string);
749         T t;
750         // All one-byte ints int8 and uint8 should be parsed as integers, not chars --
751         // including 'char' itself. (Note that sizeof(bool) is often-but-not-always-1,
752         // so be sure to exclude that case.)
753         if (sizeof(T) == sizeof(char) && !std::is_same<T, bool>::value) {
754             int i;
755             iss >> i;
756             t = (T)i;
757         } else {
758             iss >> t;
759         }
760         user_assert(!iss.fail() && iss.get() == EOF) << "Unable to parse: " << new_value_string;
761         this->set(t);
762     }
763 
764     std::string get_default_value() const override {
765         std::ostringstream oss;
766         oss << this->value();
767         if (std::is_same<T, float>::value) {
768             // If the constant has no decimal point ("1")
769             // we must append one before appending "f"
770             if (oss.str().find(".") == std::string::npos) {
771                 oss << ".";
772             }
773             oss << "f";
774         }
775         return oss.str();
776     }
777 
778     std::string call_to_string(const std::string &v) const override {
779         std::ostringstream oss;
780         oss << "std::to_string(" << v << ")";
781         return oss.str();
782     }
783 
784     std::string get_c_type() const override {
785         std::ostringstream oss;
786         if (std::is_same<T, float>::value) {
787             return "float";
788         } else if (std::is_same<T, double>::value) {
789             return "double";
790         } else if (std::is_integral<T>::value) {
791             if (std::is_unsigned<T>::value) {
792                 oss << "u";
793             }
794             oss << "int" << (sizeof(T) * 8) << "_t";
795             return oss.str();
796         } else {
797             user_error << "Unknown arithmetic type\n";
798             return "";
799         }
800     }
801 
802 private:
803     const T min, max;
804 };
805 
806 template<typename T>
807 class GeneratorParam_Bool : public GeneratorParam_Arithmetic<T> {
808 public:
809     GeneratorParam_Bool(const std::string &name, const T &value)
810         : GeneratorParam_Arithmetic<T>(name, value) {
811     }
812 
813     void set_from_string(const std::string &new_value_string) override {
814         bool v = false;
815         if (new_value_string == "true" || new_value_string == "True") {
816             v = true;
817         } else if (new_value_string == "false" || new_value_string == "False") {
818             v = false;
819         } else {
820             user_assert(false) << "Unable to parse bool: " << new_value_string;
821         }
822         this->set(v);
823     }
824 
825     std::string get_default_value() const override {
826         return this->value() ? "true" : "false";
827     }
828 
829     std::string call_to_string(const std::string &v) const override {
830         std::ostringstream oss;
831         oss << "std::string((" << v << ") ? \"true\" : \"false\")";
832         return oss.str();
833     }
834 
835     std::string get_c_type() const override {
836         return "bool";
837     }
838 };
839 
840 template<typename T>
841 class GeneratorParam_Enum : public GeneratorParamImpl<T> {
842 public:
843     GeneratorParam_Enum(const std::string &name, const T &value, const std::map<std::string, T> &enum_map)
844         : GeneratorParamImpl<T>(name, value), enum_map(enum_map) {
845     }
846 
847     // define a "set" that takes our specific enum (but don't hide the inherited virtual functions)
848     using GeneratorParamImpl<T>::set;
849 
850     template<typename T2 = T, typename std::enable_if<!std::is_same<T2, Type>::value>::type * = nullptr>
851     void set(const T &e) {
852         this->set_impl(e);
853     }
854 
855     void set_from_string(const std::string &new_value_string) override {
856         auto it = enum_map.find(new_value_string);
857         user_assert(it != enum_map.end()) << "Enumeration value not found: " << new_value_string;
858         this->set_impl(it->second);
859     }
860 
861     std::string call_to_string(const std::string &v) const override {
862         return "Enum_" + this->name + "_map().at(" + v + ")";
863     }
864 
865     std::string get_c_type() const override {
866         return "Enum_" + this->name;
867     }
868 
869     std::string get_default_value() const override {
870         return "Enum_" + this->name + "::" + enum_to_string(enum_map, this->value());
871     }
872 
873     std::string get_type_decls() const override {
874         std::ostringstream oss;
875         oss << "enum class Enum_" << this->name << " {\n";
876         for (auto key_value : enum_map) {
877             oss << "  " << key_value.first << ",\n";
878         }
879         oss << "};\n";
880         oss << "\n";
881 
882         // TODO: since we generate the enums, we could probably just use a vector (or array!) rather than a map,
883         // since we can ensure that the enum values are a nice tight range.
884         oss << "inline HALIDE_NO_USER_CODE_INLINE const std::map<Enum_" << this->name << ", std::string>& Enum_" << this->name << "_map() {\n";
885         oss << "  static const std::map<Enum_" << this->name << ", std::string> m = {\n";
886         for (auto key_value : enum_map) {
887             oss << "    { Enum_" << this->name << "::" << key_value.first << ", \"" << key_value.first << "\"},\n";
888         }
889         oss << "  };\n";
890         oss << "  return m;\n";
891         oss << "};\n";
892         return oss.str();
893     }
894 
895 private:
896     const std::map<std::string, T> enum_map;
897 };
898 
899 template<typename T>
900 class GeneratorParam_Type : public GeneratorParam_Enum<T> {
901 public:
902     GeneratorParam_Type(const std::string &name, const T &value)
903         : GeneratorParam_Enum<T>(name, value, get_halide_type_enum_map()) {
904     }
905 
906     std::string call_to_string(const std::string &v) const override {
907         return "Halide::Internal::halide_type_to_enum_string(" + v + ")";
908     }
909 
910     std::string get_c_type() const override {
911         return "Type";
912     }
913 
914     std::string get_default_value() const override {
915         return halide_type_to_c_source(this->value());
916     }
917 
918     std::string get_type_decls() const override {
919         return "";
920     }
921 };
922 
923 template<typename T>
924 class GeneratorParam_String : public Internal::GeneratorParamImpl<T> {
925 public:
926     GeneratorParam_String(const std::string &name, const std::string &value)
927         : GeneratorParamImpl<T>(name, value) {
928     }
929     void set_from_string(const std::string &new_value_string) override {
930         this->set(new_value_string);
931     }
932 
933     std::string get_default_value() const override {
934         return "\"" + this->value() + "\"";
935     }
936 
937     std::string call_to_string(const std::string &v) const override {
938         return v;
939     }
940 
941     std::string get_c_type() const override {
942         return "std::string";
943     }
944 };
945 
946 template<typename T>
947 using GeneratorParamImplBase =
948     typename select_type<
949         cond<std::is_same<T, Target>::value, GeneratorParam_Target<T>>,
950         cond<std::is_same<T, MachineParams>::value, GeneratorParam_MachineParams<T>>,
951         cond<std::is_same<T, LoopLevel>::value, GeneratorParam_LoopLevel>,
952         cond<std::is_same<T, std::string>::value, GeneratorParam_String<T>>,
953         cond<std::is_same<T, Type>::value, GeneratorParam_Type<T>>,
954         cond<std::is_same<T, bool>::value, GeneratorParam_Bool<T>>,
955         cond<std::is_arithmetic<T>::value, GeneratorParam_Arithmetic<T>>,
956         cond<std::is_enum<T>::value, GeneratorParam_Enum<T>>>::type;
957 
958 }  // namespace Internal
959 
960 /** GeneratorParam is a templated class that can be used to modify the behavior
961  * of the Generator at code-generation time. GeneratorParams are commonly
962  * specified in build files (e.g. Makefile) to customize the behavior of
963  * a given Generator, thus they have a very constrained set of types to allow
964  * for efficient specification via command-line flags. A GeneratorParam can be:
965  *   - any float or int type.
966  *   - bool
967  *   - enum
968  *   - Halide::Target
969  *   - Halide::Type
970  *   - std::string
971  * Please don't use std::string unless there's no way to do what you want with some
972  * other type; in particular, don't use this if you can use enum instead.
973  * All GeneratorParams have a default value. Arithmetic types can also
974  * optionally specify min and max. Enum types must specify a string-to-value
975  * map.
976  *
977  * Halide::Type is treated as though it were an enum, with the mappings:
978  *
979  *   "int8"     Halide::Int(8)
980  *   "int16"    Halide::Int(16)
981  *   "int32"    Halide::Int(32)
982  *   "uint8"    Halide::UInt(8)
983  *   "uint16"   Halide::UInt(16)
984  *   "uint32"   Halide::UInt(32)
985  *   "float32"  Halide::Float(32)
986  *   "float64"  Halide::Float(64)
987  *
988  * No vector Types are currently supported by this mapping.
989  *
990  */
991 template<typename T>
992 class GeneratorParam : public Internal::GeneratorParamImplBase<T> {
993 public:
994     template<typename T2 = T, typename std::enable_if<!std::is_same<T2, std::string>::value>::type * = nullptr>
995     GeneratorParam(const std::string &name, const T &value)
996         : Internal::GeneratorParamImplBase<T>(name, value) {
997     }
998 
999     GeneratorParam(const std::string &name, const T &value, const T &min, const T &max)
1000         : Internal::GeneratorParamImplBase<T>(name, value, min, max) {
1001     }
1002 
1003     GeneratorParam(const std::string &name, const T &value, const std::map<std::string, T> &enum_map)
1004         : Internal::GeneratorParamImplBase<T>(name, value, enum_map) {
1005     }
1006 
1007     GeneratorParam(const std::string &name, const std::string &value)
1008         : Internal::GeneratorParamImplBase<T>(name, value) {
1009     }
1010 };
1011 
1012 /** Addition between GeneratorParam<T> and any type that supports operator+ with T.
1013  * Returns type of underlying operator+. */
1014 // @{
1015 template<typename Other, typename T>
1016 auto operator+(const Other &a, const GeneratorParam<T> &b) -> decltype(a + (T)b) {
1017     return a + (T)b;
1018 }
1019 template<typename Other, typename T>
1020 auto operator+(const GeneratorParam<T> &a, const Other &b) -> decltype((T)a + b) {
1021     return (T)a + b;
1022 }
1023 // @}
1024 
1025 /** Subtraction between GeneratorParam<T> and any type that supports operator- with T.
1026  * Returns type of underlying operator-. */
1027 // @{
1028 template<typename Other, typename T>
1029 auto operator-(const Other &a, const GeneratorParam<T> &b) -> decltype(a - (T)b) {
1030     return a - (T)b;
1031 }
1032 template<typename Other, typename T>
1033 auto operator-(const GeneratorParam<T> &a, const Other &b) -> decltype((T)a - b) {
1034     return (T)a - b;
1035 }
1036 // @}
1037 
1038 /** Multiplication between GeneratorParam<T> and any type that supports operator* with T.
1039  * Returns type of underlying operator*. */
1040 // @{
1041 template<typename Other, typename T>
1042 auto operator*(const Other &a, const GeneratorParam<T> &b) -> decltype(a * (T)b) {
1043     return a * (T)b;
1044 }
1045 template<typename Other, typename T>
1046 auto operator*(const GeneratorParam<T> &a, const Other &b) -> decltype((T)a * b) {
1047     return (T)a * b;
1048 }
1049 // @}
1050 
1051 /** Division between GeneratorParam<T> and any type that supports operator/ with T.
1052  * Returns type of underlying operator/. */
1053 // @{
1054 template<typename Other, typename T>
1055 auto operator/(const Other &a, const GeneratorParam<T> &b) -> decltype(a / (T)b) {
1056     return a / (T)b;
1057 }
1058 template<typename Other, typename T>
1059 auto operator/(const GeneratorParam<T> &a, const Other &b) -> decltype((T)a / b) {
1060     return (T)a / b;
1061 }
1062 // @}
1063 
1064 /** Modulo between GeneratorParam<T> and any type that supports operator% with T.
1065  * Returns type of underlying operator%. */
1066 // @{
1067 template<typename Other, typename T>
1068 auto operator%(const Other &a, const GeneratorParam<T> &b) -> decltype(a % (T)b) {
1069     return a % (T)b;
1070 }
1071 template<typename Other, typename T>
1072 auto operator%(const GeneratorParam<T> &a, const Other &b) -> decltype((T)a % b) {
1073     return (T)a % b;
1074 }
1075 // @}
1076 
1077 /** Greater than comparison between GeneratorParam<T> and any type that supports operator> with T.
1078  * Returns type of underlying operator>. */
1079 // @{
1080 template<typename Other, typename T>
1081 auto operator>(const Other &a, const GeneratorParam<T> &b) -> decltype(a > (T)b) {
1082     return a > (T)b;
1083 }
1084 template<typename Other, typename T>
1085 auto operator>(const GeneratorParam<T> &a, const Other &b) -> decltype((T)a > b) {
1086     return (T)a > b;
1087 }
1088 // @}
1089 
1090 /** Less than comparison between GeneratorParam<T> and any type that supports operator< with T.
1091  * Returns type of underlying operator<. */
1092 // @{
1093 template<typename Other, typename T>
1094 auto operator<(const Other &a, const GeneratorParam<T> &b) -> decltype(a < (T)b) {
1095     return a < (T)b;
1096 }
1097 template<typename Other, typename T>
1098 auto operator<(const GeneratorParam<T> &a, const Other &b) -> decltype((T)a < b) {
1099     return (T)a < b;
1100 }
1101 // @}
1102 
1103 /** Greater than or equal comparison between GeneratorParam<T> and any type that supports operator>= with T.
1104  * Returns type of underlying operator>=. */
1105 // @{
1106 template<typename Other, typename T>
1107 auto operator>=(const Other &a, const GeneratorParam<T> &b) -> decltype(a >= (T)b) {
1108     return a >= (T)b;
1109 }
1110 template<typename Other, typename T>
1111 auto operator>=(const GeneratorParam<T> &a, const Other &b) -> decltype((T)a >= b) {
1112     return (T)a >= b;
1113 }
1114 // @}
1115 
1116 /** Less than or equal comparison between GeneratorParam<T> and any type that supports operator<= with T.
1117  * Returns type of underlying operator<=. */
1118 // @{
1119 template<typename Other, typename T>
1120 auto operator<=(const Other &a, const GeneratorParam<T> &b) -> decltype(a <= (T)b) {
1121     return a <= (T)b;
1122 }
1123 template<typename Other, typename T>
1124 auto operator<=(const GeneratorParam<T> &a, const Other &b) -> decltype((T)a <= b) {
1125     return (T)a <= b;
1126 }
1127 // @}
1128 
1129 /** Equality comparison between GeneratorParam<T> and any type that supports operator== with T.
1130  * Returns type of underlying operator==. */
1131 // @{
1132 template<typename Other, typename T>
1133 auto operator==(const Other &a, const GeneratorParam<T> &b) -> decltype(a == (T)b) {
1134     return a == (T)b;
1135 }
1136 template<typename Other, typename T>
1137 auto operator==(const GeneratorParam<T> &a, const Other &b) -> decltype((T)a == b) {
1138     return (T)a == b;
1139 }
1140 // @}
1141 
1142 /** Inequality comparison between between GeneratorParam<T> and any type that supports operator!= with T.
1143  * Returns type of underlying operator!=. */
1144 // @{
1145 template<typename Other, typename T>
1146 auto operator!=(const Other &a, const GeneratorParam<T> &b) -> decltype(a != (T)b) {
1147     return a != (T)b;
1148 }
1149 template<typename Other, typename T>
1150 auto operator!=(const GeneratorParam<T> &a, const Other &b) -> decltype((T)a != b) {
1151     return (T)a != b;
1152 }
1153 // @}
1154 
1155 /** Logical and between between GeneratorParam<T> and any type that supports operator&& with T.
1156  * Returns type of underlying operator&&. */
1157 // @{
1158 template<typename Other, typename T>
1159 auto operator&&(const Other &a, const GeneratorParam<T> &b) -> decltype(a && (T)b) {
1160     return a && (T)b;
1161 }
1162 template<typename Other, typename T>
1163 auto operator&&(const GeneratorParam<T> &a, const Other &b) -> decltype((T)a && b) {
1164     return (T)a && b;
1165 }
1166 template<typename T>
1167 auto operator&&(const GeneratorParam<T> &a, const GeneratorParam<T> &b) -> decltype((T)a && (T)b) {
1168     return (T)a && (T)b;
1169 }
1170 // @}
1171 
1172 /** Logical or between between GeneratorParam<T> and any type that supports operator|| with T.
1173  * Returns type of underlying operator||. */
1174 // @{
1175 template<typename Other, typename T>
1176 auto operator||(const Other &a, const GeneratorParam<T> &b) -> decltype(a || (T)b) {
1177     return a || (T)b;
1178 }
1179 template<typename Other, typename T>
1180 auto operator||(const GeneratorParam<T> &a, const Other &b) -> decltype((T)a || b) {
1181     return (T)a || b;
1182 }
1183 template<typename T>
1184 auto operator||(const GeneratorParam<T> &a, const GeneratorParam<T> &b) -> decltype((T)a || (T)b) {
1185     return (T)a || (T)b;
1186 }
1187 // @}
1188 
1189 /* min and max are tricky as the language support for these is in the std
1190  * namespace. In order to make this work, forwarding functions are used that
1191  * are declared in a namespace that has std::min and std::max in scope.
1192  */
1193 namespace Internal {
1194 namespace GeneratorMinMax {
1195 
1196 using std::max;
1197 using std::min;
1198 
1199 template<typename Other, typename T>
1200 auto min_forward(const Other &a, const GeneratorParam<T> &b) -> decltype(min(a, (T)b)) {
1201     return min(a, (T)b);
1202 }
1203 template<typename Other, typename T>
1204 auto min_forward(const GeneratorParam<T> &a, const Other &b) -> decltype(min((T)a, b)) {
1205     return min((T)a, b);
1206 }
1207 
1208 template<typename Other, typename T>
1209 auto max_forward(const Other &a, const GeneratorParam<T> &b) -> decltype(max(a, (T)b)) {
1210     return max(a, (T)b);
1211 }
1212 template<typename Other, typename T>
1213 auto max_forward(const GeneratorParam<T> &a, const Other &b) -> decltype(max((T)a, b)) {
1214     return max((T)a, b);
1215 }
1216 
1217 }  // namespace GeneratorMinMax
1218 }  // namespace Internal
1219 
1220 /** Compute minimum between GeneratorParam<T> and any type that supports min with T.
1221  * Will automatically import std::min. Returns type of underlying min call. */
1222 // @{
1223 template<typename Other, typename T>
1224 auto min(const Other &a, const GeneratorParam<T> &b) -> decltype(Internal::GeneratorMinMax::min_forward(a, b)) {
1225     return Internal::GeneratorMinMax::min_forward(a, b);
1226 }
1227 template<typename Other, typename T>
1228 auto min(const GeneratorParam<T> &a, const Other &b) -> decltype(Internal::GeneratorMinMax::min_forward(a, b)) {
1229     return Internal::GeneratorMinMax::min_forward(a, b);
1230 }
1231 // @}
1232 
1233 /** Compute the maximum value between GeneratorParam<T> and any type that supports max with T.
1234  * Will automatically import std::max. Returns type of underlying max call. */
1235 // @{
1236 template<typename Other, typename T>
1237 auto max(const Other &a, const GeneratorParam<T> &b) -> decltype(Internal::GeneratorMinMax::max_forward(a, b)) {
1238     return Internal::GeneratorMinMax::max_forward(a, b);
1239 }
1240 template<typename Other, typename T>
1241 auto max(const GeneratorParam<T> &a, const Other &b) -> decltype(Internal::GeneratorMinMax::max_forward(a, b)) {
1242     return Internal::GeneratorMinMax::max_forward(a, b);
1243 }
1244 // @}
1245 
1246 /** Not operator for GeneratorParam */
1247 template<typename T>
1248 auto operator!(const GeneratorParam<T> &a) -> decltype(!(T)a) {
1249     return !(T)a;
1250 }
1251 
1252 namespace Internal {
1253 
1254 template<typename T2>
1255 class GeneratorInput_Buffer;
1256 
1257 enum class IOKind { Scalar,
1258                     Function,
1259                     Buffer };
1260 
1261 /**
1262  * StubInputBuffer is the placeholder that a Stub uses when it requires
1263  * a Buffer for an input (rather than merely a Func or Expr). It is constructed
1264  * to allow only two possible sorts of input:
1265  * -- Assignment of an Input<Buffer<>>, with compatible type and dimensions,
1266  * essentially allowing us to pipe a parameter from an enclosing Generator to an internal Stub.
1267  * -- Assignment of a Buffer<>, with compatible type and dimensions,
1268  * causing the Input<Buffer<>> to become a precompiled buffer in the generated code.
1269  */
1270 template<typename T = void>
1271 class StubInputBuffer {
1272     friend class StubInput;
1273     template<typename T2>
1274     friend class GeneratorInput_Buffer;
1275 
1276     Parameter parameter_;
1277 
1278     HALIDE_NO_USER_CODE_INLINE explicit StubInputBuffer(const Parameter &p)
1279         : parameter_(p) {
1280         // Create an empty 1-element buffer with the right runtime typing and dimensions,
1281         // which we'll use only to pass to can_convert_from() to verify this
1282         // Parameter is compatible with our constraints.
1283         Buffer<> other(p.type(), nullptr, std::vector<int>(p.dimensions(), 1));
1284         internal_assert((Buffer<T>::can_convert_from(other)));
1285     }
1286 
1287     template<typename T2>
1288     HALIDE_NO_USER_CODE_INLINE static Parameter parameter_from_buffer(const Buffer<T2> &b) {
1289         user_assert((Buffer<T>::can_convert_from(b)));
1290         Parameter p(b.type(), true, b.dimensions());
1291         p.set_buffer(b);
1292         return p;
1293     }
1294 
1295 public:
1296     StubInputBuffer() = default;
1297 
1298     // *not* explicit -- this ctor should only be used when you want
1299     // to pass a literal Buffer<> for a Stub Input; this Buffer<> will be
1300     // compiled into the Generator's product, rather than becoming
1301     // a runtime Parameter.
1302     template<typename T2>
1303     StubInputBuffer(const Buffer<T2> &b)
1304         : parameter_(parameter_from_buffer(b)) {
1305     }
1306 };
1307 
1308 class StubOutputBufferBase {
1309 protected:
1310     Func f;
1311     std::shared_ptr<GeneratorBase> generator;
1312 
1313     void check_scheduled(const char *m) const;
1314     Target get_target() const;
1315 
1316     explicit StubOutputBufferBase(const Func &f, std::shared_ptr<GeneratorBase> generator)
1317         : f(f), generator(std::move(generator)) {
1318     }
1319     StubOutputBufferBase() = default;
1320 
1321 public:
1322     Realization realize(std::vector<int32_t> sizes) {
1323         check_scheduled("realize");
1324         return f.realize(std::move(sizes), get_target());
1325     }
1326 
1327     template<typename... Args>
1328     Realization realize(Args &&... args) {
1329         check_scheduled("realize");
1330         return f.realize(std::forward<Args>(args)..., get_target());
1331     }
1332 
1333     template<typename Dst>
1334     void realize(Dst dst) {
1335         check_scheduled("realize");
1336         f.realize(dst, get_target());
1337     }
1338 };
1339 
1340 /**
1341  * StubOutputBuffer is the placeholder that a Stub uses when it requires
1342  * a Buffer for an output (rather than merely a Func). It is constructed
1343  * to allow only two possible sorts of things:
1344  * -- Assignment to an Output<Buffer<>>, with compatible type and dimensions,
1345  * essentially allowing us to pipe a parameter from the result of a Stub to an
1346  * enclosing Generator
1347  * -- Realization into a Buffer<>; this is useful only in JIT compilation modes
1348  * (and shouldn't be usable otherwise)
1349  *
1350  * It is deliberate that StubOutputBuffer is not (easily) convertible to Func.
1351  */
1352 template<typename T = void>
1353 class StubOutputBuffer : public StubOutputBufferBase {
1354     template<typename T2>
1355     friend class GeneratorOutput_Buffer;
1356     friend class GeneratorStub;
1357     explicit StubOutputBuffer(const Func &f, const std::shared_ptr<GeneratorBase> &generator)
1358         : StubOutputBufferBase(f, generator) {
1359     }
1360 
1361 public:
1362     StubOutputBuffer() = default;
1363 };
1364 
1365 // This is a union-like class that allows for convenient initialization of Stub Inputs
1366 // via C++11 initializer-list syntax; it is only used in situations where the
1367 // downstream consumer will be able to explicitly check that each value is
1368 // of the expected/required kind.
1369 class StubInput {
1370     const IOKind kind_;
1371     // Exactly one of the following fields should be defined:
1372     const Parameter parameter_;
1373     const Func func_;
1374     const Expr expr_;
1375 
1376 public:
1377     // *not* explicit.
1378     template<typename T2>
1379     StubInput(const StubInputBuffer<T2> &b)
1380         : kind_(IOKind::Buffer), parameter_(b.parameter_), func_(), expr_() {
1381     }
1382     StubInput(const Func &f)
1383         : kind_(IOKind::Function), parameter_(), func_(f), expr_() {
1384     }
1385     StubInput(const Expr &e)
1386         : kind_(IOKind::Scalar), parameter_(), func_(), expr_(e) {
1387     }
1388 
1389 private:
1390     friend class GeneratorInputBase;
1391 
1392     IOKind kind() const {
1393         return kind_;
1394     }
1395 
1396     Parameter parameter() const {
1397         internal_assert(kind_ == IOKind::Buffer);
1398         return parameter_;
1399     }
1400 
1401     Func func() const {
1402         internal_assert(kind_ == IOKind::Function);
1403         return func_;
1404     }
1405 
1406     Expr expr() const {
1407         internal_assert(kind_ == IOKind::Scalar);
1408         return expr_;
1409     }
1410 };
1411 
1412 /** GIOBase is the base class for all GeneratorInput<> and GeneratorOutput<>
1413  * instantiations; it is not part of the public API and should never be
1414  * used directly by user code.
1415  *
1416  * Every GIOBase instance can be either a single value or an array-of-values;
1417  * each of these values can be an Expr or a Func. (Note that for an
1418  * array-of-values, the types/dimensions of all values in the array must match.)
1419  *
1420  * A GIOBase can have multiple Types, in which case it represents a Tuple.
1421  * (Note that Tuples are currently only supported for GeneratorOutput, but
1422  * it is likely that GeneratorInput will be extended to support Tuple as well.)
1423  *
1424  * The array-size, type(s), and dimensions can all be left "unspecified" at
1425  * creation time, in which case they may assume values provided by a Stub.
1426  * (It is important to note that attempting to use a GIOBase with unspecified
1427  * values will assert-fail; you must ensure that all unspecified values are
1428  * filled in prior to use.)
1429  */
1430 class GIOBase {
1431 public:
1432     bool array_size_defined() const;
1433     size_t array_size() const;
1434     virtual bool is_array() const;
1435 
1436     const std::string &name() const;
1437     IOKind kind() const;
1438 
1439     bool types_defined() const;
1440     const std::vector<Type> &types() const;
1441     Type type() const;
1442 
1443     bool dims_defined() const;
1444     int dims() const;
1445 
1446     const std::vector<Func> &funcs() const;
1447     const std::vector<Expr> &exprs() const;
1448 
1449     virtual ~GIOBase();
1450 
1451 protected:
1452     GIOBase(size_t array_size,
1453             const std::string &name,
1454             IOKind kind,
1455             const std::vector<Type> &types,
1456             int dims);
1457 
1458     friend class GeneratorBase;
1459     friend class GeneratorParamInfo;
1460 
1461     mutable int array_size_;  // always 1 if is_array() == false.
1462         // -1 if is_array() == true but unspecified.
1463 
1464     const std::string name_;
1465     const IOKind kind_;
1466     mutable std::vector<Type> types_;  // empty if type is unspecified
1467     mutable int dims_;                 // -1 if dim is unspecified
1468 
1469     // Exactly one of these will have nonzero length
1470     std::vector<Func> funcs_;
1471     std::vector<Expr> exprs_;
1472 
1473     // Generator which owns this Input or Output. Note that this will be null
1474     // initially; the GeneratorBase itself will set this field when it initially
1475     // builds its info about params. However, since it isn't
1476     // appropriate for Input<> or Output<> to be declared outside of a Generator,
1477     // all reasonable non-testing code should expect this to be non-null.
1478     GeneratorBase *generator{nullptr};
1479 
1480     std::string array_name(size_t i) const;
1481 
1482     virtual void verify_internals();
1483 
1484     void check_matching_array_size(size_t size) const;
1485     void check_matching_types(const std::vector<Type> &t) const;
1486     void check_matching_dims(int d) const;
1487 
1488     template<typename ElemType>
1489     const std::vector<ElemType> &get_values() const;
1490 
1491     void check_gio_access() const;
1492 
1493     virtual void check_value_writable() const = 0;
1494 
1495     virtual const char *input_or_output() const = 0;
1496 
1497 private:
1498     template<typename T>
1499     friend class GeneratorParam_Synthetic;
1500 
1501     // No copy
1502     GIOBase(const GIOBase &) = delete;
1503     void operator=(const GIOBase &) = delete;
1504     // No move
1505     GIOBase(GIOBase &&) = delete;
1506     void operator=(GIOBase &&) = delete;
1507 };
1508 
1509 template<>
1510 inline const std::vector<Expr> &GIOBase::get_values<Expr>() const {
1511     return exprs();
1512 }
1513 
1514 template<>
1515 inline const std::vector<Func> &GIOBase::get_values<Func>() const {
1516     return funcs();
1517 }
1518 
1519 class GeneratorInputBase : public GIOBase {
1520 protected:
1521     GeneratorInputBase(size_t array_size,
1522                        const std::string &name,
1523                        IOKind kind,
1524                        const std::vector<Type> &t,
1525                        int d);
1526 
1527     GeneratorInputBase(const std::string &name, IOKind kind, const std::vector<Type> &t, int d);
1528 
1529     friend class GeneratorBase;
1530     friend class GeneratorParamInfo;
1531 
1532     std::vector<Parameter> parameters_;
1533 
1534     Parameter parameter() const;
1535 
1536     void init_internals();
1537     void set_inputs(const std::vector<StubInput> &inputs);
1538 
1539     virtual void set_def_min_max();
1540     virtual Expr get_def_expr() const;
1541 
1542     void verify_internals() override;
1543 
1544     friend class StubEmitter;
1545 
1546     virtual std::string get_c_type() const = 0;
1547 
1548     void check_value_writable() const override;
1549 
1550     const char *input_or_output() const override {
1551         return "Input";
1552     }
1553 
1554     void set_estimate_impl(const Var &var, const Expr &min, const Expr &extent);
1555     void set_estimates_impl(const Region &estimates);
1556 
1557 public:
1558     ~GeneratorInputBase() override;
1559 };
1560 
1561 template<typename T, typename ValueType>
1562 class GeneratorInputImpl : public GeneratorInputBase {
1563 protected:
1564     using TBase = typename std::remove_all_extents<T>::type;
1565 
1566     bool is_array() const override {
1567         return std::is_array<T>::value;
1568     }
1569 
1570     template<typename T2 = T, typename std::enable_if<
1571                                   // Only allow T2 not-an-array
1572                                   !std::is_array<T2>::value>::type * = nullptr>
1573     GeneratorInputImpl(const std::string &name, IOKind kind, const std::vector<Type> &t, int d)
1574         : GeneratorInputBase(name, kind, t, d) {
1575     }
1576 
1577     template<typename T2 = T, typename std::enable_if<
1578                                   // Only allow T2[kSomeConst]
1579                                   std::is_array<T2>::value && std::rank<T2>::value == 1 && (std::extent<T2, 0>::value > 0)>::type * = nullptr>
1580     GeneratorInputImpl(const std::string &name, IOKind kind, const std::vector<Type> &t, int d)
1581         : GeneratorInputBase(std::extent<T2, 0>::value, name, kind, t, d) {
1582     }
1583 
1584     template<typename T2 = T, typename std::enable_if<
1585                                   // Only allow T2[]
1586                                   std::is_array<T2>::value && std::rank<T2>::value == 1 && std::extent<T2, 0>::value == 0>::type * = nullptr>
1587     GeneratorInputImpl(const std::string &name, IOKind kind, const std::vector<Type> &t, int d)
1588         : GeneratorInputBase(-1, name, kind, t, d) {
1589     }
1590 
1591 public:
1592     template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
1593     size_t size() const {
1594         this->check_gio_access();
1595         return get_values<ValueType>().size();
1596     }
1597 
1598     template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
1599     const ValueType &operator[](size_t i) const {
1600         this->check_gio_access();
1601         return get_values<ValueType>()[i];
1602     }
1603 
1604     template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
1605     const ValueType &at(size_t i) const {
1606         this->check_gio_access();
1607         return get_values<ValueType>().at(i);
1608     }
1609 
1610     template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
1611     typename std::vector<ValueType>::const_iterator begin() const {
1612         this->check_gio_access();
1613         return get_values<ValueType>().begin();
1614     }
1615 
1616     template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
1617     typename std::vector<ValueType>::const_iterator end() const {
1618         this->check_gio_access();
1619         return get_values<ValueType>().end();
1620     }
1621 };
1622 
1623 // When forwarding methods to ImageParam, Func, etc., we must take
1624 // care with the return types: many of the methods return a reference-to-self
1625 // (e.g., ImageParam&); since we create temporaries for most of these forwards,
1626 // returning a ref will crater because it refers to a now-defunct section of the
1627 // stack. Happily, simply removing the reference is solves this, since all of the
1628 // types in question satisfy the property of copies referring to the same underlying
1629 // structure (returning references is just an optimization). Since this is verbose
1630 // and used in several places, we'll use a helper macro:
1631 #define HALIDE_FORWARD_METHOD(Class, Method)                                                                                                         \
1632     template<typename... Args>                                                                                                                       \
1633     inline auto Method(Args &&... args)->typename std::remove_reference<decltype(std::declval<Class>().Method(std::forward<Args>(args)...))>::type { \
1634         return this->template as<Class>().Method(std::forward<Args>(args)...);                                                                       \
1635     }
1636 
1637 #define HALIDE_FORWARD_METHOD_CONST(Class, Method)                                                                  \
1638     template<typename... Args>                                                                                      \
1639     inline auto Method(Args &&... args) const->                                                                     \
1640         typename std::remove_reference<decltype(std::declval<Class>().Method(std::forward<Args>(args)...))>::type { \
1641         this->check_gio_access();                                                                                   \
1642         return this->template as<Class>().Method(std::forward<Args>(args)...);                                      \
1643     }
1644 
1645 template<typename T>
1646 class GeneratorInput_Buffer : public GeneratorInputImpl<T, Func> {
1647 private:
1648     using Super = GeneratorInputImpl<T, Func>;
1649 
1650 protected:
1651     using TBase = typename Super::TBase;
1652 
1653     friend class ::Halide::Func;
1654     friend class ::Halide::Stage;
1655 
1656     std::string get_c_type() const override {
1657         if (TBase::has_static_halide_type) {
1658             return "Halide::Internal::StubInputBuffer<" +
1659                    halide_type_to_c_type(TBase::static_halide_type()) +
1660                    ">";
1661         } else {
1662             return "Halide::Internal::StubInputBuffer<>";
1663         }
1664     }
1665 
1666     template<typename T2>
1667     inline T2 as() const {
1668         return (T2) * this;
1669     }
1670 
1671 public:
1672     GeneratorInput_Buffer(const std::string &name)
1673         : Super(name, IOKind::Buffer,
1674                 TBase::has_static_halide_type ? std::vector<Type>{TBase::static_halide_type()} : std::vector<Type>{},
1675                 -1) {
1676     }
1677 
1678     GeneratorInput_Buffer(const std::string &name, const Type &t, int d = -1)
1679         : Super(name, IOKind::Buffer, {t}, d) {
1680         static_assert(!TBase::has_static_halide_type, "You can only specify a Type argument for Input<Buffer<T>> if T is void or omitted.");
1681     }
1682 
1683     GeneratorInput_Buffer(const std::string &name, int d)
1684         : Super(name, IOKind::Buffer, TBase::has_static_halide_type ? std::vector<Type>{TBase::static_halide_type()} : std::vector<Type>{}, d) {
1685     }
1686 
1687     template<typename... Args>
1688     Expr operator()(Args &&... args) const {
1689         this->check_gio_access();
1690         return Func(*this)(std::forward<Args>(args)...);
1691     }
1692 
1693     Expr operator()(std::vector<Expr> args) const {
1694         this->check_gio_access();
1695         return Func(*this)(std::move(args));
1696     }
1697 
1698     template<typename T2>
1699     operator StubInputBuffer<T2>() const {
1700         user_assert(!this->is_array()) << "Cannot assign an array type to a non-array type for Input " << this->name();
1701         return StubInputBuffer<T2>(this->parameters_.at(0));
1702     }
1703 
1704     operator Func() const {
1705         this->check_gio_access();
1706         return this->funcs().at(0);
1707     }
1708 
1709     operator ExternFuncArgument() const {
1710         this->check_gio_access();
1711         return ExternFuncArgument(this->parameters_.at(0));
1712     }
1713 
1714     GeneratorInput_Buffer<T> &set_estimate(Var var, Expr min, Expr extent) {
1715         this->check_gio_access();
1716         this->set_estimate_impl(var, min, extent);
1717         return *this;
1718     }
1719 
1720     HALIDE_ATTRIBUTE_DEPRECATED("Use set_estimate() instead")
1721     GeneratorInput_Buffer<T> &estimate(const Var &var, const Expr &min, const Expr &extent) {
1722         return set_estimate(var, min, extent);
1723     }
1724 
1725     GeneratorInput_Buffer<T> &set_estimates(const Region &estimates) {
1726         this->check_gio_access();
1727         this->set_estimates_impl(estimates);
1728         return *this;
1729     }
1730 
1731     Func in() {
1732         this->check_gio_access();
1733         return Func(*this).in();
1734     }
1735 
1736     Func in(const Func &other) {
1737         this->check_gio_access();
1738         return Func(*this).in(other);
1739     }
1740 
1741     Func in(const std::vector<Func> &others) {
1742         this->check_gio_access();
1743         return Func(*this).in(others);
1744     }
1745 
1746     operator ImageParam() const {
1747         this->check_gio_access();
1748         user_assert(!this->is_array()) << "Cannot convert an Input<Buffer<>[]> to an ImageParam; use an explicit subscript operator: " << this->name();
1749         return ImageParam(this->parameters_.at(0), Func(*this));
1750     }
1751 
1752     template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
1753     size_t size() const {
1754         this->check_gio_access();
1755         return this->parameters_.size();
1756     }
1757 
1758     template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
1759     ImageParam operator[](size_t i) const {
1760         this->check_gio_access();
1761         return ImageParam(this->parameters_.at(i), this->funcs().at(i));
1762     }
1763 
1764     template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
1765     ImageParam at(size_t i) const {
1766         this->check_gio_access();
1767         return ImageParam(this->parameters_.at(i), this->funcs().at(i));
1768     }
1769 
1770     template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
1771     typename std::vector<ImageParam>::const_iterator begin() const {
1772         user_error << "Input<Buffer<>>::begin() is not supported.";
1773         return {};
1774     }
1775 
1776     template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
1777     typename std::vector<ImageParam>::const_iterator end() const {
1778         user_error << "Input<Buffer<>>::end() is not supported.";
1779         return {};
1780     }
1781 
1782     /** Forward methods to the ImageParam. */
1783     // @{
1784     HALIDE_FORWARD_METHOD(ImageParam, dim)
1785     HALIDE_FORWARD_METHOD_CONST(ImageParam, dim)
1786     HALIDE_FORWARD_METHOD_CONST(ImageParam, host_alignment)
1787     HALIDE_FORWARD_METHOD(ImageParam, set_host_alignment)
1788     HALIDE_FORWARD_METHOD_CONST(ImageParam, dimensions)
1789     HALIDE_FORWARD_METHOD_CONST(ImageParam, left)
1790     HALIDE_FORWARD_METHOD_CONST(ImageParam, right)
1791     HALIDE_FORWARD_METHOD_CONST(ImageParam, top)
1792     HALIDE_FORWARD_METHOD_CONST(ImageParam, bottom)
1793     HALIDE_FORWARD_METHOD_CONST(ImageParam, width)
1794     HALIDE_FORWARD_METHOD_CONST(ImageParam, height)
1795     HALIDE_FORWARD_METHOD_CONST(ImageParam, channels)
1796     HALIDE_FORWARD_METHOD_CONST(ImageParam, trace_loads)
1797     HALIDE_FORWARD_METHOD_CONST(ImageParam, add_trace_tag)
1798     // }@
1799 };
1800 
1801 template<typename T>
1802 class GeneratorInput_Func : public GeneratorInputImpl<T, Func> {
1803 private:
1804     using Super = GeneratorInputImpl<T, Func>;
1805 
1806 protected:
1807     using TBase = typename Super::TBase;
1808 
1809     std::string get_c_type() const override {
1810         return "Func";
1811     }
1812 
1813     template<typename T2>
1814     inline T2 as() const {
1815         return (T2) * this;
1816     }
1817 
1818 public:
1819     GeneratorInput_Func(const std::string &name, const Type &t, int d)
1820         : Super(name, IOKind::Function, {t}, d) {
1821     }
1822 
1823     // unspecified type
1824     GeneratorInput_Func(const std::string &name, int d)
1825         : Super(name, IOKind::Function, {}, d) {
1826     }
1827 
1828     // unspecified dimension
1829     GeneratorInput_Func(const std::string &name, const Type &t)
1830         : Super(name, IOKind::Function, {t}, -1) {
1831     }
1832 
1833     // unspecified type & dimension
1834     GeneratorInput_Func(const std::string &name)
1835         : Super(name, IOKind::Function, {}, -1) {
1836     }
1837 
1838     GeneratorInput_Func(size_t array_size, const std::string &name, const Type &t, int d)
1839         : Super(array_size, name, IOKind::Function, {t}, d) {
1840     }
1841 
1842     // unspecified type
1843     GeneratorInput_Func(size_t array_size, const std::string &name, int d)
1844         : Super(array_size, name, IOKind::Function, {}, d) {
1845     }
1846 
1847     // unspecified dimension
1848     GeneratorInput_Func(size_t array_size, const std::string &name, const Type &t)
1849         : Super(array_size, name, IOKind::Function, {t}, -1) {
1850     }
1851 
1852     // unspecified type & dimension
1853     GeneratorInput_Func(size_t array_size, const std::string &name)
1854         : Super(array_size, name, IOKind::Function, {}, -1) {
1855     }
1856 
1857     template<typename... Args>
1858     Expr operator()(Args &&... args) const {
1859         this->check_gio_access();
1860         return this->funcs().at(0)(std::forward<Args>(args)...);
1861     }
1862 
1863     Expr operator()(const std::vector<Expr> &args) const {
1864         this->check_gio_access();
1865         return this->funcs().at(0)(args);
1866     }
1867 
1868     operator Func() const {
1869         this->check_gio_access();
1870         return this->funcs().at(0);
1871     }
1872 
1873     operator ExternFuncArgument() const {
1874         this->check_gio_access();
1875         return ExternFuncArgument(this->parameters_.at(0));
1876     }
1877 
1878     GeneratorInput_Func<T> &set_estimate(Var var, Expr min, Expr extent) {
1879         this->check_gio_access();
1880         this->set_estimate_impl(var, min, extent);
1881         return *this;
1882     }
1883 
1884     HALIDE_ATTRIBUTE_DEPRECATED("Use set_estimate() instead")
1885     GeneratorInput_Func<T> &estimate(const Var &var, const Expr &min, const Expr &extent) {
1886         return set_estimate(var, min, extent);
1887     }
1888 
1889     GeneratorInput_Func<T> &set_estimates(const Region &estimates) {
1890         this->check_gio_access();
1891         this->set_estimates_impl(estimates);
1892         return *this;
1893     }
1894 
1895     Func in() {
1896         this->check_gio_access();
1897         return Func(*this).in();
1898     }
1899 
1900     Func in(const Func &other) {
1901         this->check_gio_access();
1902         return Func(*this).in(other);
1903     }
1904 
1905     Func in(const std::vector<Func> &others) {
1906         this->check_gio_access();
1907         return Func(*this).in(others);
1908     }
1909 
1910     /** Forward const methods to the underlying Func. (Non-const methods
1911      * aren't available for Input<Func>.) */
1912     // @{
1913     HALIDE_FORWARD_METHOD_CONST(Func, args)
1914     HALIDE_FORWARD_METHOD_CONST(Func, defined)
1915     HALIDE_FORWARD_METHOD_CONST(Func, has_update_definition)
1916     HALIDE_FORWARD_METHOD_CONST(Func, num_update_definitions)
1917     HALIDE_FORWARD_METHOD_CONST(Func, output_types)
1918     HALIDE_FORWARD_METHOD_CONST(Func, outputs)
1919     HALIDE_FORWARD_METHOD_CONST(Func, rvars)
1920     HALIDE_FORWARD_METHOD_CONST(Func, update_args)
1921     HALIDE_FORWARD_METHOD_CONST(Func, update_value)
1922     HALIDE_FORWARD_METHOD_CONST(Func, update_values)
1923     HALIDE_FORWARD_METHOD_CONST(Func, value)
1924     HALIDE_FORWARD_METHOD_CONST(Func, values)
1925     // }@
1926 };
1927 
1928 template<typename T>
1929 class GeneratorInput_Scalar : public GeneratorInputImpl<T, Expr> {
1930 private:
1931     using Super = GeneratorInputImpl<T, Expr>;
1932 
1933 protected:
1934     using TBase = typename Super::TBase;
1935 
1936     const TBase def_{TBase()};
1937     const Expr def_expr_;
1938 
1939 protected:
1940     Expr get_def_expr() const override {
1941         return def_expr_;
1942     }
1943 
1944     void set_def_min_max() override {
1945         for (Parameter &p : this->parameters_) {
1946             p.set_scalar<TBase>(def_);
1947         }
1948     }
1949 
1950     std::string get_c_type() const override {
1951         return "Expr";
1952     }
1953 
1954     // Expr() doesn't accept a pointer type in its ctor; add a SFINAE adapter
1955     // so that pointer (aka handle) Inputs will get cast to uint64.
1956     template<typename TBase2 = TBase, typename std::enable_if<!std::is_pointer<TBase2>::value>::type * = nullptr>
1957     static Expr TBaseToExpr(const TBase2 &value) {
1958         return Expr(value);
1959     }
1960 
1961     template<typename TBase2 = TBase, typename std::enable_if<std::is_pointer<TBase2>::value>::type * = nullptr>
1962     static Expr TBaseToExpr(const TBase2 &value) {
1963         return Expr((uint64_t)value);
1964     }
1965 
1966 public:
1967     explicit GeneratorInput_Scalar(const std::string &name)
1968         : Super(name, IOKind::Scalar, {type_of<TBase>()}, 0), def_(static_cast<TBase>(0)), def_expr_(Expr()) {
1969     }
1970 
1971     GeneratorInput_Scalar(const std::string &name, const TBase &def)
1972         : Super(name, IOKind::Scalar, {type_of<TBase>()}, 0), def_(def), def_expr_(TBaseToExpr(def)) {
1973     }
1974 
1975     GeneratorInput_Scalar(size_t array_size,
1976                           const std::string &name)
1977         : Super(array_size, name, IOKind::Scalar, {type_of<TBase>()}, 0), def_(static_cast<TBase>(0)), def_expr_(Expr()) {
1978     }
1979 
1980     GeneratorInput_Scalar(size_t array_size,
1981                           const std::string &name,
1982                           const TBase &def)
1983         : Super(array_size, name, IOKind::Scalar, {type_of<TBase>()}, 0), def_(def), def_expr_(TBaseToExpr(def)) {
1984     }
1985 
1986     /** You can use this Input as an expression in a halide
1987      * function definition */
1988     operator Expr() const {
1989         this->check_gio_access();
1990         return this->exprs().at(0);
1991     }
1992 
1993     /** Using an Input as the argument to an external stage treats it
1994      * as an Expr */
1995     operator ExternFuncArgument() const {
1996         this->check_gio_access();
1997         return ExternFuncArgument(this->exprs().at(0));
1998     }
1999 
2000     template<typename T2 = T, typename std::enable_if<std::is_pointer<T2>::value>::type * = nullptr>
2001     void set_estimate(const TBase &value) {
2002         this->check_gio_access();
2003         user_assert(value == nullptr) << "nullptr is the only valid estimate for Input<PointerType>";
2004         Expr e = reinterpret(type_of<T2>(), cast<uint64_t>(0));
2005         for (Parameter &p : this->parameters_) {
2006             p.set_estimate(e);
2007         }
2008     }
2009 
2010     template<typename T2 = T, typename std::enable_if<!std::is_array<T2>::value && !std::is_pointer<T2>::value>::type * = nullptr>
2011     void set_estimate(const TBase &value) {
2012         this->check_gio_access();
2013         Expr e = Expr(value);
2014         if (std::is_same<T2, bool>::value) {
2015             e = cast<bool>(e);
2016         }
2017         for (Parameter &p : this->parameters_) {
2018             p.set_estimate(e);
2019         }
2020     }
2021 
2022     template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
2023     void set_estimate(size_t index, const TBase &value) {
2024         this->check_gio_access();
2025         Expr e = Expr(value);
2026         if (std::is_same<T2, bool>::value) {
2027             e = cast<bool>(e);
2028         }
2029         this->parameters_.at(index).set_estimate(e);
2030     }
2031 };
2032 
2033 template<typename T>
2034 class GeneratorInput_Arithmetic : public GeneratorInput_Scalar<T> {
2035 private:
2036     using Super = GeneratorInput_Scalar<T>;
2037 
2038 protected:
2039     using TBase = typename Super::TBase;
2040 
2041     const Expr min_, max_;
2042 
2043 protected:
2044     void set_def_min_max() override {
2045         Super::set_def_min_max();
2046         // Don't set min/max for bool
2047         if (!std::is_same<TBase, bool>::value) {
2048             for (Parameter &p : this->parameters_) {
2049                 if (min_.defined()) p.set_min_value(min_);
2050                 if (max_.defined()) p.set_max_value(max_);
2051             }
2052         }
2053     }
2054 
2055 public:
2056     explicit GeneratorInput_Arithmetic(const std::string &name)
2057         : Super(name), min_(Expr()), max_(Expr()) {
2058     }
2059 
2060     GeneratorInput_Arithmetic(const std::string &name,
2061                               const TBase &def)
2062         : Super(name, def), min_(Expr()), max_(Expr()) {
2063     }
2064 
2065     GeneratorInput_Arithmetic(size_t array_size,
2066                               const std::string &name)
2067         : Super(array_size, name), min_(Expr()), max_(Expr()) {
2068     }
2069 
2070     GeneratorInput_Arithmetic(size_t array_size,
2071                               const std::string &name,
2072                               const TBase &def)
2073         : Super(array_size, name, def), min_(Expr()), max_(Expr()) {
2074     }
2075 
2076     GeneratorInput_Arithmetic(const std::string &name,
2077                               const TBase &def,
2078                               const TBase &min,
2079                               const TBase &max)
2080         : Super(name, def), min_(min), max_(max) {
2081     }
2082 
2083     GeneratorInput_Arithmetic(size_t array_size,
2084                               const std::string &name,
2085                               const TBase &def,
2086                               const TBase &min,
2087                               const TBase &max)
2088         : Super(array_size, name, def), min_(min), max_(max) {
2089     }
2090 };
2091 
2092 template<typename>
2093 struct type_sink { typedef void type; };
2094 
2095 template<typename T2, typename = void>
2096 struct has_static_halide_type_method : std::false_type {};
2097 
2098 template<typename T2>
2099 struct has_static_halide_type_method<T2, typename type_sink<decltype(T2::static_halide_type())>::type> : std::true_type {};
2100 
2101 template<typename T, typename TBase = typename std::remove_all_extents<T>::type>
2102 using GeneratorInputImplBase =
2103     typename select_type<
2104         cond<has_static_halide_type_method<TBase>::value, GeneratorInput_Buffer<T>>,
2105         cond<std::is_same<TBase, Func>::value, GeneratorInput_Func<T>>,
2106         cond<std::is_arithmetic<TBase>::value, GeneratorInput_Arithmetic<T>>,
2107         cond<std::is_scalar<TBase>::value, GeneratorInput_Scalar<T>>>::type;
2108 
2109 }  // namespace Internal
2110 
2111 template<typename T>
2112 class GeneratorInput : public Internal::GeneratorInputImplBase<T> {
2113 private:
2114     using Super = Internal::GeneratorInputImplBase<T>;
2115 
2116 protected:
2117     using TBase = typename Super::TBase;
2118 
2119     // Trick to avoid ambiguous ctor between Func-with-dim and int-with-default-value;
2120     // since we can't use std::enable_if on ctors, define the argument to be one that
2121     // can only be properly resolved for TBase=Func.
2122     struct Unused;
2123     using IntIfNonScalar =
2124         typename Internal::select_type<
2125             Internal::cond<Internal::has_static_halide_type_method<TBase>::value, int>,
2126             Internal::cond<std::is_same<TBase, Func>::value, int>,
2127             Internal::cond<true, Unused>>::type;
2128 
2129 public:
2130     explicit GeneratorInput(const std::string &name)
2131         : Super(name) {
2132     }
2133 
2134     GeneratorInput(const std::string &name, const TBase &def)
2135         : Super(name, def) {
2136     }
2137 
2138     GeneratorInput(size_t array_size, const std::string &name, const TBase &def)
2139         : Super(array_size, name, def) {
2140     }
2141 
2142     GeneratorInput(const std::string &name,
2143                    const TBase &def, const TBase &min, const TBase &max)
2144         : Super(name, def, min, max) {
2145     }
2146 
2147     GeneratorInput(size_t array_size, const std::string &name,
2148                    const TBase &def, const TBase &min, const TBase &max)
2149         : Super(array_size, name, def, min, max) {
2150     }
2151 
2152     GeneratorInput(const std::string &name, const Type &t, int d)
2153         : Super(name, t, d) {
2154     }
2155 
2156     GeneratorInput(const std::string &name, const Type &t)
2157         : Super(name, t) {
2158     }
2159 
2160     // Avoid ambiguity between Func-with-dim and int-with-default
2161     GeneratorInput(const std::string &name, IntIfNonScalar d)
2162         : Super(name, d) {
2163     }
2164 
2165     GeneratorInput(size_t array_size, const std::string &name, const Type &t, int d)
2166         : Super(array_size, name, t, d) {
2167     }
2168 
2169     GeneratorInput(size_t array_size, const std::string &name, const Type &t)
2170         : Super(array_size, name, t) {
2171     }
2172 
2173     // Avoid ambiguity between Func-with-dim and int-with-default
2174     //template <typename T2 = T, typename std::enable_if<std::is_same<TBase, Func>::value>::type * = nullptr>
2175     GeneratorInput(size_t array_size, const std::string &name, IntIfNonScalar d)
2176         : Super(array_size, name, d) {
2177     }
2178 
2179     GeneratorInput(size_t array_size, const std::string &name)
2180         : Super(array_size, name) {
2181     }
2182 };
2183 
2184 namespace Internal {
2185 
2186 class GeneratorOutputBase : public GIOBase {
2187 protected:
2188     template<typename T2, typename std::enable_if<std::is_same<T2, Func>::value>::type * = nullptr>
2189     HALIDE_NO_USER_CODE_INLINE T2 as() const {
2190         static_assert(std::is_same<T2, Func>::value, "Only Func allowed here");
2191         internal_assert(kind() != IOKind::Scalar);
2192         internal_assert(exprs_.empty());
2193         user_assert(funcs_.size() == 1) << "Use [] to access individual Funcs in Output<Func[]>";
2194         return funcs_[0];
2195     }
2196 
2197 public:
2198     HALIDE_ATTRIBUTE_DEPRECATED("Use set_estimate() instead")
2199     GeneratorOutputBase &estimate(const Var &var, const Expr &min, const Expr &extent) {
2200         this->as<Func>().set_estimate(var, min, extent);
2201         return *this;
2202     }
2203 
2204     /** Forward schedule-related methods to the underlying Func. */
2205     // @{
2206     HALIDE_FORWARD_METHOD(Func, add_trace_tag)
2207     HALIDE_FORWARD_METHOD(Func, align_bounds)
2208     HALIDE_FORWARD_METHOD(Func, align_storage)
2209     HALIDE_FORWARD_METHOD_CONST(Func, args)
2210     HALIDE_FORWARD_METHOD(Func, bound)
2211     HALIDE_FORWARD_METHOD(Func, bound_extent)
2212     HALIDE_FORWARD_METHOD(Func, compute_at)
2213     HALIDE_FORWARD_METHOD(Func, compute_inline)
2214     HALIDE_FORWARD_METHOD(Func, compute_root)
2215     HALIDE_FORWARD_METHOD(Func, compute_with)
2216     HALIDE_FORWARD_METHOD(Func, copy_to_device)
2217     HALIDE_FORWARD_METHOD(Func, copy_to_host)
2218     HALIDE_FORWARD_METHOD(Func, define_extern)
2219     HALIDE_FORWARD_METHOD_CONST(Func, defined)
2220     HALIDE_FORWARD_METHOD(Func, fold_storage)
2221     HALIDE_FORWARD_METHOD(Func, fuse)
2222     HALIDE_FORWARD_METHOD(Func, glsl)
2223     HALIDE_FORWARD_METHOD(Func, gpu)
2224     HALIDE_FORWARD_METHOD(Func, gpu_blocks)
2225     HALIDE_FORWARD_METHOD(Func, gpu_single_thread)
2226     HALIDE_FORWARD_METHOD(Func, gpu_threads)
2227     HALIDE_FORWARD_METHOD(Func, gpu_tile)
2228     HALIDE_FORWARD_METHOD_CONST(Func, has_update_definition)
2229     HALIDE_FORWARD_METHOD(Func, hexagon)
2230     HALIDE_FORWARD_METHOD(Func, in)
2231     HALIDE_FORWARD_METHOD(Func, memoize)
2232     HALIDE_FORWARD_METHOD_CONST(Func, num_update_definitions)
2233     HALIDE_FORWARD_METHOD_CONST(Func, output_types)
2234     HALIDE_FORWARD_METHOD_CONST(Func, outputs)
2235     HALIDE_FORWARD_METHOD(Func, parallel)
2236     HALIDE_FORWARD_METHOD(Func, prefetch)
2237     HALIDE_FORWARD_METHOD(Func, print_loop_nest)
2238     HALIDE_FORWARD_METHOD(Func, rename)
2239     HALIDE_FORWARD_METHOD(Func, reorder)
2240     HALIDE_FORWARD_METHOD(Func, reorder_storage)
2241     HALIDE_FORWARD_METHOD_CONST(Func, rvars)
2242     HALIDE_FORWARD_METHOD(Func, serial)
2243     HALIDE_FORWARD_METHOD(Func, set_estimate)
2244     HALIDE_FORWARD_METHOD(Func, shader)
2245     HALIDE_FORWARD_METHOD(Func, specialize)
2246     HALIDE_FORWARD_METHOD(Func, specialize_fail)
2247     HALIDE_FORWARD_METHOD(Func, split)
2248     HALIDE_FORWARD_METHOD(Func, store_at)
2249     HALIDE_FORWARD_METHOD(Func, store_root)
2250     HALIDE_FORWARD_METHOD(Func, tile)
2251     HALIDE_FORWARD_METHOD(Func, trace_stores)
2252     HALIDE_FORWARD_METHOD(Func, unroll)
2253     HALIDE_FORWARD_METHOD(Func, update)
2254     HALIDE_FORWARD_METHOD_CONST(Func, update_args)
2255     HALIDE_FORWARD_METHOD_CONST(Func, update_value)
2256     HALIDE_FORWARD_METHOD_CONST(Func, update_values)
2257     HALIDE_FORWARD_METHOD_CONST(Func, value)
2258     HALIDE_FORWARD_METHOD_CONST(Func, values)
2259     HALIDE_FORWARD_METHOD(Func, vectorize)
2260     // }@
2261 
2262 #undef HALIDE_OUTPUT_FORWARD
2263 #undef HALIDE_OUTPUT_FORWARD_CONST
2264 
2265 protected:
2266     GeneratorOutputBase(size_t array_size,
2267                         const std::string &name,
2268                         IOKind kind,
2269                         const std::vector<Type> &t,
2270                         int d);
2271 
2272     GeneratorOutputBase(const std::string &name,
2273                         IOKind kind,
2274                         const std::vector<Type> &t,
2275                         int d);
2276 
2277     friend class GeneratorBase;
2278     friend class StubEmitter;
2279 
2280     void init_internals();
2281     void resize(size_t size);
2282 
2283     virtual std::string get_c_type() const {
2284         return "Func";
2285     }
2286 
2287     void check_value_writable() const override;
2288 
2289     const char *input_or_output() const override {
2290         return "Output";
2291     }
2292 
2293 public:
2294     ~GeneratorOutputBase() override;
2295 };
2296 
2297 template<typename T>
2298 class GeneratorOutputImpl : public GeneratorOutputBase {
2299 protected:
2300     using TBase = typename std::remove_all_extents<T>::type;
2301     using ValueType = Func;
2302 
2303     bool is_array() const override {
2304         return std::is_array<T>::value;
2305     }
2306 
2307     template<typename T2 = T, typename std::enable_if<
2308                                   // Only allow T2 not-an-array
2309                                   !std::is_array<T2>::value>::type * = nullptr>
2310     GeneratorOutputImpl(const std::string &name, IOKind kind, const std::vector<Type> &t, int d)
2311         : GeneratorOutputBase(name, kind, t, d) {
2312     }
2313 
2314     template<typename T2 = T, typename std::enable_if<
2315                                   // Only allow T2[kSomeConst]
2316                                   std::is_array<T2>::value && std::rank<T2>::value == 1 && (std::extent<T2, 0>::value > 0)>::type * = nullptr>
2317     GeneratorOutputImpl(const std::string &name, IOKind kind, const std::vector<Type> &t, int d)
2318         : GeneratorOutputBase(std::extent<T2, 0>::value, name, kind, t, d) {
2319     }
2320 
2321     template<typename T2 = T, typename std::enable_if<
2322                                   // Only allow T2[]
2323                                   std::is_array<T2>::value && std::rank<T2>::value == 1 && std::extent<T2, 0>::value == 0>::type * = nullptr>
2324     GeneratorOutputImpl(const std::string &name, IOKind kind, const std::vector<Type> &t, int d)
2325         : GeneratorOutputBase(-1, name, kind, t, d) {
2326     }
2327 
2328 public:
2329     template<typename... Args, typename T2 = T, typename std::enable_if<!std::is_array<T2>::value>::type * = nullptr>
2330     FuncRef operator()(Args &&... args) const {
2331         this->check_gio_access();
2332         return get_values<ValueType>().at(0)(std::forward<Args>(args)...);
2333     }
2334 
2335     template<typename ExprOrVar, typename T2 = T, typename std::enable_if<!std::is_array<T2>::value>::type * = nullptr>
2336     FuncRef operator()(std::vector<ExprOrVar> args) const {
2337         this->check_gio_access();
2338         return get_values<ValueType>().at(0)(args);
2339     }
2340 
2341     template<typename T2 = T, typename std::enable_if<!std::is_array<T2>::value>::type * = nullptr>
2342     operator Func() const {
2343         this->check_gio_access();
2344         return get_values<ValueType>().at(0);
2345     }
2346 
2347     template<typename T2 = T, typename std::enable_if<!std::is_array<T2>::value>::type * = nullptr>
2348     operator Stage() const {
2349         this->check_gio_access();
2350         return get_values<ValueType>().at(0);
2351     }
2352 
2353     template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
2354     size_t size() const {
2355         this->check_gio_access();
2356         return get_values<ValueType>().size();
2357     }
2358 
2359     template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
2360     const ValueType &operator[](size_t i) const {
2361         this->check_gio_access();
2362         return get_values<ValueType>()[i];
2363     }
2364 
2365     template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
2366     const ValueType &at(size_t i) const {
2367         this->check_gio_access();
2368         return get_values<ValueType>().at(i);
2369     }
2370 
2371     template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
2372     typename std::vector<ValueType>::const_iterator begin() const {
2373         this->check_gio_access();
2374         return get_values<ValueType>().begin();
2375     }
2376 
2377     template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
2378     typename std::vector<ValueType>::const_iterator end() const {
2379         this->check_gio_access();
2380         return get_values<ValueType>().end();
2381     }
2382 
2383     template<typename T2 = T, typename std::enable_if<
2384                                   // Only allow T2[]
2385                                   std::is_array<T2>::value && std::rank<T2>::value == 1 && std::extent<T2, 0>::value == 0>::type * = nullptr>
2386     void resize(size_t size) {
2387         this->check_gio_access();
2388         GeneratorOutputBase::resize(size);
2389     }
2390 };
2391 
2392 template<typename T>
2393 class GeneratorOutput_Buffer : public GeneratorOutputImpl<T> {
2394 private:
2395     using Super = GeneratorOutputImpl<T>;
2396 
2397     HALIDE_NO_USER_CODE_INLINE void assign_from_func(const Func &f) {
2398         this->check_value_writable();
2399 
2400         internal_assert(f.defined());
2401 
2402         if (this->types_defined()) {
2403             const auto &my_types = this->types();
2404             user_assert(my_types.size() == f.output_types().size())
2405                 << "Cannot assign Func \"" << f.name()
2406                 << "\" to Output \"" << this->name() << "\"\n"
2407                 << "Output " << this->name()
2408                 << " is declared to have " << my_types.size() << " tuple elements"
2409                 << " but Func " << f.name()
2410                 << " has " << f.output_types().size() << " tuple elements.\n";
2411             for (size_t i = 0; i < my_types.size(); i++) {
2412                 user_assert(my_types[i] == f.output_types().at(i))
2413                     << "Cannot assign Func \"" << f.name()
2414                     << "\" to Output \"" << this->name() << "\"\n"
2415                     << (my_types.size() > 1 ? "In tuple element " + std::to_string(i) + ", " : "")
2416                     << "Output " << this->name()
2417                     << " has declared type " << my_types[i]
2418                     << " but Func " << f.name()
2419                     << " has type " << f.output_types().at(i) << "\n";
2420             }
2421         }
2422         if (this->dims_defined()) {
2423             user_assert(f.dimensions() == this->dims())
2424                 << "Cannot assign Func \"" << f.name()
2425                 << "\" to Output \"" << this->name() << "\"\n"
2426                 << "Output " << this->name()
2427                 << " has declared dimensionality " << this->dims()
2428                 << " but Func " << f.name()
2429                 << " has dimensionality " << f.dimensions() << "\n";
2430         }
2431 
2432         internal_assert(this->exprs_.empty() && this->funcs_.size() == 1);
2433         user_assert(!this->funcs_.at(0).defined());
2434         this->funcs_[0] = f;
2435     }
2436 
2437 protected:
2438     using TBase = typename Super::TBase;
2439 
2440     static std::vector<Type> my_types(const std::vector<Type> &t) {
2441         if (TBase::has_static_halide_type) {
2442             user_assert(t.empty()) << "Cannot pass a Type argument for an Output<Buffer> with a non-void static type\n";
2443             return std::vector<Type>{TBase::static_halide_type()};
2444         }
2445         return t;
2446     }
2447 
2448 protected:
2449     GeneratorOutput_Buffer(const std::string &name, const std::vector<Type> &t = {}, int d = -1)
2450         : Super(name, IOKind::Buffer, my_types(t), d) {
2451     }
2452 
2453     GeneratorOutput_Buffer(size_t array_size, const std::string &name, const std::vector<Type> &t = {}, int d = -1)
2454         : Super(array_size, name, IOKind::Buffer, my_types(t), d) {
2455     }
2456 
2457     HALIDE_NO_USER_CODE_INLINE std::string get_c_type() const override {
2458         if (TBase::has_static_halide_type) {
2459             return "Halide::Internal::StubOutputBuffer<" +
2460                    halide_type_to_c_type(TBase::static_halide_type()) +
2461                    ">";
2462         } else {
2463             return "Halide::Internal::StubOutputBuffer<>";
2464         }
2465     }
2466 
2467     template<typename T2, typename std::enable_if<!std::is_same<T2, Func>::value>::type * = nullptr>
2468     HALIDE_NO_USER_CODE_INLINE T2 as() const {
2469         return (T2) * this;
2470     }
2471 
2472 public:
2473     // Allow assignment from a Buffer<> to an Output<Buffer<>>;
2474     // this allows us to use a statically-compiled buffer inside a Generator
2475     // to assign to an output.
2476     // TODO: This used to take the buffer as a const ref. This no longer works as
2477     // using it in a Pipeline might change the dev field so it is currently
2478     // not considered const. We should consider how this really ought to work.
2479     template<typename T2>
2480     HALIDE_NO_USER_CODE_INLINE GeneratorOutput_Buffer<T> &operator=(Buffer<T2> &buffer) {
2481         this->check_gio_access();
2482         this->check_value_writable();
2483 
2484         user_assert(T::can_convert_from(buffer))
2485             << "Cannot assign to the Output \"" << this->name()
2486             << "\": the expression is not convertible to the same Buffer type and/or dimensions.\n";
2487 
2488         if (this->types_defined()) {
2489             user_assert(Type(buffer.type()) == this->type())
2490                 << "Output " << this->name() << " should have type=" << this->type() << " but saw type=" << Type(buffer.type()) << "\n";
2491         }
2492         if (this->dims_defined()) {
2493             user_assert(buffer.dimensions() == this->dims())
2494                 << "Output " << this->name() << " should have dim=" << this->dims() << " but saw dim=" << buffer.dimensions() << "\n";
2495         }
2496 
2497         internal_assert(this->exprs_.empty() && this->funcs_.size() == 1);
2498         user_assert(!this->funcs_.at(0).defined());
2499         this->funcs_.at(0)(_) = buffer(_);
2500 
2501         return *this;
2502     }
2503 
2504     // Allow assignment from a StubOutputBuffer to an Output<Buffer>;
2505     // this allows us to pipeline the results of a Stub to the results
2506     // of the enclosing Generator.
2507     template<typename T2>
2508     GeneratorOutput_Buffer<T> &operator=(const StubOutputBuffer<T2> &stub_output_buffer) {
2509         this->check_gio_access();
2510         assign_from_func(stub_output_buffer.f);
2511         return *this;
2512     }
2513 
2514     // Allow assignment from a Func to an Output<Buffer>;
2515     // this allows us to use helper functions that return a plain Func
2516     // to simply set the output(s) without needing a wrapper Func.
2517     GeneratorOutput_Buffer<T> &operator=(const Func &f) {
2518         this->check_gio_access();
2519         assign_from_func(f);
2520         return *this;
2521     }
2522 
2523     operator OutputImageParam() const {
2524         this->check_gio_access();
2525         user_assert(!this->is_array()) << "Cannot convert an Output<Buffer<>[]> to an ImageParam; use an explicit subscript operator: " << this->name();
2526         internal_assert(this->exprs_.empty() && this->funcs_.size() == 1);
2527         return this->funcs_.at(0).output_buffer();
2528     }
2529 
2530     // 'perfect forwarding' won't work with initializer lists,
2531     // so hand-roll our own forwarding method for set_estimates,
2532     // rather than using HALIDE_FORWARD_METHOD.
2533     GeneratorOutput_Buffer<T> &set_estimates(const Region &estimates) {
2534         this->as<OutputImageParam>().set_estimates(estimates);
2535         return *this;
2536     }
2537 
2538     /** Forward methods to the OutputImageParam. */
2539     // @{
2540     HALIDE_FORWARD_METHOD(OutputImageParam, dim)
2541     HALIDE_FORWARD_METHOD_CONST(OutputImageParam, dim)
2542     HALIDE_FORWARD_METHOD_CONST(OutputImageParam, host_alignment)
2543     HALIDE_FORWARD_METHOD(OutputImageParam, set_host_alignment)
2544     HALIDE_FORWARD_METHOD_CONST(OutputImageParam, dimensions)
2545     HALIDE_FORWARD_METHOD_CONST(OutputImageParam, left)
2546     HALIDE_FORWARD_METHOD_CONST(OutputImageParam, right)
2547     HALIDE_FORWARD_METHOD_CONST(OutputImageParam, top)
2548     HALIDE_FORWARD_METHOD_CONST(OutputImageParam, bottom)
2549     HALIDE_FORWARD_METHOD_CONST(OutputImageParam, width)
2550     HALIDE_FORWARD_METHOD_CONST(OutputImageParam, height)
2551     HALIDE_FORWARD_METHOD_CONST(OutputImageParam, channels)
2552     // }@
2553 };
2554 
2555 template<typename T>
2556 class GeneratorOutput_Func : public GeneratorOutputImpl<T> {
2557 private:
2558     using Super = GeneratorOutputImpl<T>;
2559 
2560     HALIDE_NO_USER_CODE_INLINE Func &get_assignable_func_ref(size_t i) {
2561         internal_assert(this->exprs_.empty() && this->funcs_.size() > i);
2562         return this->funcs_.at(i);
2563     }
2564 
2565 protected:
2566     using TBase = typename Super::TBase;
2567 
2568 protected:
2569     GeneratorOutput_Func(const std::string &name)
2570         : Super(name, IOKind::Function, std::vector<Type>{}, -1) {
2571     }
2572 
2573     GeneratorOutput_Func(const std::string &name, const std::vector<Type> &t, int d = -1)
2574         : Super(name, IOKind::Function, t, d) {
2575     }
2576 
2577     GeneratorOutput_Func(size_t array_size, const std::string &name, const std::vector<Type> &t, int d)
2578         : Super(array_size, name, IOKind::Function, t, d) {
2579     }
2580 
2581 public:
2582     // Allow Output<Func> = Func
2583     template<typename T2 = T, typename std::enable_if<!std::is_array<T2>::value>::type * = nullptr>
2584     GeneratorOutput_Func<T> &operator=(const Func &f) {
2585         this->check_gio_access();
2586         this->check_value_writable();
2587 
2588         // Don't bother verifying the Func type, dimensions, etc., here:
2589         // That's done later, when we produce the pipeline.
2590         get_assignable_func_ref(0) = f;
2591         return *this;
2592     }
2593 
2594     // Allow Output<Func[]> = Func
2595     template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
2596     Func &operator[](size_t i) {
2597         this->check_gio_access();
2598         this->check_value_writable();
2599         return get_assignable_func_ref(i);
2600     }
2601 
2602     // Allow Func = Output<Func[]>
2603     template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
2604     const Func &operator[](size_t i) const {
2605         this->check_gio_access();
2606         return Super::operator[](i);
2607     }
2608 
2609     GeneratorOutput_Func<T> &set_estimate(const Var &var, const Expr &min, const Expr &extent) {
2610         this->check_gio_access();
2611         internal_assert(this->exprs_.empty() && !this->funcs_.empty());
2612         for (Func &f : this->funcs_) {
2613             f.set_estimate(var, min, extent);
2614         }
2615         return *this;
2616     }
2617 
2618     HALIDE_ATTRIBUTE_DEPRECATED("Use set_estimate() instead")
2619     GeneratorOutput_Func<T> &estimate(const Var &var, const Expr &min, const Expr &extent) {
2620         return set_estimate(var, min, extent);
2621     }
2622 
2623     GeneratorOutput_Func<T> &set_estimates(const Region &estimates) {
2624         this->check_gio_access();
2625         internal_assert(this->exprs_.empty() && !this->funcs_.empty());
2626         for (Func &f : this->funcs_) {
2627             f.set_estimates(estimates);
2628         }
2629         return *this;
2630     }
2631 };
2632 
2633 template<typename T>
2634 class GeneratorOutput_Arithmetic : public GeneratorOutputImpl<T> {
2635 private:
2636     using Super = GeneratorOutputImpl<T>;
2637 
2638 protected:
2639     using TBase = typename Super::TBase;
2640 
2641 protected:
2642     explicit GeneratorOutput_Arithmetic(const std::string &name)
2643         : Super(name, IOKind::Function, {type_of<TBase>()}, 0) {
2644     }
2645 
2646     GeneratorOutput_Arithmetic(size_t array_size, const std::string &name)
2647         : Super(array_size, name, IOKind::Function, {type_of<TBase>()}, 0) {
2648     }
2649 };
2650 
2651 template<typename T, typename TBase = typename std::remove_all_extents<T>::type>
2652 using GeneratorOutputImplBase =
2653     typename select_type<
2654         cond<has_static_halide_type_method<TBase>::value, GeneratorOutput_Buffer<T>>,
2655         cond<std::is_same<TBase, Func>::value, GeneratorOutput_Func<T>>,
2656         cond<std::is_arithmetic<TBase>::value, GeneratorOutput_Arithmetic<T>>>::type;
2657 
2658 }  // namespace Internal
2659 
2660 template<typename T>
2661 class GeneratorOutput : public Internal::GeneratorOutputImplBase<T> {
2662 private:
2663     using Super = Internal::GeneratorOutputImplBase<T>;
2664 
2665 protected:
2666     using TBase = typename Super::TBase;
2667 
2668 public:
2669     explicit GeneratorOutput(const std::string &name)
2670         : Super(name) {
2671     }
2672 
2673     explicit GeneratorOutput(const char *name)
2674         : GeneratorOutput(std::string(name)) {
2675     }
2676 
2677     GeneratorOutput(size_t array_size, const std::string &name)
2678         : Super(array_size, name) {
2679     }
2680 
2681     GeneratorOutput(const std::string &name, int d)
2682         : Super(name, {}, d) {
2683     }
2684 
2685     GeneratorOutput(const std::string &name, const Type &t, int d)
2686         : Super(name, {t}, d) {
2687     }
2688 
2689     GeneratorOutput(const std::string &name, const std::vector<Type> &t, int d)
2690         : Super(name, t, d) {
2691     }
2692 
2693     GeneratorOutput(size_t array_size, const std::string &name, int d)
2694         : Super(array_size, name, {}, d) {
2695     }
2696 
2697     GeneratorOutput(size_t array_size, const std::string &name, const Type &t, int d)
2698         : Super(array_size, name, {t}, d) {
2699     }
2700 
2701     GeneratorOutput(size_t array_size, const std::string &name, const std::vector<Type> &t, int d)
2702         : Super(array_size, name, t, d) {
2703     }
2704 
2705     // TODO: This used to take the buffer as a const ref. This no longer works as
2706     // using it in a Pipeline might change the dev field so it is currently
2707     // not considered const. We should consider how this really ought to work.
2708     template<typename T2>
2709     GeneratorOutput<T> &operator=(Buffer<T2> &buffer) {
2710         Super::operator=(buffer);
2711         return *this;
2712     }
2713 
2714     template<typename T2>
2715     GeneratorOutput<T> &operator=(const Internal::StubOutputBuffer<T2> &stub_output_buffer) {
2716         Super::operator=(stub_output_buffer);
2717         return *this;
2718     }
2719 
2720     GeneratorOutput<T> &operator=(const Func &f) {
2721         Super::operator=(f);
2722         return *this;
2723     }
2724 };
2725 
2726 namespace Internal {
2727 
2728 template<typename T>
2729 T parse_scalar(const std::string &value) {
2730     std::istringstream iss(value);
2731     T t;
2732     iss >> t;
2733     user_assert(!iss.fail() && iss.get() == EOF) << "Unable to parse: " << value;
2734     return t;
2735 }
2736 
2737 std::vector<Type> parse_halide_type_list(const std::string &types);
2738 
2739 enum class SyntheticParamType { Type,
2740                                 Dim,
2741                                 ArraySize };
2742 
2743 // This is a type of GeneratorParam used internally to create 'synthetic' params
2744 // (e.g. image.type, image.dim); it is not possible for user code to instantiate it.
2745 template<typename T>
2746 class GeneratorParam_Synthetic : public GeneratorParamImpl<T> {
2747 public:
2748     void set_from_string(const std::string &new_value_string) override {
2749         // If error_msg is not empty, this is unsettable:
2750         // display error_msg as a user error.
2751         if (!error_msg.empty()) {
2752             user_error << error_msg;
2753         }
2754         set_from_string_impl<T>(new_value_string);
2755     }
2756 
2757     std::string get_default_value() const override {
2758         internal_error;
2759         return std::string();
2760     }
2761 
2762     std::string call_to_string(const std::string &v) const override {
2763         internal_error;
2764         return std::string();
2765     }
2766 
2767     std::string get_c_type() const override {
2768         internal_error;
2769         return std::string();
2770     }
2771 
2772     bool is_synthetic_param() const override {
2773         return true;
2774     }
2775 
2776 private:
2777     friend class GeneratorParamInfo;
2778 
2779     static std::unique_ptr<Internal::GeneratorParamBase> make(
2780         GeneratorBase *generator,
2781         const std::string &generator_name,
2782         const std::string &gpname,
2783         GIOBase &gio,
2784         SyntheticParamType which,
2785         bool defined) {
2786         std::string error_msg = defined ? "Cannot set the GeneratorParam " + gpname + " for " + generator_name + " because the value is explicitly specified in the C++ source." : "";
2787         return std::unique_ptr<GeneratorParam_Synthetic<T>>(
2788             new GeneratorParam_Synthetic<T>(gpname, gio, which, error_msg));
2789     }
2790 
2791     GeneratorParam_Synthetic(const std::string &name, GIOBase &gio, SyntheticParamType which, const std::string &error_msg = "")
2792         : GeneratorParamImpl<T>(name, T()), gio(gio), which(which), error_msg(error_msg) {
2793     }
2794 
2795     template<typename T2 = T, typename std::enable_if<std::is_same<T2, ::Halide::Type>::value>::type * = nullptr>
2796     void set_from_string_impl(const std::string &new_value_string) {
2797         internal_assert(which == SyntheticParamType::Type);
2798         gio.types_ = parse_halide_type_list(new_value_string);
2799     }
2800 
2801     template<typename T2 = T, typename std::enable_if<std::is_integral<T2>::value>::type * = nullptr>
2802     void set_from_string_impl(const std::string &new_value_string) {
2803         if (which == SyntheticParamType::Dim) {
2804             gio.dims_ = parse_scalar<T2>(new_value_string);
2805         } else if (which == SyntheticParamType::ArraySize) {
2806             gio.array_size_ = parse_scalar<T2>(new_value_string);
2807         } else {
2808             internal_error;
2809         }
2810     }
2811 
2812     GIOBase &gio;
2813     const SyntheticParamType which;
2814     const std::string error_msg;
2815 };
2816 
2817 class GeneratorStub;
2818 
2819 }  // namespace Internal
2820 
2821 /** GeneratorContext is a base class that is used when using Generators (or Stubs) directly;
2822  * it is used to allow the outer context (typically, either a Generator or "top-level" code)
2823  * to specify certain information to the inner context to ensure that inner and outer
2824  * Generators are compiled in a compatible way.
2825  *
2826  * If you are using this at "top level" (e.g. with the JIT), you can construct a GeneratorContext
2827  * with a Target:
2828  * \code
2829  *   auto my_stub = MyStub(
2830  *       GeneratorContext(get_target_from_environment()),
2831  *       // inputs
2832  *       { ... },
2833  *       // generator params
2834  *       { ... }
2835  *   );
2836  * \endcode
2837  *
2838  * Note that all Generators inherit from GeneratorContext, so if you are using a Stub
2839  * from within a Generator, you can just pass 'this' for the GeneratorContext:
2840  * \code
2841  *  struct SomeGen : Generator<SomeGen> {
2842  *   void generate() {
2843  *     ...
2844  *     auto my_stub = MyStub(
2845  *       this,  // GeneratorContext
2846  *       // inputs
2847  *       { ... },
2848  *       // generator params
2849  *       { ... }
2850  *     );
2851  *     ...
2852  *   }
2853  *  };
2854  * \endcode
2855  */
2856 class GeneratorContext {
2857 public:
2858     using ExternsMap = std::map<std::string, ExternalCode>;
2859 
2860     explicit GeneratorContext(const Target &t,
2861                               bool auto_schedule = false,
2862                               const MachineParams &machine_params = MachineParams::generic());
2863     virtual ~GeneratorContext();
2864 
2865     inline Target get_target() const {
2866         return target;
2867     }
2868     inline bool get_auto_schedule() const {
2869         return auto_schedule;
2870     }
2871     inline MachineParams get_machine_params() const {
2872         return machine_params;
2873     }
2874 
2875     /** Generators can register ExternalCode objects onto
2876      * themselves. The Generator infrastructure will arrange to have
2877      * this ExternalCode appended to the Module that is finally
2878      * compiled using the Generator. This allows encapsulating
2879      * functionality that depends on external libraries or handwritten
2880      * code for various targets. The name argument should match the
2881      * name of the ExternalCode block and is used to ensure the same
2882      * code block is not duplicated in the output. Halide does not do
2883      * anything other than to compare names for equality. To guarantee
2884      * uniqueness in public code, we suggest using a Java style
2885      * inverted domain name followed by organization specific
2886      * naming. E.g.:
2887      *     com.yoyodyne.overthruster.0719acd19b66df2a9d8d628a8fefba911a0ab2b7
2888      *
2889      * See test/generator/external_code_generator.cpp for example use. */
2890     inline std::shared_ptr<ExternsMap> get_externs_map() const {
2891         return externs_map;
2892     }
2893 
2894     template<typename T>
2895     inline std::unique_ptr<T> create() const {
2896         return T::create(*this);
2897     }
2898 
2899     template<typename T, typename... Args>
2900     inline std::unique_ptr<T> apply(const Args &... args) const {
2901         auto t = this->create<T>();
2902         t->apply(args...);
2903         return t;
2904     }
2905 
2906 protected:
2907     GeneratorParam<Target> target;
2908     GeneratorParam<bool> auto_schedule;
2909     GeneratorParam<MachineParams> machine_params;
2910     std::shared_ptr<ExternsMap> externs_map;
2911     std::shared_ptr<Internal::ValueTracker> value_tracker;
2912 
2913     GeneratorContext()
2914         : GeneratorContext(Target()) {
2915     }
2916 
2917     virtual void init_from_context(const Halide::GeneratorContext &context);
2918 
2919     inline std::shared_ptr<Internal::ValueTracker> get_value_tracker() const {
2920         return value_tracker;
2921     }
2922 
2923     // No copy
2924     GeneratorContext(const GeneratorContext &) = delete;
2925     void operator=(const GeneratorContext &) = delete;
2926     // No move
2927     GeneratorContext(GeneratorContext &&) = delete;
2928     void operator=(GeneratorContext &&) = delete;
2929 };
2930 
2931 class NamesInterface {
2932     // Names in this class are only intended for use in derived classes.
2933 protected:
2934     // Import a consistent list of Halide names that can be used in
2935     // Halide generators without qualification.
2936     using Expr = Halide::Expr;
2937     using ExternFuncArgument = Halide::ExternFuncArgument;
2938     using Func = Halide::Func;
2939     using GeneratorContext = Halide::GeneratorContext;
2940     using ImageParam = Halide::ImageParam;
2941     using LoopLevel = Halide::LoopLevel;
2942     using MemoryType = Halide::MemoryType;
2943     using NameMangling = Halide::NameMangling;
2944     using Pipeline = Halide::Pipeline;
2945     using PrefetchBoundStrategy = Halide::PrefetchBoundStrategy;
2946     using RDom = Halide::RDom;
2947     using RVar = Halide::RVar;
2948     using TailStrategy = Halide::TailStrategy;
2949     using Target = Halide::Target;
2950     using Tuple = Halide::Tuple;
2951     using Type = Halide::Type;
2952     using Var = Halide::Var;
2953     template<typename T>
2954     static Expr cast(Expr e) {
2955         return Halide::cast<T>(e);
2956     }
2957     static inline Expr cast(Halide::Type t, Expr e) {
2958         return Halide::cast(t, std::move(e));
2959     }
2960     template<typename T>
2961     using GeneratorParam = Halide::GeneratorParam<T>;
2962     template<typename T = void>
2963     using Buffer = Halide::Buffer<T>;
2964     template<typename T>
2965     using Param = Halide::Param<T>;
2966     static inline Type Bool(int lanes = 1) {
2967         return Halide::Bool(lanes);
2968     }
2969     static inline Type Float(int bits, int lanes = 1) {
2970         return Halide::Float(bits, lanes);
2971     }
2972     static inline Type Int(int bits, int lanes = 1) {
2973         return Halide::Int(bits, lanes);
2974     }
2975     static inline Type UInt(int bits, int lanes = 1) {
2976         return Halide::UInt(bits, lanes);
2977     }
2978 };
2979 
2980 namespace Internal {
2981 
2982 template<typename... Args>
2983 struct NoRealizations : std::false_type {};
2984 
2985 template<>
2986 struct NoRealizations<> : std::true_type {};
2987 
2988 template<typename T, typename... Args>
2989 struct NoRealizations<T, Args...> {
2990     static const bool value = !std::is_convertible<T, Realization>::value && NoRealizations<Args...>::value;
2991 };
2992 
2993 class GeneratorStub;
2994 class SimpleGeneratorFactory;
2995 
2996 // Note that these functions must never return null:
2997 // if they cannot return a valid Generator, they must assert-fail.
2998 using GeneratorFactory = std::function<std::unique_ptr<GeneratorBase>(const GeneratorContext &)>;
2999 
3000 struct StringOrLoopLevel {
3001     std::string string_value;
3002     LoopLevel loop_level;
3003 
3004     StringOrLoopLevel() = default;
3005     /*not-explicit*/ StringOrLoopLevel(const char *s)
3006         : string_value(s) {
3007     }
3008     /*not-explicit*/ StringOrLoopLevel(const std::string &s)
3009         : string_value(s) {
3010     }
3011     /*not-explicit*/ StringOrLoopLevel(const LoopLevel &loop_level)
3012         : loop_level(loop_level) {
3013     }
3014 };
3015 using GeneratorParamsMap = std::map<std::string, StringOrLoopLevel>;
3016 
3017 class GeneratorParamInfo {
3018     // names used across all params, inputs, and outputs.
3019     std::set<std::string> names;
3020 
3021     // Ordered-list of non-null ptrs to GeneratorParam<> fields.
3022     std::vector<Internal::GeneratorParamBase *> filter_generator_params;
3023 
3024     // Ordered-list of non-null ptrs to Input<> fields.
3025     std::vector<Internal::GeneratorInputBase *> filter_inputs;
3026 
3027     // Ordered-list of non-null ptrs to Output<> fields; empty if old-style Generator.
3028     std::vector<Internal::GeneratorOutputBase *> filter_outputs;
3029 
3030     // list of synthetic GP's that we dynamically created; this list only exists to simplify
3031     // lifetime management, and shouldn't be accessed directly outside of our ctor/dtor,
3032     // regardless of friend access.
3033     std::vector<std::unique_ptr<Internal::GeneratorParamBase>> owned_synthetic_params;
3034 
3035     // list of dynamically-added inputs and outputs, here only for lifetime management.
3036     std::vector<std::unique_ptr<Internal::GIOBase>> owned_extras;
3037 
3038 public:
3039     friend class GeneratorBase;
3040 
3041     GeneratorParamInfo(GeneratorBase *generator, const size_t size);
3042 
3043     const std::vector<Internal::GeneratorParamBase *> &generator_params() const {
3044         return filter_generator_params;
3045     }
3046     const std::vector<Internal::GeneratorInputBase *> &inputs() const {
3047         return filter_inputs;
3048     }
3049     const std::vector<Internal::GeneratorOutputBase *> &outputs() const {
3050         return filter_outputs;
3051     }
3052 };
3053 
3054 class GeneratorBase : public NamesInterface, public GeneratorContext {
3055 public:
3056     ~GeneratorBase() override;
3057 
3058     void set_generator_param_values(const GeneratorParamsMap &params);
3059 
3060     /** Given a data type, return an estimate of the "natural" vector size
3061      * for that data type when compiling for the current target. */
3062     int natural_vector_size(Halide::Type t) const {
3063         return get_target().natural_vector_size(t);
3064     }
3065 
3066     /** Given a data type, return an estimate of the "natural" vector size
3067      * for that data type when compiling for the current target. */
3068     template<typename data_t>
3069     int natural_vector_size() const {
3070         return get_target().natural_vector_size<data_t>();
3071     }
3072 
3073     void emit_cpp_stub(const std::string &stub_file_path);
3074 
3075     // Call build() and produce a Module for the result.
3076     // If function_name is empty, generator_name() will be used for the function.
3077     Module build_module(const std::string &function_name = "",
3078                         const LinkageType linkage_type = LinkageType::ExternalPlusMetadata);
3079 
3080     /**
3081      * Build a module that is suitable for using for gradient descent calculation in TensorFlow or PyTorch.
3082      *
3083      * Essentially:
3084      *   - A new Pipeline is synthesized from the current Generator (according to the rules below)
3085      *   - The new Pipeline is autoscheduled (if autoscheduling is requested, but it would be odd not to do so)
3086      *   - The Pipeline is compiled to a Module and returned
3087      *
3088      * The new Pipeline is adjoint to the original; it has:
3089      *   - All the same inputs as the original, in the same order
3090      *   - Followed by one grad-input for each original output
3091      *   - Followed by one output for each unique pairing of original-output + original-input.
3092      *     (For the common case of just one original-output, this amounts to being one output for each original-input.)
3093      */
3094     Module build_gradient_module(const std::string &function_name);
3095 
3096     /**
3097      * set_inputs is a variadic wrapper around set_inputs_vector, which makes usage much simpler
3098      * in many cases, as it constructs the relevant entries for the vector for you, which
3099      * is often a bit unintuitive at present. The arguments are passed in Input<>-declaration-order,
3100      * and the types must be compatible. Array inputs are passed as std::vector<> of the relevant type.
3101      *
3102      * Note: at present, scalar input types must match *exactly*, i.e., for Input<uint8_t>, you
3103      * must pass an argument that is actually uint8_t; an argument that is int-that-will-fit-in-uint8
3104      * will assert-fail at Halide compile time.
3105      */
3106     template<typename... Args>
3107     void set_inputs(const Args &... args) {
3108         // set_inputs_vector() checks this too, but checking it here allows build_inputs() to avoid out-of-range checks.
3109         GeneratorParamInfo &pi = this->param_info();
3110         user_assert(sizeof...(args) == pi.inputs().size())
3111             << "Expected exactly " << pi.inputs().size()
3112             << " inputs but got " << sizeof...(args) << "\n";
3113         set_inputs_vector(build_inputs(std::forward_as_tuple<const Args &...>(args...), make_index_sequence<sizeof...(Args)>{}));
3114     }
3115 
3116     Realization realize(std::vector<int32_t> sizes) {
3117         this->check_scheduled("realize");
3118         return get_pipeline().realize(std::move(sizes), get_target());
3119     }
3120 
3121     // Only enable if none of the args are Realization; otherwise we can incorrectly
3122     // select this method instead of the Realization-as-outparam variant
3123     template<typename... Args, typename std::enable_if<NoRealizations<Args...>::value>::type * = nullptr>
3124     Realization realize(Args &&... args) {
3125         this->check_scheduled("realize");
3126         return get_pipeline().realize(std::forward<Args>(args)..., get_target());
3127     }
3128 
3129     void realize(Realization r) {
3130         this->check_scheduled("realize");
3131         get_pipeline().realize(r, get_target());
3132     }
3133 
3134     // Return the Pipeline that has been built by the generate() method.
3135     // This method can only be used from a Generator that has a generate()
3136     // method (vs a build() method), and currently can only be called from
3137     // the schedule() method. (This may be relaxed in the future to allow
3138     // calling from generate() as long as all Outputs have been defined.)
3139     Pipeline get_pipeline();
3140 
3141     // Create Input<Buffer> or Input<Func> with dynamic type
3142     template<typename T,
3143              typename std::enable_if<!std::is_arithmetic<T>::value>::type * = nullptr>
3144     GeneratorInput<T> *add_input(const std::string &name, const Type &t, int dimensions) {
3145         check_exact_phase(GeneratorBase::ConfigureCalled);
3146         auto *p = new GeneratorInput<T>(name, t, dimensions);
3147         p->generator = this;
3148         param_info_ptr->owned_extras.push_back(std::unique_ptr<Internal::GIOBase>(p));
3149         param_info_ptr->filter_inputs.push_back(p);
3150         return p;
3151     }
3152 
3153     // Create a Input<Buffer> or Input<Func> with compile-time type
3154     template<typename T,
3155              typename std::enable_if<T::has_static_halide_type>::type * = nullptr>
3156     GeneratorInput<T> *add_input(const std::string &name, int dimensions) {
3157         check_exact_phase(GeneratorBase::ConfigureCalled);
3158         auto *p = new GeneratorInput<T>(name, dimensions);
3159         p->generator = this;
3160         param_info_ptr->owned_extras.push_back(std::unique_ptr<Internal::GIOBase>(p));
3161         param_info_ptr->filter_inputs.push_back(p);
3162         return p;
3163     }
3164 
3165     // Create Input<scalar>
3166     template<typename T,
3167              typename std::enable_if<std::is_arithmetic<T>::value>::type * = nullptr>
3168     GeneratorInput<T> *add_input(const std::string &name) {
3169         check_exact_phase(GeneratorBase::ConfigureCalled);
3170         auto *p = new GeneratorInput<T>(name);
3171         p->generator = this;
3172         param_info_ptr->owned_extras.push_back(std::unique_ptr<Internal::GIOBase>(p));
3173         param_info_ptr->filter_inputs.push_back(p);
3174         return p;
3175     }
3176 
3177     // Create Output<Buffer> or Output<Func> with dynamic type
3178     template<typename T,
3179              typename std::enable_if<!std::is_arithmetic<T>::value>::type * = nullptr>
3180     GeneratorOutput<T> *add_output(const std::string &name, const Type &t, int dimensions) {
3181         check_exact_phase(GeneratorBase::ConfigureCalled);
3182         auto *p = new GeneratorOutput<T>(name, t, dimensions);
3183         p->generator = this;
3184         param_info_ptr->owned_extras.push_back(std::unique_ptr<Internal::GIOBase>(p));
3185         param_info_ptr->filter_outputs.push_back(p);
3186         return p;
3187     }
3188 
3189     // Create a Output<Buffer> or Output<Func> with compile-time type
3190     template<typename T,
3191              typename std::enable_if<T::has_static_halide_type>::type * = nullptr>
3192     GeneratorOutput<T> *add_output(const std::string &name, int dimensions) {
3193         check_exact_phase(GeneratorBase::ConfigureCalled);
3194         auto *p = new GeneratorOutput<T>(name, dimensions);
3195         p->generator = this;
3196         param_info_ptr->owned_extras.push_back(std::unique_ptr<Internal::GIOBase>(p));
3197         param_info_ptr->filter_outputs.push_back(p);
3198         return p;
3199     }
3200 
3201     template<typename... Args>
3202     HALIDE_NO_USER_CODE_INLINE void add_requirement(Expr condition, Args &&... args) {
3203         get_pipeline().add_requirement(condition, std::forward<Args>(args)...);
3204     }
3205 
3206     void trace_pipeline() {
3207         get_pipeline().trace_pipeline();
3208     }
3209 
3210 protected:
3211     GeneratorBase(size_t size, const void *introspection_helper);
3212     void set_generator_names(const std::string &registered_name, const std::string &stub_name);
3213 
3214     void init_from_context(const Halide::GeneratorContext &context) override;
3215 
3216     virtual Pipeline build_pipeline() = 0;
3217     virtual void call_configure() = 0;
3218     virtual void call_generate() = 0;
3219     virtual void call_schedule() = 0;
3220 
3221     void track_parameter_values(bool include_outputs);
3222 
3223     void pre_build();
3224     void post_build();
3225     void pre_configure();
3226     void post_configure();
3227     void pre_generate();
3228     void post_generate();
3229     void pre_schedule();
3230     void post_schedule();
3231 
3232     template<typename T>
3233     using Input = GeneratorInput<T>;
3234 
3235     template<typename T>
3236     using Output = GeneratorOutput<T>;
3237 
3238     // A Generator's creation and usage must go in a certain phase to ensure correctness;
3239     // the state machine here is advanced and checked at various points to ensure
3240     // this is the case.
3241     enum Phase {
3242         // Generator has just come into being.
3243         Created,
3244 
3245         // Generator has had its configure() method called. (For Generators without
3246         // a configure() method, this phase will be skipped and will advance
3247         // directly to InputsSet.)
3248         ConfigureCalled,
3249 
3250         // All Input<>/Param<> fields have been set. (Applicable only in JIT mode;
3251         // in AOT mode, this can be skipped, going Created->GenerateCalled directly.)
3252         InputsSet,
3253 
3254         // Generator has had its generate() method called. (For Generators with
3255         // a build() method instead of generate(), this phase will be skipped
3256         // and will advance directly to ScheduleCalled.)
3257         GenerateCalled,
3258 
3259         // Generator has had its schedule() method (if any) called.
3260         ScheduleCalled,
3261     } phase{Created};
3262 
3263     void check_exact_phase(Phase expected_phase) const;
3264     void check_min_phase(Phase expected_phase) const;
3265     void advance_phase(Phase new_phase);
3266 
3267 private:
3268     friend void ::Halide::Internal::generator_test();
3269     friend class GeneratorParamBase;
3270     friend class GIOBase;
3271     friend class GeneratorInputBase;
3272     friend class GeneratorOutputBase;
3273     friend class GeneratorParamInfo;
3274     friend class GeneratorStub;
3275     friend class SimpleGeneratorFactory;
3276     friend class StubOutputBufferBase;
3277 
3278     const size_t size;
3279 
3280     // Lazily-allocated-and-inited struct with info about our various Params.
3281     // Do not access directly: use the param_info() getter.
3282     std::unique_ptr<GeneratorParamInfo> param_info_ptr;
3283 
3284     mutable std::shared_ptr<ExternsMap> externs_map;
3285 
3286     bool inputs_set{false};
3287     std::string generator_registered_name, generator_stub_name;
3288     Pipeline pipeline;
3289 
3290     // Return our GeneratorParamInfo.
3291     GeneratorParamInfo &param_info();
3292 
3293     Internal::GeneratorOutputBase *find_output_by_name(const std::string &name);
3294 
3295     void check_scheduled(const char *m) const;
3296 
3297     void build_params(bool force = false);
3298 
3299     // Provide private, unimplemented, wrong-result-type methods here
3300     // so that Generators don't attempt to call the global methods
3301     // of the same name by accident: use the get_target() method instead.
3302     void get_host_target();
3303     void get_jit_target_from_environment();
3304     void get_target_from_environment();
3305 
3306     // Return the Output<Func> or Output<Buffer> with the given name,
3307     // which must be a singular (non-array) Func or Buffer output.
3308     // If no such name exists (or is non-array), assert; this method never returns an undefined Func.
3309     Func get_output(const std::string &n);
3310 
3311     // Return the Output<Func[]> with the given name, which must be an
3312     // array-of-Func output. If no such name exists (or is non-array), assert;
3313     // this method never returns undefined Funcs.
3314     std::vector<Func> get_array_output(const std::string &n);
3315 
3316     void set_inputs_vector(const std::vector<std::vector<StubInput>> &inputs);
3317 
3318     static void check_input_is_singular(Internal::GeneratorInputBase *in);
3319     static void check_input_is_array(Internal::GeneratorInputBase *in);
3320     static void check_input_kind(Internal::GeneratorInputBase *in, Internal::IOKind kind);
3321 
3322     // Allow Buffer<> if:
3323     // -- we are assigning it to an Input<Buffer<>> (with compatible type and dimensions),
3324     // causing the Input<Buffer<>> to become a precompiled buffer in the generated code.
3325     // -- we are assigningit to an Input<Func>, in which case we just Func-wrap the Buffer<>.
3326     template<typename T>
3327     std::vector<StubInput> build_input(size_t i, const Buffer<T> &arg) {
3328         auto *in = param_info().inputs().at(i);
3329         check_input_is_singular(in);
3330         const auto k = in->kind();
3331         if (k == Internal::IOKind::Buffer) {
3332             Halide::Buffer<> b = arg;
3333             StubInputBuffer<> sib(b);
3334             StubInput si(sib);
3335             return {si};
3336         } else if (k == Internal::IOKind::Function) {
3337             Halide::Func f(arg.name() + "_im");
3338             f(Halide::_) = arg(Halide::_);
3339             StubInput si(f);
3340             return {si};
3341         } else {
3342             check_input_kind(in, Internal::IOKind::Buffer);  // just to trigger assertion
3343             return {};
3344         }
3345     }
3346 
3347     // Allow Input<Buffer<>> if:
3348     // -- we are assigning it to another Input<Buffer<>> (with compatible type and dimensions),
3349     // allowing us to simply pipe a parameter from an enclosing Generator to the Invoker.
3350     // -- we are assigningit to an Input<Func>, in which case we just Func-wrap the Input<Buffer<>>.
3351     template<typename T>
3352     std::vector<StubInput> build_input(size_t i, const GeneratorInput<Buffer<T>> &arg) {
3353         auto *in = param_info().inputs().at(i);
3354         check_input_is_singular(in);
3355         const auto k = in->kind();
3356         if (k == Internal::IOKind::Buffer) {
3357             StubInputBuffer<> sib = arg;
3358             StubInput si(sib);
3359             return {si};
3360         } else if (k == Internal::IOKind::Function) {
3361             Halide::Func f = arg.funcs().at(0);
3362             StubInput si(f);
3363             return {si};
3364         } else {
3365             check_input_kind(in, Internal::IOKind::Buffer);  // just to trigger assertion
3366             return {};
3367         }
3368     }
3369 
3370     // Allow Func iff we are assigning it to an Input<Func> (with compatible type and dimensions).
3371     std::vector<StubInput> build_input(size_t i, const Func &arg) {
3372         auto *in = param_info().inputs().at(i);
3373         check_input_kind(in, Internal::IOKind::Function);
3374         check_input_is_singular(in);
3375         const Halide::Func &f = arg;
3376         StubInput si(f);
3377         return {si};
3378     }
3379 
3380     // Allow vector<Func> iff we are assigning it to an Input<Func[]> (with compatible type and dimensions).
3381     std::vector<StubInput> build_input(size_t i, const std::vector<Func> &arg) {
3382         auto *in = param_info().inputs().at(i);
3383         check_input_kind(in, Internal::IOKind::Function);
3384         check_input_is_array(in);
3385         // My kingdom for a list comprehension...
3386         std::vector<StubInput> siv;
3387         siv.reserve(arg.size());
3388         for (const auto &f : arg) {
3389             siv.emplace_back(f);
3390         }
3391         return siv;
3392     }
3393 
3394     // Expr must be Input<Scalar>.
3395     std::vector<StubInput> build_input(size_t i, const Expr &arg) {
3396         auto *in = param_info().inputs().at(i);
3397         check_input_kind(in, Internal::IOKind::Scalar);
3398         check_input_is_singular(in);
3399         StubInput si(arg);
3400         return {si};
3401     }
3402 
3403     // (Array form)
3404     std::vector<StubInput> build_input(size_t i, const std::vector<Expr> &arg) {
3405         auto *in = param_info().inputs().at(i);
3406         check_input_kind(in, Internal::IOKind::Scalar);
3407         check_input_is_array(in);
3408         std::vector<StubInput> siv;
3409         siv.reserve(arg.size());
3410         for (const auto &value : arg) {
3411             siv.emplace_back(value);
3412         }
3413         return siv;
3414     }
3415 
3416     // Any other type must be convertible to Expr and must be associated with an Input<Scalar>.
3417     // Use is_arithmetic since some Expr conversions are explicit.
3418     template<typename T,
3419              typename std::enable_if<std::is_arithmetic<T>::value>::type * = nullptr>
3420     std::vector<StubInput> build_input(size_t i, const T &arg) {
3421         auto *in = param_info().inputs().at(i);
3422         check_input_kind(in, Internal::IOKind::Scalar);
3423         check_input_is_singular(in);
3424         // We must use an explicit Expr() ctor to preserve the type
3425         Expr e(arg);
3426         StubInput si(e);
3427         return {si};
3428     }
3429 
3430     // (Array form)
3431     template<typename T,
3432              typename std::enable_if<std::is_arithmetic<T>::value>::type * = nullptr>
3433     std::vector<StubInput> build_input(size_t i, const std::vector<T> &arg) {
3434         auto *in = param_info().inputs().at(i);
3435         check_input_kind(in, Internal::IOKind::Scalar);
3436         check_input_is_array(in);
3437         std::vector<StubInput> siv;
3438         siv.reserve(arg.size());
3439         for (const auto &value : arg) {
3440             // We must use an explicit Expr() ctor to preserve the type;
3441             // otherwise, implicit conversions can downgrade (e.g.) float -> int
3442             Expr e(value);
3443             siv.emplace_back(e);
3444         }
3445         return siv;
3446     }
3447 
3448     template<typename... Args, size_t... Indices>
3449     std::vector<std::vector<StubInput>> build_inputs(const std::tuple<const Args &...> &t, index_sequence<Indices...>) {
3450         return {build_input(Indices, std::get<Indices>(t))...};
3451     }
3452 
3453     // No copy
3454     GeneratorBase(const GeneratorBase &) = delete;
3455     void operator=(const GeneratorBase &) = delete;
3456     // No move
3457     GeneratorBase(GeneratorBase &&that) = delete;
3458     void operator=(GeneratorBase &&that) = delete;
3459 };
3460 
3461 class GeneratorRegistry {
3462 public:
3463     static void register_factory(const std::string &name, GeneratorFactory generator_factory);
3464     static void unregister_factory(const std::string &name);
3465     static std::vector<std::string> enumerate();
3466     // Note that this method will never return null:
3467     // if it cannot return a valid Generator, it should assert-fail.
3468     static std::unique_ptr<GeneratorBase> create(const std::string &name,
3469                                                  const Halide::GeneratorContext &context);
3470 
3471 private:
3472     using GeneratorFactoryMap = std::map<const std::string, GeneratorFactory>;
3473 
3474     GeneratorFactoryMap factories;
3475     std::mutex mutex;
3476 
3477     static GeneratorRegistry &get_registry();
3478 
3479     GeneratorRegistry() = default;
3480     GeneratorRegistry(const GeneratorRegistry &) = delete;
3481     void operator=(const GeneratorRegistry &) = delete;
3482 };
3483 
3484 }  // namespace Internal
3485 
3486 template<class T>
3487 class Generator : public Internal::GeneratorBase {
3488 protected:
3489     Generator()
3490         : Internal::GeneratorBase(sizeof(T),
3491                                   Internal::Introspection::get_introspection_helper<T>()) {
3492     }
3493 
3494 public:
3495     static std::unique_ptr<T> create(const Halide::GeneratorContext &context) {
3496         // We must have an object of type T (not merely GeneratorBase) to call a protected method,
3497         // because CRTP is a weird beast.
3498         auto g = std::unique_ptr<T>(new T());
3499         g->init_from_context(context);
3500         return g;
3501     }
3502 
3503     // This is public but intended only for use by the HALIDE_REGISTER_GENERATOR() macro.
3504     static std::unique_ptr<T> create(const Halide::GeneratorContext &context,
3505                                      const std::string &registered_name,
3506                                      const std::string &stub_name) {
3507         auto g = create(context);
3508         g->set_generator_names(registered_name, stub_name);
3509         return g;
3510     }
3511 
3512     using Internal::GeneratorBase::apply;
3513     using Internal::GeneratorBase::create;
3514 
3515     template<typename... Args>
3516     void apply(const Args &... args) {
3517 #ifndef _MSC_VER
3518         // VS2015 apparently has some SFINAE issues, so this can inappropriately
3519         // trigger there. (We'll still fail when generate() is called, just
3520         // with a less-helpful error message.)
3521         static_assert(has_generate_method<T>::value, "apply() is not supported for old-style Generators.");
3522 #endif
3523         call_configure();
3524         set_inputs(args...);
3525         call_generate();
3526         call_schedule();
3527     }
3528 
3529 private:
3530     // std::is_member_function_pointer will fail if there is no member of that name,
3531     // so we use a little SFINAE to detect if there are method-shaped members.
3532     template<typename>
3533     struct type_sink { typedef void type; };
3534 
3535     template<typename T2, typename = void>
3536     struct has_configure_method : std::false_type {};
3537 
3538     template<typename T2>
3539     struct has_configure_method<T2, typename type_sink<decltype(std::declval<T2>().configure())>::type> : std::true_type {};
3540 
3541     template<typename T2, typename = void>
3542     struct has_generate_method : std::false_type {};
3543 
3544     template<typename T2>
3545     struct has_generate_method<T2, typename type_sink<decltype(std::declval<T2>().generate())>::type> : std::true_type {};
3546 
3547     template<typename T2, typename = void>
3548     struct has_schedule_method : std::false_type {};
3549 
3550     template<typename T2>
3551     struct has_schedule_method<T2, typename type_sink<decltype(std::declval<T2>().schedule())>::type> : std::true_type {};
3552 
3553     template<typename T2 = T,
3554              typename std::enable_if<!has_generate_method<T2>::value>::type * = nullptr>
3555 
3556     // Implementations for build_pipeline_impl(), specialized on whether we
3557     // have build() or generate()/schedule() methods.
3558 
3559     // MSVC apparently has some weirdness with the usual sfinae tricks
3560     // for detecting method-shaped things, so we can't actually use
3561     // the helpers above outside of static_assert. Instead we make as
3562     // many overloads as we can exist, and then use C++'s preference
3563     // for treating a 0 as an int rather than a double to choose one
3564     // of them.
3565     Pipeline build_pipeline_impl(double) {
3566         static_assert(!has_configure_method<T2>::value, "The configure() method is ignored if you define a build() method; use generate() instead.");
3567         static_assert(!has_schedule_method<T2>::value, "The schedule() method is ignored if you define a build() method; use generate() instead.");
3568         pre_build();
3569         Pipeline p = ((T *)this)->build();
3570         post_build();
3571         return p;
3572     }
3573 
3574     template<typename T2 = T,
3575              typename = decltype(std::declval<T2>().generate())>
3576     Pipeline build_pipeline_impl(int) {
3577         // No: configure() must be called prior to this
3578         // (and in fact, prior to calling set_inputs).
3579         //
3580         // ((T *)this)->call_configure_impl(0, 0);
3581 
3582         ((T *)this)->call_generate_impl(0);
3583         ((T *)this)->call_schedule_impl(0, 0);
3584         return get_pipeline();
3585     }
3586 
3587     // Implementations for call_configure_impl(), specialized on whether we
3588     // have build() or configure()/generate()/schedule() methods.
3589 
3590     void call_configure_impl(double, double) {
3591         // Called as a side effect for build()-method Generators; quietly do nothing.
3592     }
3593 
3594     template<typename T2 = T,
3595              typename = decltype(std::declval<T2>().generate())>
3596     void call_configure_impl(double, int) {
3597         // Generator has a generate() method but no configure() method. This is ok. Just advance the phase.
3598         pre_configure();
3599         static_assert(!has_configure_method<T2>::value, "Did not expect a configure method here.");
3600         post_configure();
3601     }
3602 
3603     template<typename T2 = T,
3604              typename = decltype(std::declval<T2>().generate()),
3605              typename = decltype(std::declval<T2>().configure())>
3606     void call_configure_impl(int, int) {
3607         T *t = (T *)this;
3608         static_assert(std::is_void<decltype(t->configure())>::value, "configure() must return void");
3609         pre_configure();
3610         t->configure();
3611         post_configure();
3612     }
3613 
3614     // Implementations for call_generate_impl(), specialized on whether we
3615     // have build() or configure()/generate()/schedule() methods.
3616 
3617     void call_generate_impl(double) {
3618         user_error << "Unimplemented";
3619     }
3620 
3621     template<typename T2 = T,
3622              typename = decltype(std::declval<T2>().generate())>
3623     void call_generate_impl(int) {
3624         T *t = (T *)this;
3625         static_assert(std::is_void<decltype(t->generate())>::value, "generate() must return void");
3626         pre_generate();
3627         t->generate();
3628         post_generate();
3629     }
3630 
3631     // Implementations for call_schedule_impl(), specialized on whether we
3632     // have build() or configure()generate()/schedule() methods.
3633 
3634     void call_schedule_impl(double, double) {
3635         user_error << "Unimplemented";
3636     }
3637 
3638     template<typename T2 = T,
3639              typename = decltype(std::declval<T2>().generate())>
3640     void call_schedule_impl(double, int) {
3641         // Generator has a generate() method but no schedule() method. This is ok. Just advance the phase.
3642         pre_schedule();
3643         post_schedule();
3644     }
3645 
3646     template<typename T2 = T,
3647              typename = decltype(std::declval<T2>().generate()),
3648              typename = decltype(std::declval<T2>().schedule())>
3649     void call_schedule_impl(int, int) {
3650         T *t = (T *)this;
3651         static_assert(std::is_void<decltype(t->schedule())>::value, "schedule() must return void");
3652         pre_schedule();
3653         t->schedule();
3654         post_schedule();
3655     }
3656 
3657 protected:
3658     Pipeline build_pipeline() override {
3659         return this->build_pipeline_impl(0);
3660     }
3661 
3662     void call_configure() override {
3663         this->call_configure_impl(0, 0);
3664     }
3665 
3666     void call_generate() override {
3667         this->call_generate_impl(0);
3668     }
3669 
3670     void call_schedule() override {
3671         this->call_schedule_impl(0, 0);
3672     }
3673 
3674 private:
3675     friend void ::Halide::Internal::generator_test();
3676     friend class Internal::SimpleGeneratorFactory;
3677     friend void ::Halide::Internal::generator_test();
3678     friend class ::Halide::GeneratorContext;
3679 
3680     // No copy
3681     Generator(const Generator &) = delete;
3682     void operator=(const Generator &) = delete;
3683     // No move
3684     Generator(Generator &&that) = delete;
3685     void operator=(Generator &&that) = delete;
3686 };
3687 
3688 namespace Internal {
3689 
3690 class RegisterGenerator {
3691 public:
3692     RegisterGenerator(const char *registered_name, GeneratorFactory generator_factory) {
3693         Internal::GeneratorRegistry::register_factory(registered_name, std::move(generator_factory));
3694     }
3695 };
3696 
3697 class GeneratorStub : public NamesInterface {
3698 public:
3699     GeneratorStub(const GeneratorContext &context,
3700                   const GeneratorFactory &generator_factory);
3701 
3702     GeneratorStub(const GeneratorContext &context,
3703                   const GeneratorFactory &generator_factory,
3704                   const GeneratorParamsMap &generator_params,
3705                   const std::vector<std::vector<Internal::StubInput>> &inputs);
3706     std::vector<std::vector<Func>> generate(const GeneratorParamsMap &generator_params,
3707                                             const std::vector<std::vector<Internal::StubInput>> &inputs);
3708 
3709     // Output(s)
3710     // TODO: identify vars used
3711     Func get_output(const std::string &n) const {
3712         return generator->get_output(n);
3713     }
3714 
3715     template<typename T2>
3716     T2 get_output_buffer(const std::string &n) const {
3717         return T2(get_output(n), generator);
3718     }
3719 
3720     template<typename T2>
3721     std::vector<T2> get_array_output_buffer(const std::string &n) const {
3722         auto v = generator->get_array_output(n);
3723         std::vector<T2> result;
3724         for (auto &o : v) {
3725             result.push_back(T2(o, generator));
3726         }
3727         return result;
3728     }
3729 
3730     std::vector<Func> get_array_output(const std::string &n) const {
3731         return generator->get_array_output(n);
3732     }
3733 
3734     static std::vector<StubInput> to_stub_input_vector(const Expr &e) {
3735         return {StubInput(e)};
3736     }
3737 
3738     static std::vector<StubInput> to_stub_input_vector(const Func &f) {
3739         return {StubInput(f)};
3740     }
3741 
3742     template<typename T = void>
3743     static std::vector<StubInput> to_stub_input_vector(const StubInputBuffer<T> &b) {
3744         return {StubInput(b)};
3745     }
3746 
3747     template<typename T>
3748     static std::vector<StubInput> to_stub_input_vector(const std::vector<T> &v) {
3749         std::vector<StubInput> r;
3750         std::copy(v.begin(), v.end(), std::back_inserter(r));
3751         return r;
3752     }
3753 
3754     struct Names {
3755         std::vector<std::string> generator_params, inputs, outputs;
3756     };
3757     Names get_names() const;
3758 
3759     std::shared_ptr<GeneratorBase> generator;
3760 };
3761 
3762 }  // namespace Internal
3763 
3764 }  // namespace Halide
3765 
3766 // Define this namespace at global scope so that anonymous namespaces won't
3767 // defeat our static_assert check; define a dummy type inside so we can
3768 // check for type aliasing injected by anonymous namespace usage
3769 namespace halide_register_generator {
3770 struct halide_global_ns;
3771 };
3772 
3773 #define _HALIDE_REGISTER_GENERATOR_IMPL(GEN_CLASS_NAME, GEN_REGISTRY_NAME, FULLY_QUALIFIED_STUB_NAME)                               \
3774     namespace halide_register_generator {                                                                                           \
3775     struct halide_global_ns;                                                                                                        \
3776     namespace GEN_REGISTRY_NAME##_ns {                                                                                              \
3777         std::unique_ptr<Halide::Internal::GeneratorBase> factory(const Halide::GeneratorContext &context);                          \
3778         std::unique_ptr<Halide::Internal::GeneratorBase> factory(const Halide::GeneratorContext &context) {                         \
3779             return GEN_CLASS_NAME::create(context, #GEN_REGISTRY_NAME, #FULLY_QUALIFIED_STUB_NAME);                                 \
3780         }                                                                                                                           \
3781     }                                                                                                                               \
3782     static auto reg_##GEN_REGISTRY_NAME = Halide::Internal::RegisterGenerator(#GEN_REGISTRY_NAME, GEN_REGISTRY_NAME##_ns::factory); \
3783     }                                                                                                                               \
3784     static_assert(std::is_same<::halide_register_generator::halide_global_ns, halide_register_generator::halide_global_ns>::value,  \
3785                   "HALIDE_REGISTER_GENERATOR must be used at global scope");
3786 
3787 #define _HALIDE_REGISTER_GENERATOR2(GEN_CLASS_NAME, GEN_REGISTRY_NAME) \
3788     _HALIDE_REGISTER_GENERATOR_IMPL(GEN_CLASS_NAME, GEN_REGISTRY_NAME, GEN_REGISTRY_NAME)
3789 
3790 #define _HALIDE_REGISTER_GENERATOR3(GEN_CLASS_NAME, GEN_REGISTRY_NAME, FULLY_QUALIFIED_STUB_NAME) \
3791     _HALIDE_REGISTER_GENERATOR_IMPL(GEN_CLASS_NAME, GEN_REGISTRY_NAME, FULLY_QUALIFIED_STUB_NAME)
3792 
3793 // MSVC has a broken implementation of variadic macros: it expands __VA_ARGS__
3794 // as a single token in argument lists (rather than multiple tokens).
3795 // Jump through some hoops to work around this.
3796 #define __HALIDE_REGISTER_ARGCOUNT_IMPL(_1, _2, _3, COUNT, ...) \
3797     COUNT
3798 
3799 #define _HALIDE_REGISTER_ARGCOUNT_IMPL(ARGS) \
3800     __HALIDE_REGISTER_ARGCOUNT_IMPL ARGS
3801 
3802 #define _HALIDE_REGISTER_ARGCOUNT(...) \
3803     _HALIDE_REGISTER_ARGCOUNT_IMPL((__VA_ARGS__, 3, 2, 1, 0))
3804 
3805 #define ___HALIDE_REGISTER_CHOOSER(COUNT) \
3806     _HALIDE_REGISTER_GENERATOR##COUNT
3807 
3808 #define __HALIDE_REGISTER_CHOOSER(COUNT) \
3809     ___HALIDE_REGISTER_CHOOSER(COUNT)
3810 
3811 #define _HALIDE_REGISTER_CHOOSER(COUNT) \
3812     __HALIDE_REGISTER_CHOOSER(COUNT)
3813 
3814 #define _HALIDE_REGISTER_GENERATOR_PASTE(A, B) \
3815     A B
3816 
3817 #define HALIDE_REGISTER_GENERATOR(...) \
3818     _HALIDE_REGISTER_GENERATOR_PASTE(_HALIDE_REGISTER_CHOOSER(_HALIDE_REGISTER_ARGCOUNT(__VA_ARGS__)), (__VA_ARGS__))
3819 
3820 // HALIDE_REGISTER_GENERATOR_ALIAS() can be used to create an an alias-with-a-particular-set-of-param-values
3821 // for a given Generator in the build system. Normally, you wouldn't want to do this;
3822 // however, some existing Halide clients have build systems that make it challenging to
3823 // specify GeneratorParams inside the build system, and this allows a somewhat simpler
3824 // customization route for them. It's highly recommended you don't use this for new code.
3825 //
3826 // The final argument is really an initializer-list of GeneratorParams, in the form
3827 // of an initializer-list for map<string, string>:
3828 //
3829 //    { { "gp-name", "gp-value"} [, { "gp2-name", "gp2-value" }] }
3830 //
3831 // It is specified as a variadic template argument to allow for the fact that the embedded commas
3832 // would otherwise confuse the preprocessor; since (in this case) all we're going to do is
3833 // pass it thru as-is, this is fine (and even MSVC's 'broken' __VA_ARGS__ should be OK here).
3834 #define HALIDE_REGISTER_GENERATOR_ALIAS(GEN_REGISTRY_NAME, ORIGINAL_REGISTRY_NAME, ...)                                             \
3835     namespace halide_register_generator {                                                                                           \
3836     struct halide_global_ns;                                                                                                        \
3837     namespace ORIGINAL_REGISTRY_NAME##_ns {                                                                                         \
3838         std::unique_ptr<Halide::Internal::GeneratorBase> factory(const Halide::GeneratorContext &context);                          \
3839     }                                                                                                                               \
3840     namespace GEN_REGISTRY_NAME##_ns {                                                                                              \
3841         std::unique_ptr<Halide::Internal::GeneratorBase> factory(const Halide::GeneratorContext &context);                          \
3842         std::unique_ptr<Halide::Internal::GeneratorBase> factory(const Halide::GeneratorContext &context) {                         \
3843             auto g = ORIGINAL_REGISTRY_NAME##_ns::factory(context);                                                                 \
3844             g->set_generator_param_values(__VA_ARGS__);                                                                             \
3845             return g;                                                                                                               \
3846         }                                                                                                                           \
3847     }                                                                                                                               \
3848     static auto reg_##GEN_REGISTRY_NAME = Halide::Internal::RegisterGenerator(#GEN_REGISTRY_NAME, GEN_REGISTRY_NAME##_ns::factory); \
3849     }                                                                                                                               \
3850     static_assert(std::is_same<::halide_register_generator::halide_global_ns, halide_register_generator::halide_global_ns>::value,  \
3851                   "HALIDE_REGISTER_GENERATOR_ALIAS must be used at global scope");
3852 
3853 #endif  // HALIDE_GENERATOR_H_
3854