1 #include <cmath>
2 #include <fstream>
3 #include <unordered_map>
4 #include <utility>
5 
6 #include "BoundaryConditions.h"
7 #include "CompilerLogger.h"
8 #include "Derivative.h"
9 #include "Generator.h"
10 #include "IRPrinter.h"
11 #include "Module.h"
12 #include "Simplify.h"
13 
14 namespace Halide {
15 
GeneratorContext(const Target & t,bool auto_schedule,const MachineParams & machine_params)16 GeneratorContext::GeneratorContext(const Target &t, bool auto_schedule,
17                                    const MachineParams &machine_params)
18     : target("target", t),
19       auto_schedule("auto_schedule", auto_schedule),
20       machine_params("machine_params", machine_params),
21       externs_map(std::make_shared<ExternsMap>()),
22       value_tracker(std::make_shared<Internal::ValueTracker>()) {
23 }
24 
~GeneratorContext()25 GeneratorContext::~GeneratorContext() {
26     // nothing
27 }
28 
init_from_context(const Halide::GeneratorContext & context)29 void GeneratorContext::init_from_context(const Halide::GeneratorContext &context) {
30     target.set(context.get_target());
31     auto_schedule.set(context.get_auto_schedule());
32     machine_params.set(context.get_machine_params());
33     value_tracker = context.get_value_tracker();
34     externs_map = context.get_externs_map();
35 }
36 
37 namespace Internal {
38 
39 namespace {
40 
41 // Return true iff the name is valid for Generators or Params.
42 // (NOTE: gcc didn't add proper std::regex support until v4.9;
43 // we don't yet require this, hence the hand-rolled replacement.)
44 
is_alpha(char c)45 bool is_alpha(char c) {
46     return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z');
47 }
48 
49 // Note that this includes '_'
is_alnum(char c)50 bool is_alnum(char c) {
51     return is_alpha(c) || (c == '_') || (c >= '0' && c <= '9');
52 }
53 
54 // Basically, a valid C identifier, except:
55 //
56 // -- initial _ is forbidden (rather than merely "reserved")
57 // -- two underscores in a row is also forbidden
is_valid_name(const std::string & n)58 bool is_valid_name(const std::string &n) {
59     if (n.empty()) return false;
60     if (!is_alpha(n[0])) return false;
61     for (size_t i = 1; i < n.size(); ++i) {
62         if (!is_alnum(n[i])) return false;
63         if (n[i] == '_' && n[i - 1] == '_') return false;
64     }
65     return true;
66 }
67 
compute_base_path(const std::string & output_dir,const std::string & function_name,const std::string & file_base_name)68 std::string compute_base_path(const std::string &output_dir,
69                               const std::string &function_name,
70                               const std::string &file_base_name) {
71     std::vector<std::string> namespaces;
72     std::string simple_name = extract_namespaces(function_name, namespaces);
73     std::string base_path = output_dir + "/" + (file_base_name.empty() ? simple_name : file_base_name);
74     return base_path;
75 }
76 
compute_output_files(const Target & target,const std::string & base_path,const std::set<Output> & outputs)77 std::map<Output, std::string> compute_output_files(const Target &target,
78                                                    const std::string &base_path,
79                                                    const std::set<Output> &outputs) {
80     std::map<Output, const OutputInfo> output_info = get_output_info(target);
81 
82     std::map<Output, std::string> output_files;
83     for (auto o : outputs) {
84         output_files[o] = base_path + output_info.at(o).extension;
85     }
86     return output_files;
87 }
88 
to_argument(const Internal::Parameter & param,const Expr & default_value)89 Argument to_argument(const Internal::Parameter &param, const Expr &default_value) {
90     ArgumentEstimates argument_estimates = param.get_argument_estimates();
91     argument_estimates.scalar_def = default_value;
92     return Argument(param.name(),
93                     param.is_buffer() ? Argument::InputBuffer : Argument::InputScalar,
94                     param.type(), param.dimensions(), argument_estimates);
95 }
96 
make_param_func(const Parameter & p,const std::string & name)97 Func make_param_func(const Parameter &p, const std::string &name) {
98     internal_assert(p.is_buffer());
99     Func f(name + "_im");
100     auto b = p.buffer();
101     if (b.defined()) {
102         // If the Parameter has an explicit BufferPtr set, bind directly to it
103         f(_) = b(_);
104     } else {
105         std::vector<Var> args;
106         std::vector<Expr> args_expr;
107         for (int i = 0; i < p.dimensions(); ++i) {
108             Var v = Var::implicit(i);
109             args.push_back(v);
110             args_expr.push_back(v);
111         }
112         f(args) = Internal::Call::make(p, args_expr);
113     }
114     return f;
115 }
116 
117 }  // namespace
118 
parse_halide_type_list(const std::string & types)119 std::vector<Type> parse_halide_type_list(const std::string &types) {
120     const auto &e = get_halide_type_enum_map();
121     std::vector<Type> result;
122     for (auto t : split_string(types, ",")) {
123         auto it = e.find(t);
124         user_assert(it != e.end()) << "Type not found: " << t;
125         result.push_back(it->second);
126     }
127     return result;
128 }
129 
track_values(const std::string & name,const std::vector<Expr> & values)130 void ValueTracker::track_values(const std::string &name, const std::vector<Expr> &values) {
131     std::vector<std::vector<Expr>> &history = values_history[name];
132     if (history.empty()) {
133         for (size_t i = 0; i < values.size(); ++i) {
134             history.push_back({values[i]});
135         }
136         return;
137     }
138 
139     internal_assert(history.size() == values.size())
140         << "Expected values of size " << history.size()
141         << " but saw size " << values.size()
142         << " for name " << name << "\n";
143 
144     // For each item, see if we have a new unique value
145     for (size_t i = 0; i < values.size(); ++i) {
146         Expr oldval = history[i].back();
147         Expr newval = values[i];
148         if (oldval.defined() && newval.defined()) {
149             if (can_prove(newval == oldval)) {
150                 continue;
151             }
152         } else if (!oldval.defined() && !newval.defined()) {
153             // Expr::operator== doesn't work with undefined
154             // values, but they are equal for our purposes here.
155             continue;
156         }
157         history[i].push_back(newval);
158         // If we exceed max_unique_values, fail immediately.
159         // TODO: could be useful to log all the entries that
160         // overflow max_unique_values before failing.
161         // TODO: this could be more helpful about labeling the values
162         // that have multiple setttings.
163         if (history[i].size() > max_unique_values) {
164             std::ostringstream o;
165             o << "Saw too many unique values in ValueTracker[" + std::to_string(i) + "]; "
166               << "expected a maximum of " << max_unique_values << ":\n";
167             for (auto e : history[i]) {
168                 o << "    " << e << "\n";
169             }
170             user_error << o.str();
171         }
172     }
173 }
174 
parameter_constraints(const Parameter & p)175 std::vector<Expr> parameter_constraints(const Parameter &p) {
176     internal_assert(p.defined());
177     std::vector<Expr> values;
178     values.emplace_back(p.host_alignment());
179     if (p.is_buffer()) {
180         for (int i = 0; i < p.dimensions(); ++i) {
181             values.push_back(p.min_constraint(i));
182             values.push_back(p.extent_constraint(i));
183             values.push_back(p.stride_constraint(i));
184         }
185     } else {
186         values.push_back(p.min_value());
187         values.push_back(p.max_value());
188     }
189     return values;
190 }
191 
192 class StubEmitter {
193 public:
StubEmitter(std::ostream & dest,const std::string & generator_registered_name,const std::string & generator_stub_name,const std::vector<Internal::GeneratorParamBase * > & generator_params,const std::vector<Internal::GeneratorInputBase * > & inputs,const std::vector<Internal::GeneratorOutputBase * > & outputs)194     StubEmitter(std::ostream &dest,
195                 const std::string &generator_registered_name,
196                 const std::string &generator_stub_name,
197                 const std::vector<Internal::GeneratorParamBase *> &generator_params,
198                 const std::vector<Internal::GeneratorInputBase *> &inputs,
199                 const std::vector<Internal::GeneratorOutputBase *> &outputs)
200         : stream(dest),
201           generator_registered_name(generator_registered_name),
202           generator_stub_name(generator_stub_name),
203           generator_params(select_generator_params(generator_params)),
204           inputs(inputs),
205           outputs(outputs) {
206         namespaces = split_string(generator_stub_name, "::");
207         internal_assert(!namespaces.empty());
208         if (namespaces[0].empty()) {
209             // We have a name like ::foo::bar::baz; omit the first empty ns.
210             namespaces.erase(namespaces.begin());
211             internal_assert(namespaces.size() >= 2);
212         }
213         class_name = namespaces.back();
214         namespaces.pop_back();
215     }
216 
217     void emit();
218 
219 private:
220     std::ostream &stream;
221     const std::string generator_registered_name;
222     const std::string generator_stub_name;
223     std::string class_name;
224     std::vector<std::string> namespaces;
225     const std::vector<Internal::GeneratorParamBase *> generator_params;
226     const std::vector<Internal::GeneratorInputBase *> inputs;
227     const std::vector<Internal::GeneratorOutputBase *> outputs;
228     int indent_level{0};
229 
select_generator_params(const std::vector<Internal::GeneratorParamBase * > & in)230     std::vector<Internal::GeneratorParamBase *> select_generator_params(const std::vector<Internal::GeneratorParamBase *> &in) {
231         std::vector<Internal::GeneratorParamBase *> out;
232         for (auto p : in) {
233             // These are always propagated specially.
234             if (p->name == "target" ||
235                 p->name == "auto_schedule" ||
236                 p->name == "machine_params") continue;
237             if (p->is_synthetic_param()) continue;
238             out.push_back(p);
239         }
240         return out;
241     }
242 
243     /** Emit spaces according to the current indentation level */
get_indent() const244     Indentation get_indent() const {
245         return Indentation{indent_level};
246     }
247 
248     void emit_inputs_struct();
249     void emit_generator_params_struct();
250 };
251 
emit_generator_params_struct()252 void StubEmitter::emit_generator_params_struct() {
253     const auto &v = generator_params;
254     std::string name = "GeneratorParams";
255     stream << get_indent() << "struct " << name << " final {\n";
256     indent_level++;
257     if (!v.empty()) {
258         for (auto p : v) {
259             stream << get_indent() << p->get_c_type() << " " << p->name << "{ " << p->get_default_value() << " };\n";
260         }
261         stream << "\n";
262     }
263 
264     stream << get_indent() << name << "() {}\n";
265     stream << "\n";
266 
267     if (!v.empty()) {
268         stream << get_indent() << name << "(\n";
269         indent_level++;
270         std::string comma = "";
271         for (auto p : v) {
272             stream << get_indent() << comma << p->get_c_type() << " " << p->name << "\n";
273             comma = ", ";
274         }
275         indent_level--;
276         stream << get_indent() << ") : \n";
277         indent_level++;
278         comma = "";
279         for (auto p : v) {
280             stream << get_indent() << comma << p->name << "(" << p->name << ")\n";
281             comma = ", ";
282         }
283         indent_level--;
284         stream << get_indent() << "{\n";
285         stream << get_indent() << "}\n";
286         stream << "\n";
287     }
288 
289     stream << get_indent() << "inline HALIDE_NO_USER_CODE_INLINE Halide::Internal::GeneratorParamsMap to_generator_params_map() const {\n";
290     indent_level++;
291     stream << get_indent() << "return {\n";
292     indent_level++;
293     std::string comma = "";
294     for (auto p : v) {
295         stream << get_indent() << comma << "{\"" << p->name << "\", ";
296         if (p->is_looplevel_param()) {
297             stream << p->name << "}\n";
298         } else {
299             stream << p->call_to_string(p->name) << "}\n";
300         }
301         comma = ", ";
302     }
303     indent_level--;
304     stream << get_indent() << "};\n";
305     indent_level--;
306     stream << get_indent() << "}\n";
307 
308     indent_level--;
309     stream << get_indent() << "};\n";
310     stream << "\n";
311 }
312 
emit_inputs_struct()313 void StubEmitter::emit_inputs_struct() {
314     struct InInfo {
315         std::string c_type;
316         std::string name;
317     };
318     std::vector<InInfo> in_info;
319     for (auto input : inputs) {
320         std::string c_type = input->get_c_type();
321         if (input->is_array()) {
322             c_type = "std::vector<" + c_type + ">";
323         }
324         in_info.push_back({c_type, input->name()});
325     }
326 
327     const std::string name = "Inputs";
328     stream << get_indent() << "struct " << name << " final {\n";
329     indent_level++;
330     for (auto in : in_info) {
331         stream << get_indent() << in.c_type << " " << in.name << ";\n";
332     }
333     stream << "\n";
334 
335     stream << get_indent() << name << "() {}\n";
336     stream << "\n";
337     if (!in_info.empty()) {
338         stream << get_indent() << name << "(\n";
339         indent_level++;
340         std::string comma = "";
341         for (auto in : in_info) {
342             stream << get_indent() << comma << "const " << in.c_type << "& " << in.name << "\n";
343             comma = ", ";
344         }
345         indent_level--;
346         stream << get_indent() << ") : \n";
347         indent_level++;
348         comma = "";
349         for (auto in : in_info) {
350             stream << get_indent() << comma << in.name << "(" << in.name << ")\n";
351             comma = ", ";
352         }
353         indent_level--;
354         stream << get_indent() << "{\n";
355         stream << get_indent() << "}\n";
356 
357         indent_level--;
358     }
359     stream << get_indent() << "};\n";
360     stream << "\n";
361 }
362 
emit()363 void StubEmitter::emit() {
364     if (outputs.empty()) {
365         // The generator can't support a real stub. Instead, generate an (essentially)
366         // empty .stub.h file, so that build systems like Bazel will still get the output file
367         // they expected. Note that we deliberately don't emit an ifndef header guard,
368         // since we can't reliably assume that the generator_name will be globally unique;
369         // on the other hand, since this file is just a couple of comments, it's
370         // really not an issue if it's included multiple times.
371         stream << "/* MACHINE-GENERATED - DO NOT EDIT */\n";
372         stream << "/* The Generator named " << generator_registered_name << " uses ImageParam or Param, thus cannot have a Stub generated. */\n";
373         return;
374     }
375 
376     struct OutputInfo {
377         std::string name;
378         std::string ctype;
379         std::string getter;
380     };
381     bool all_outputs_are_func = true;
382     std::vector<OutputInfo> out_info;
383     for (auto output : outputs) {
384         std::string c_type = output->get_c_type();
385         std::string getter;
386         const bool is_func = (c_type == "Func");
387         if (output->is_array()) {
388             getter = is_func ? "get_array_output" : "get_array_output_buffer<" + c_type + ">";
389         } else {
390             getter = is_func ? "get_output" : "get_output_buffer<" + c_type + ">";
391         }
392         out_info.push_back({output->name(),
393                             output->is_array() ? "std::vector<" + c_type + ">" : c_type,
394                             getter + "(\"" + output->name() + "\")"});
395         if (c_type != "Func") {
396             all_outputs_are_func = false;
397         }
398     }
399 
400     std::ostringstream guard;
401     guard << "HALIDE_STUB";
402     for (const auto &ns : namespaces) {
403         guard << "_" << ns;
404     }
405     guard << "_" << class_name;
406 
407     stream << get_indent() << "#ifndef " << guard.str() << "\n";
408     stream << get_indent() << "#define " << guard.str() << "\n";
409     stream << "\n";
410 
411     stream << get_indent() << "/* MACHINE-GENERATED - DO NOT EDIT */\n";
412     stream << "\n";
413 
414     stream << get_indent() << "#include <cassert>\n";
415     stream << get_indent() << "#include <map>\n";
416     stream << get_indent() << "#include <memory>\n";
417     stream << get_indent() << "#include <string>\n";
418     stream << get_indent() << "#include <utility>\n";
419     stream << get_indent() << "#include <vector>\n";
420     stream << "\n";
421     stream << get_indent() << "#include \"Halide.h\"\n";
422     stream << "\n";
423 
424     stream << "namespace halide_register_generator {\n";
425     stream << "namespace " << generator_registered_name << "_ns {\n";
426     stream << "extern std::unique_ptr<Halide::Internal::GeneratorBase> factory(const Halide::GeneratorContext& context);\n";
427     stream << "}  // namespace halide_register_generator\n";
428     stream << "}  // namespace " << generator_registered_name << "\n";
429     stream << "\n";
430 
431     for (const auto &ns : namespaces) {
432         stream << get_indent() << "namespace " << ns << " {\n";
433     }
434     stream << "\n";
435 
436     for (auto *p : generator_params) {
437         std::string decl = p->get_type_decls();
438         if (decl.empty()) continue;
439         stream << decl << "\n";
440     }
441 
442     stream << get_indent() << "class " << class_name << " final : public Halide::NamesInterface {\n";
443     stream << get_indent() << "public:\n";
444     indent_level++;
445 
446     emit_inputs_struct();
447     emit_generator_params_struct();
448 
449     stream << get_indent() << "struct Outputs final {\n";
450     indent_level++;
451     stream << get_indent() << "// Outputs\n";
452     for (const auto &out : out_info) {
453         stream << get_indent() << out.ctype << " " << out.name << ";\n";
454     }
455 
456     stream << "\n";
457     stream << get_indent() << "// The Target used\n";
458     stream << get_indent() << "Target target;\n";
459 
460     if (out_info.size() == 1) {
461         stream << "\n";
462         if (all_outputs_are_func) {
463             std::string name = out_info.at(0).name;
464             auto output = outputs[0];
465             if (output->is_array()) {
466                 stream << get_indent() << "operator std::vector<Halide::Func>() const {\n";
467                 indent_level++;
468                 stream << get_indent() << "return " << name << ";\n";
469                 indent_level--;
470                 stream << get_indent() << "}\n";
471 
472                 stream << get_indent() << "Halide::Func operator[](size_t i) const {\n";
473                 indent_level++;
474                 stream << get_indent() << "return " << name << "[i];\n";
475                 indent_level--;
476                 stream << get_indent() << "}\n";
477 
478                 stream << get_indent() << "Halide::Func at(size_t i) const {\n";
479                 indent_level++;
480                 stream << get_indent() << "return " << name << ".at(i);\n";
481                 indent_level--;
482                 stream << get_indent() << "}\n";
483 
484                 stream << get_indent() << "// operator operator()() overloads omitted because the sole Output is array-of-Func.\n";
485             } else {
486                 // If there is exactly one output, add overloads
487                 // for operator Func and operator().
488                 stream << get_indent() << "operator Halide::Func() const {\n";
489                 indent_level++;
490                 stream << get_indent() << "return " << name << ";\n";
491                 indent_level--;
492                 stream << get_indent() << "}\n";
493 
494                 stream << "\n";
495                 stream << get_indent() << "template <typename... Args>\n";
496                 stream << get_indent() << "Halide::FuncRef operator()(Args&&... args) const {\n";
497                 indent_level++;
498                 stream << get_indent() << "return " << name << "(std::forward<Args>(args)...);\n";
499                 indent_level--;
500                 stream << get_indent() << "}\n";
501 
502                 stream << "\n";
503                 stream << get_indent() << "template <typename ExprOrVar>\n";
504                 stream << get_indent() << "Halide::FuncRef operator()(std::vector<ExprOrVar> args) const {\n";
505                 indent_level++;
506                 stream << get_indent() << "return " << name << "()(args);\n";
507                 indent_level--;
508                 stream << get_indent() << "}\n";
509             }
510         } else {
511             stream << get_indent() << "// operator Func() and operator()() overloads omitted because the sole Output is not Func.\n";
512         }
513     }
514 
515     stream << "\n";
516     if (all_outputs_are_func) {
517         stream << get_indent() << "Halide::Pipeline get_pipeline() const {\n";
518         indent_level++;
519         stream << get_indent() << "return Halide::Pipeline(std::vector<Halide::Func>{\n";
520         indent_level++;
521         int commas = (int)out_info.size() - 1;
522         for (const auto &out : out_info) {
523             stream << get_indent() << out.name << (commas-- ? "," : "") << "\n";
524         }
525         indent_level--;
526         stream << get_indent() << "});\n";
527         indent_level--;
528         stream << get_indent() << "}\n";
529 
530         stream << "\n";
531         stream << get_indent() << "Halide::Realization realize(std::vector<int32_t> sizes) {\n";
532         indent_level++;
533         stream << get_indent() << "return get_pipeline().realize(sizes, target);\n";
534         indent_level--;
535         stream << get_indent() << "}\n";
536 
537         stream << "\n";
538         stream << get_indent() << "template <typename... Args, typename std::enable_if<Halide::Internal::NoRealizations<Args...>::value>::type * = nullptr>\n";
539         stream << get_indent() << "Halide::Realization realize(Args&&... args) {\n";
540         indent_level++;
541         stream << get_indent() << "return get_pipeline().realize(std::forward<Args>(args)..., target);\n";
542         indent_level--;
543         stream << get_indent() << "}\n";
544 
545         stream << "\n";
546         stream << get_indent() << "void realize(Halide::Realization r) {\n";
547         indent_level++;
548         stream << get_indent() << "get_pipeline().realize(r, target);\n";
549         indent_level--;
550         stream << get_indent() << "}\n";
551     } else {
552         stream << get_indent() << "// get_pipeline() and realize() overloads omitted because some Outputs are not Func.\n";
553     }
554 
555     indent_level--;
556     stream << get_indent() << "};\n";
557     stream << "\n";
558 
559     stream << get_indent() << "HALIDE_NO_USER_CODE_INLINE static Outputs generate(\n";
560     indent_level++;
561     stream << get_indent() << "const GeneratorContext& context,\n";
562     stream << get_indent() << "const Inputs& inputs,\n";
563     stream << get_indent() << "const GeneratorParams& generator_params = GeneratorParams()\n";
564     indent_level--;
565     stream << get_indent() << ")\n";
566     stream << get_indent() << "{\n";
567     indent_level++;
568     stream << get_indent() << "using Stub = Halide::Internal::GeneratorStub;\n";
569     stream << get_indent() << "Stub stub(\n";
570     indent_level++;
571     stream << get_indent() << "context,\n";
572     stream << get_indent() << "halide_register_generator::" << generator_registered_name << "_ns::factory,\n";
573     stream << get_indent() << "generator_params.to_generator_params_map(),\n";
574     stream << get_indent() << "{\n";
575     indent_level++;
576     for (size_t i = 0; i < inputs.size(); ++i) {
577         stream << get_indent() << "Stub::to_stub_input_vector(inputs." << inputs[i]->name() << ")";
578         stream << ",\n";
579     }
580     indent_level--;
581     stream << get_indent() << "}\n";
582     indent_level--;
583     stream << get_indent() << ");\n";
584 
585     stream << get_indent() << "return {\n";
586     indent_level++;
587     for (const auto &out : out_info) {
588         stream << get_indent() << "stub." << out.getter << ",\n";
589     }
590     stream << get_indent() << "stub.generator->get_target()\n";
591     indent_level--;
592     stream << get_indent() << "};\n";
593     indent_level--;
594     stream << get_indent() << "}\n";
595     stream << "\n";
596 
597     stream << get_indent() << "// overload to allow GeneratorContext-pointer\n";
598     stream << get_indent() << "inline static Outputs generate(\n";
599     indent_level++;
600     stream << get_indent() << "const GeneratorContext* context,\n";
601     stream << get_indent() << "const Inputs& inputs,\n";
602     stream << get_indent() << "const GeneratorParams& generator_params = GeneratorParams()\n";
603     indent_level--;
604     stream << get_indent() << ")\n";
605     stream << get_indent() << "{\n";
606     indent_level++;
607     stream << get_indent() << "return generate(*context, inputs, generator_params);\n";
608     indent_level--;
609     stream << get_indent() << "}\n";
610     stream << "\n";
611 
612     stream << get_indent() << "// overload to allow Target instead of GeneratorContext.\n";
613     stream << get_indent() << "inline static Outputs generate(\n";
614     indent_level++;
615     stream << get_indent() << "const Target& target,\n";
616     stream << get_indent() << "const Inputs& inputs,\n";
617     stream << get_indent() << "const GeneratorParams& generator_params = GeneratorParams()\n";
618     indent_level--;
619     stream << get_indent() << ")\n";
620     stream << get_indent() << "{\n";
621     indent_level++;
622     stream << get_indent() << "return generate(Halide::GeneratorContext(target), inputs, generator_params);\n";
623     indent_level--;
624     stream << get_indent() << "}\n";
625     stream << "\n";
626 
627     stream << get_indent() << class_name << "() = delete;\n";
628 
629     indent_level--;
630     stream << get_indent() << "};\n";
631     stream << "\n";
632 
633     for (int i = (int)namespaces.size() - 1; i >= 0; --i) {
634         stream << get_indent() << "}  // namespace " << namespaces[i] << "\n";
635     }
636     stream << "\n";
637 
638     stream << get_indent() << "#endif  // " << guard.str() << "\n";
639 }
640 
GeneratorStub(const GeneratorContext & context,const GeneratorFactory & generator_factory)641 GeneratorStub::GeneratorStub(const GeneratorContext &context,
642                              const GeneratorFactory &generator_factory)
643     : generator(generator_factory(context)) {
644 }
645 
GeneratorStub(const GeneratorContext & context,const GeneratorFactory & generator_factory,const GeneratorParamsMap & generator_params,const std::vector<std::vector<Internal::StubInput>> & inputs)646 GeneratorStub::GeneratorStub(const GeneratorContext &context,
647                              const GeneratorFactory &generator_factory,
648                              const GeneratorParamsMap &generator_params,
649                              const std::vector<std::vector<Internal::StubInput>> &inputs)
650     : GeneratorStub(context, generator_factory) {
651     generate(generator_params, inputs);
652 }
653 
654 // Return a vector of all Outputs of this Generator; non-array outputs are returned
655 // as a vector-of-size-1. This method is primarily useful for code that needs
656 // to iterate through the outputs of unknown, arbitrary Generators (e.g.,
657 // the Python bindings).
generate(const GeneratorParamsMap & generator_params,const std::vector<std::vector<Internal::StubInput>> & inputs)658 std::vector<std::vector<Func>> GeneratorStub::generate(const GeneratorParamsMap &generator_params,
659                                                        const std::vector<std::vector<Internal::StubInput>> &inputs) {
660     generator->set_generator_param_values(generator_params);
661     generator->call_configure();
662     generator->set_inputs_vector(inputs);
663     Pipeline p = generator->build_pipeline();
664 
665     std::vector<std::vector<Func>> v;
666     GeneratorParamInfo &pi = generator->param_info();
667     if (!pi.outputs().empty()) {
668         for (auto *output : pi.outputs()) {
669             const std::string &name = output->name();
670             if (output->is_array()) {
671                 v.push_back(get_array_output(name));
672             } else {
673                 v.push_back(std::vector<Func>{get_output(name)});
674             }
675         }
676     } else {
677         // Generators with build() method can't have Output<>, hence can't have array outputs
678         for (auto output : p.outputs()) {
679             v.push_back(std::vector<Func>{output});
680         }
681     }
682     return v;
683 }
684 
get_names() const685 GeneratorStub::Names GeneratorStub::get_names() const {
686     auto &pi = generator->param_info();
687     Names names;
688     for (auto *o : pi.generator_params()) {
689         names.generator_params.push_back(o->name);
690     }
691     for (auto *o : pi.inputs()) {
692         names.inputs.push_back(o->name());
693     }
694     for (auto *o : pi.outputs()) {
695         names.outputs.push_back(o->name());
696     }
697     return names;
698 }
699 
get_halide_type_enum_map()700 const std::map<std::string, Type> &get_halide_type_enum_map() {
701     static const std::map<std::string, Type> halide_type_enum_map{
702         {"bool", Bool()},
703         {"int8", Int(8)},
704         {"int16", Int(16)},
705         {"int32", Int(32)},
706         {"uint8", UInt(8)},
707         {"uint16", UInt(16)},
708         {"uint32", UInt(32)},
709         {"float16", Float(16)},
710         {"float32", Float(32)},
711         {"float64", Float(64)}};
712     return halide_type_enum_map;
713 }
714 
halide_type_to_c_source(const Type & t)715 std::string halide_type_to_c_source(const Type &t) {
716     static const std::map<halide_type_code_t, std::string> m = {
717         {halide_type_int, "Int"},
718         {halide_type_uint, "UInt"},
719         {halide_type_float, "Float"},
720         {halide_type_handle, "Handle"},
721     };
722     std::ostringstream oss;
723     oss << "Halide::" << m.at(t.code()) << "(" << t.bits() << +")";
724     return oss.str();
725 }
726 
halide_type_to_c_type(const Type & t)727 std::string halide_type_to_c_type(const Type &t) {
728     auto encode = [](const Type &t) -> int { return t.code() << 16 | t.bits(); };
729     static const std::map<int, std::string> m = {
730         {encode(Int(8)), "int8_t"},
731         {encode(Int(16)), "int16_t"},
732         {encode(Int(32)), "int32_t"},
733         {encode(Int(64)), "int64_t"},
734         {encode(UInt(1)), "bool"},
735         {encode(UInt(8)), "uint8_t"},
736         {encode(UInt(16)), "uint16_t"},
737         {encode(UInt(32)), "uint32_t"},
738         {encode(UInt(64)), "uint64_t"},
739         {encode(BFloat(16)), "uint16_t"},  // TODO: see Issues #3709, #3967
740         {encode(Float(16)), "uint16_t"},   // TODO: see Issues #3709, #3967
741         {encode(Float(32)), "float"},
742         {encode(Float(64)), "double"},
743         {encode(Handle(64)), "void*"}};
744     internal_assert(m.count(encode(t))) << t << " " << encode(t);
745     return m.at(encode(t));
746 }
747 
generate_filter_main_inner(int argc,char ** argv,std::ostream & cerr)748 int generate_filter_main_inner(int argc, char **argv, std::ostream &cerr) {
749     const char kUsage[] =
750         "gengen \n"
751         "  [-g GENERATOR_NAME] [-f FUNCTION_NAME] [-o OUTPUT_DIR] [-r RUNTIME_NAME] [-d 1|0]\n"
752         "  [-e EMIT_OPTIONS] [-n FILE_BASE_NAME] [-p PLUGIN_NAME] [-s AUTOSCHEDULER_NAME]\n"
753         "       target=target-string[,target-string...] [generator_arg=value [...]]\n"
754         "\n"
755         " -d  Build a module that is suitable for using for gradient descent calculationn\n"
756         "     in TensorFlow or PyTorch. See Generator::build_gradient_module() documentation.\n"
757         "\n"
758         " -e  A comma separated list of files to emit. Accepted values are:\n"
759         "     [assembly, bitcode, c_header, c_source, cpp_stub, featurization,\n"
760         "      llvm_assembly, object, python_extension, pytorch_wrapper, registration,\n"
761         "      schedule, static_library, stmt, stmt_html, compiler_log].\n"
762         "     If omitted, default value is [c_header, static_library, registration].\n"
763         "\n"
764         " -p  A comma-separated list of shared libraries that will be loaded before the\n"
765         "     generator is run. Useful for custom auto-schedulers. The generator must\n"
766         "     either be linked against a shared libHalide or compiled with -rdynamic\n"
767         "     so that references in the shared library to libHalide can resolve.\n"
768         "     (Note that this does not change the default autoscheduler; use the -s flag\n"
769         "     to set that value.)"
770         "\n"
771         " -r   The name of a standalone runtime to generate. Only honors EMIT_OPTIONS 'o'\n"
772         "     and 'static_library'. When multiple targets are specified, it picks a\n"
773         "     runtime that is compatible with all of the targets, or fails if it cannot\n"
774         "     find one. Flags across all of the targets that do not affect runtime code\n"
775         "     generation, such as `no_asserts` and `no_runtime`, are ignored.\n"
776         "\n"
777         " -s  The name of an autoscheduler to set as the default.\n";
778 
779     std::map<std::string, std::string> flags_info = {
780         {"-d", "0"},
781         {"-e", ""},
782         {"-f", ""},
783         {"-g", ""},
784         {"-n", ""},
785         {"-o", ""},
786         {"-p", ""},
787         {"-r", ""},
788         {"-s", ""},
789     };
790     GeneratorParamsMap generator_args;
791 
792     for (int i = 1; i < argc; ++i) {
793         if (argv[i][0] != '-') {
794             std::vector<std::string> v = split_string(argv[i], "=");
795             if (v.size() != 2 || v[0].empty() || v[1].empty()) {
796                 cerr << kUsage;
797                 return 1;
798             }
799             generator_args[v[0]] = v[1];
800             continue;
801         }
802         auto it = flags_info.find(argv[i]);
803         if (it != flags_info.end()) {
804             if (i + 1 >= argc) {
805                 cerr << kUsage;
806                 return 1;
807             }
808             it->second = argv[i + 1];
809             ++i;
810             continue;
811         }
812         cerr << "Unknown flag: " << argv[i] << "\n";
813         cerr << kUsage;
814         return 1;
815     }
816 
817     // It's possible that in the future loaded plugins might change
818     // how arguments are parsed, so we handle those first.
819     for (const auto &lib : split_string(flags_info["-p"], ",")) {
820         if (!lib.empty()) {
821             load_plugin(lib);
822         }
823     }
824 
825     if (flags_info["-d"] != "1" && flags_info["-d"] != "0") {
826         cerr << "-d must be 0 or 1\n";
827         cerr << kUsage;
828         return 1;
829     }
830     const int build_gradient_module = flags_info["-d"] == "1";
831 
832     std::string autoscheduler_name = flags_info["-s"];
833     if (!autoscheduler_name.empty()) {
834         Pipeline::set_default_autoscheduler_name(autoscheduler_name);
835     }
836 
837     std::string runtime_name = flags_info["-r"];
838 
839     std::vector<std::string> generator_names = GeneratorRegistry::enumerate();
840     if (generator_names.empty() && runtime_name.empty()) {
841         cerr << "No generators have been registered and not compiling a standalone runtime\n";
842         cerr << kUsage;
843         return 1;
844     }
845 
846     std::string generator_name = flags_info["-g"];
847     if (generator_name.empty() && runtime_name.empty()) {
848         // Require either -g or -r to be specified:
849         // no longer infer the name when only one Generator is registered
850         cerr << "Either -g <name> or -r must be specified; available Generators are:\n";
851         if (!generator_names.empty()) {
852             for (const auto &name : generator_names) {
853                 cerr << "    " << name << "\n";
854             }
855         } else {
856             cerr << "    <none>\n";
857         }
858         return 1;
859     }
860 
861     std::string function_name = flags_info["-f"];
862     if (function_name.empty()) {
863         // If -f isn't specified, assume function name = generator name.
864         function_name = generator_name;
865     }
866     std::string output_dir = flags_info["-o"];
867     if (output_dir.empty()) {
868         cerr << "-o must always be specified.\n";
869         cerr << kUsage;
870         return 1;
871     }
872 
873     // It's ok to omit "target=" if we are generating *only* a cpp_stub
874     const std::vector<std::string> emit_flags = split_string(flags_info["-e"], ",");
875     const bool stub_only = (emit_flags.size() == 1 && emit_flags[0] == "cpp_stub");
876     if (!stub_only) {
877         if (generator_args.find("target") == generator_args.end()) {
878             cerr << "Target missing\n";
879             cerr << kUsage;
880             return 1;
881         }
882     }
883 
884     // it's OK for file_base_name to be empty: filename will be based on function name
885     std::string file_base_name = flags_info["-n"];
886 
887     auto target_strings = split_string(generator_args["target"].string_value, ",");
888     std::vector<Target> targets;
889     for (const auto &s : target_strings) {
890         targets.emplace_back(s);
891     }
892 
893     // extensions won't vary across multitarget output
894     std::map<Output, const OutputInfo> output_info = get_output_info(targets[0]);
895 
896     std::set<Output> outputs;
897     if (emit_flags.empty() || (emit_flags.size() == 1 && emit_flags[0].empty())) {
898         // If omitted or empty, assume .a and .h and registration.cpp
899         outputs.insert(Output::c_header);
900         outputs.insert(Output::registration);
901         outputs.insert(Output::static_library);
902     } else {
903         // Build a reverse lookup table. Allow some legacy aliases on the command line,
904         // to allow legacy build systems to work more easily.
905         std::map<std::string, Output> output_name_to_enum = {
906             {"cpp", Output::c_source},
907             {"h", Output::c_header},
908             {"html", Output::stmt_html},
909             {"o", Output::object},
910             {"py.c", Output::python_extension},
911         };
912         for (const auto &it : output_info) {
913             output_name_to_enum[it.second.name] = it.first;
914         }
915 
916         for (std::string opt : emit_flags) {
917             auto it = output_name_to_enum.find(opt);
918             if (it == output_name_to_enum.end()) {
919                 cerr << "Unrecognized emit option: " << opt << " is not one of [";
920                 auto end = output_info.cend();
921                 auto last = std::prev(end);
922                 for (auto iter = output_info.cbegin(); iter != end; ++iter) {
923                     cerr << iter->second.name;
924                     if (iter != last) {
925                         cerr << " ";
926                     }
927                 }
928                 cerr << "], ignoring.\n";
929                 cerr << kUsage;
930                 return 1;
931             }
932             outputs.insert(it->second);
933         }
934     }
935 
936     // Allow quick-n-dirty use of compiler logging via HL_DEBUG_COMPILER_LOGGER env var
937     const bool do_compiler_logging = outputs.count(Output::compiler_log) ||
938                                      (get_env_variable("HL_DEBUG_COMPILER_LOGGER") == "1");
939 
940     const bool obfuscate_compiler_logging = get_env_variable("HL_OBFUSCATE_COMPILER_LOGGER") == "1";
941 
942     const CompilerLoggerFactory no_compiler_logger_factory =
943         [](const std::string &, const Target &) -> std::unique_ptr<CompilerLogger> {
944         return nullptr;
945     };
946 
947     const CompilerLoggerFactory json_compiler_logger_factory =
948         [&](const std::string &function_name, const Target &target) -> std::unique_ptr<CompilerLogger> {
949         // rebuild generator_args from the map so that they are always canonical
950         std::string generator_args_string;
951         std::string sep;
952         for (const auto &it : generator_args) {
953             if (it.first == "target") continue;
954             std::string quote = it.second.string_value.find(" ") != std::string::npos ? "\\\"" : "";
955             generator_args_string += sep + it.first + "=" + quote + it.second.string_value + quote;
956             sep = " ";
957         }
958         std::unique_ptr<JSONCompilerLogger> t(new JSONCompilerLogger(
959             obfuscate_compiler_logging ? "" : generator_name,
960             obfuscate_compiler_logging ? "" : function_name,
961             obfuscate_compiler_logging ? "" : autoscheduler_name,
962             obfuscate_compiler_logging ? Target() : target,
963             obfuscate_compiler_logging ? "" : generator_args_string,
964             obfuscate_compiler_logging));
965         return t;
966     };
967 
968     const CompilerLoggerFactory compiler_logger_factory = do_compiler_logging ?
969                                                               json_compiler_logger_factory :
970                                                               no_compiler_logger_factory;
971 
972     if (!runtime_name.empty()) {
973         std::string base_path = compute_base_path(output_dir, runtime_name, "");
974 
975         Target gcd_target = targets[0];
976         for (size_t i = 1; i < targets.size(); i++) {
977             if (!gcd_target.get_runtime_compatible_target(targets[i], gcd_target)) {
978                 user_error << "Failed to find compatible runtime target for "
979                            << gcd_target.to_string()
980                            << " and "
981                            << targets[i].to_string() << "\n";
982             }
983         }
984 
985         if (targets.size() > 1) {
986             debug(1) << "Building runtime for computed target: " << gcd_target.to_string() << "\n";
987         }
988 
989         auto output_files = compute_output_files(gcd_target, base_path, outputs);
990         // Runtime doesn't get to participate in the CompilerLogger party
991         compile_standalone_runtime(output_files, gcd_target);
992     }
993 
994     if (!generator_name.empty()) {
995         std::string base_path = compute_base_path(output_dir, function_name, file_base_name);
996         debug(1) << "Generator " << generator_name << " has base_path " << base_path << "\n";
997         if (outputs.count(Output::cpp_stub)) {
998             // When generating cpp_stub, we ignore all generator args passed in, and supply a fake Target.
999             // (CompilerLogger is never enabled for cpp_stub, for now anyway.)
1000             auto gen = GeneratorRegistry::create(generator_name, GeneratorContext(Target()));
1001             auto stub_file_path = base_path + output_info[Output::cpp_stub].extension;
1002             gen->emit_cpp_stub(stub_file_path);
1003         }
1004 
1005         // Don't bother with this if we're just emitting a cpp_stub.
1006         if (!stub_only) {
1007             auto output_files = compute_output_files(targets[0], base_path, outputs);
1008             auto module_factory = [&generator_name, &generator_args, build_gradient_module](const std::string &name, const Target &target) -> Module {
1009                 auto sub_generator_args = generator_args;
1010                 sub_generator_args.erase("target");
1011                 // Must re-create each time since each instance will have a different Target.
1012                 auto gen = GeneratorRegistry::create(generator_name, GeneratorContext(target));
1013                 gen->set_generator_param_values(sub_generator_args);
1014                 return build_gradient_module ? gen->build_gradient_module(name) : gen->build_module(name);
1015             };
1016             compile_multitarget(function_name, output_files, targets, target_strings, module_factory, compiler_logger_factory);
1017         }
1018     }
1019 
1020     return 0;
1021 }
1022 
1023 #ifdef HALIDE_WITH_EXCEPTIONS
generate_filter_main(int argc,char ** argv,std::ostream & cerr)1024 int generate_filter_main(int argc, char **argv, std::ostream &cerr) {
1025     try {
1026         return generate_filter_main_inner(argc, argv, cerr);
1027     } catch (std::runtime_error &err) {
1028         cerr << "Unhandled exception: " << err.what() << "\n";
1029         return -1;
1030     }
1031 }
1032 #else
generate_filter_main(int argc,char ** argv,std::ostream & cerr)1033 int generate_filter_main(int argc, char **argv, std::ostream &cerr) {
1034     return generate_filter_main_inner(argc, argv, cerr);
1035 }
1036 #endif
1037 
GeneratorParamBase(const std::string & name)1038 GeneratorParamBase::GeneratorParamBase(const std::string &name)
1039     : name(name) {
1040     ObjectInstanceRegistry::register_instance(this, 0, ObjectInstanceRegistry::GeneratorParam,
1041                                               this, nullptr);
1042 }
1043 
~GeneratorParamBase()1044 GeneratorParamBase::~GeneratorParamBase() {
1045     ObjectInstanceRegistry::unregister_instance(this);
1046 }
1047 
check_value_readable() const1048 void GeneratorParamBase::check_value_readable() const {
1049     // These are always readable.
1050     if (name == "target") return;
1051     if (name == "auto_schedule") return;
1052     if (name == "machine_params") return;
1053     user_assert(generator && generator->phase >= GeneratorBase::ConfigureCalled)
1054         << "The GeneratorParam \"" << name << "\" cannot be read before build() or configure()/generate() is called.\n";
1055 }
1056 
check_value_writable() const1057 void GeneratorParamBase::check_value_writable() const {
1058     // Allow writing when no Generator is set, to avoid having to special-case ctor initing code
1059     if (!generator) return;
1060     user_assert(generator->phase < GeneratorBase::GenerateCalled) << "The GeneratorParam \"" << name << "\" cannot be written after build() or generate() is called.\n";
1061 }
1062 
fail_wrong_type(const char * type)1063 void GeneratorParamBase::fail_wrong_type(const char *type) {
1064     user_error << "The GeneratorParam \"" << name << "\" cannot be set with a value of type " << type << ".\n";
1065 }
1066 
1067 /* static */
get_registry()1068 GeneratorRegistry &GeneratorRegistry::get_registry() {
1069     static GeneratorRegistry *registry = new GeneratorRegistry;
1070     return *registry;
1071 }
1072 
1073 /* static */
register_factory(const std::string & name,GeneratorFactory generator_factory)1074 void GeneratorRegistry::register_factory(const std::string &name,
1075                                          GeneratorFactory generator_factory) {
1076     user_assert(is_valid_name(name)) << "Invalid Generator name: " << name;
1077     GeneratorRegistry &registry = get_registry();
1078     std::lock_guard<std::mutex> lock(registry.mutex);
1079     internal_assert(registry.factories.find(name) == registry.factories.end())
1080         << "Duplicate Generator name: " << name;
1081     registry.factories[name] = std::move(generator_factory);
1082 }
1083 
1084 /* static */
unregister_factory(const std::string & name)1085 void GeneratorRegistry::unregister_factory(const std::string &name) {
1086     GeneratorRegistry &registry = get_registry();
1087     std::lock_guard<std::mutex> lock(registry.mutex);
1088     internal_assert(registry.factories.find(name) != registry.factories.end())
1089         << "Generator not found: " << name;
1090     registry.factories.erase(name);
1091 }
1092 
1093 /* static */
create(const std::string & name,const GeneratorContext & context)1094 std::unique_ptr<GeneratorBase> GeneratorRegistry::create(const std::string &name,
1095                                                          const GeneratorContext &context) {
1096     GeneratorRegistry &registry = get_registry();
1097     std::lock_guard<std::mutex> lock(registry.mutex);
1098     auto it = registry.factories.find(name);
1099     if (it == registry.factories.end()) {
1100         std::ostringstream o;
1101         o << "Generator not found: " << name << "\n";
1102         o << "Did you mean:\n";
1103         for (const auto &n : registry.factories) {
1104             o << "    " << n.first << "\n";
1105         }
1106         user_error << o.str();
1107     }
1108     std::unique_ptr<GeneratorBase> g = it->second(context);
1109     internal_assert(g != nullptr);
1110     return g;
1111 }
1112 
1113 /* static */
enumerate()1114 std::vector<std::string> GeneratorRegistry::enumerate() {
1115     GeneratorRegistry &registry = get_registry();
1116     std::lock_guard<std::mutex> lock(registry.mutex);
1117     std::vector<std::string> result;
1118     for (const auto &i : registry.factories) {
1119         result.push_back(i.first);
1120     }
1121     return result;
1122 }
1123 
GeneratorBase(size_t size,const void * introspection_helper)1124 GeneratorBase::GeneratorBase(size_t size, const void *introspection_helper)
1125     : size(size) {
1126     ObjectInstanceRegistry::register_instance(this, size, ObjectInstanceRegistry::Generator, this, introspection_helper);
1127 }
1128 
~GeneratorBase()1129 GeneratorBase::~GeneratorBase() {
1130     ObjectInstanceRegistry::unregister_instance(this);
1131 }
1132 
GeneratorParamInfo(GeneratorBase * generator,const size_t size)1133 GeneratorParamInfo::GeneratorParamInfo(GeneratorBase *generator, const size_t size) {
1134     std::vector<void *> vf = ObjectInstanceRegistry::instances_in_range(
1135         generator, size, ObjectInstanceRegistry::FilterParam);
1136     user_assert(vf.empty()) << "ImageParam and Param<> are no longer allowed in Generators; use Input<> instead.";
1137 
1138     const auto add_synthetic_params = [this, generator](GIOBase *gio) {
1139         const std::string &n = gio->name();
1140         const std::string &gn = generator->generator_registered_name;
1141 
1142         if (gio->kind() != IOKind::Scalar) {
1143             owned_synthetic_params.push_back(GeneratorParam_Synthetic<Type>::make(generator, gn, n + ".type", *gio, SyntheticParamType::Type, gio->types_defined()));
1144             filter_generator_params.push_back(owned_synthetic_params.back().get());
1145 
1146             owned_synthetic_params.push_back(GeneratorParam_Synthetic<int>::make(generator, gn, n + ".dim", *gio, SyntheticParamType::Dim, gio->dims_defined()));
1147             filter_generator_params.push_back(owned_synthetic_params.back().get());
1148         }
1149         if (gio->is_array()) {
1150             owned_synthetic_params.push_back(GeneratorParam_Synthetic<size_t>::make(generator, gn, n + ".size", *gio, SyntheticParamType::ArraySize, gio->array_size_defined()));
1151             filter_generator_params.push_back(owned_synthetic_params.back().get());
1152         }
1153     };
1154 
1155     std::vector<void *> vi = ObjectInstanceRegistry::instances_in_range(
1156         generator, size, ObjectInstanceRegistry::GeneratorInput);
1157     for (auto v : vi) {
1158         auto input = static_cast<Internal::GeneratorInputBase *>(v);
1159         internal_assert(input != nullptr);
1160         user_assert(is_valid_name(input->name())) << "Invalid Input name: (" << input->name() << ")\n";
1161         user_assert(!names.count(input->name())) << "Duplicate Input name: " << input->name();
1162         names.insert(input->name());
1163         internal_assert(input->generator == nullptr || input->generator == generator);
1164         input->generator = generator;
1165         filter_inputs.push_back(input);
1166         add_synthetic_params(input);
1167     }
1168 
1169     std::vector<void *> vo = ObjectInstanceRegistry::instances_in_range(
1170         generator, size, ObjectInstanceRegistry::GeneratorOutput);
1171     for (auto v : vo) {
1172         auto output = static_cast<Internal::GeneratorOutputBase *>(v);
1173         internal_assert(output != nullptr);
1174         user_assert(is_valid_name(output->name())) << "Invalid Output name: (" << output->name() << ")\n";
1175         user_assert(!names.count(output->name())) << "Duplicate Output name: " << output->name();
1176         names.insert(output->name());
1177         internal_assert(output->generator == nullptr || output->generator == generator);
1178         output->generator = generator;
1179         filter_outputs.push_back(output);
1180         add_synthetic_params(output);
1181     }
1182 
1183     std::vector<void *> vg = ObjectInstanceRegistry::instances_in_range(
1184         generator, size, ObjectInstanceRegistry::GeneratorParam);
1185     for (auto v : vg) {
1186         auto param = static_cast<GeneratorParamBase *>(v);
1187         internal_assert(param != nullptr);
1188         user_assert(is_valid_name(param->name)) << "Invalid GeneratorParam name: " << param->name;
1189         user_assert(!names.count(param->name)) << "Duplicate GeneratorParam name: " << param->name;
1190         names.insert(param->name);
1191         internal_assert(param->generator == nullptr || param->generator == generator);
1192         param->generator = generator;
1193         filter_generator_params.push_back(param);
1194     }
1195 
1196     for (auto &g : owned_synthetic_params) {
1197         g->generator = generator;
1198     }
1199 }
1200 
param_info()1201 GeneratorParamInfo &GeneratorBase::param_info() {
1202     internal_assert(param_info_ptr != nullptr);
1203     return *param_info_ptr;
1204 }
1205 
get_output(const std::string & n)1206 Func GeneratorBase::get_output(const std::string &n) {
1207     check_min_phase(GenerateCalled);
1208     auto *output = find_output_by_name(n);
1209     // Call for the side-effect of asserting if the value isn't defined.
1210     (void)output->array_size();
1211     user_assert(!output->is_array() && output->funcs().size() == 1) << "Output " << n << " must be accessed via get_array_output()\n";
1212     Func f = output->funcs().at(0);
1213     user_assert(f.defined()) << "Output " << n << " was not defined.\n";
1214     return f;
1215 }
1216 
get_array_output(const std::string & n)1217 std::vector<Func> GeneratorBase::get_array_output(const std::string &n) {
1218     check_min_phase(GenerateCalled);
1219     auto *output = find_output_by_name(n);
1220     // Call for the side-effect of asserting if the value isn't defined.
1221     (void)output->array_size();
1222     for (const auto &f : output->funcs()) {
1223         user_assert(f.defined()) << "Output " << n << " was not fully defined.\n";
1224     }
1225     return output->funcs();
1226 }
1227 
1228 // Find output by name. If not found, assert-fail. Never returns null.
find_output_by_name(const std::string & name)1229 GeneratorOutputBase *GeneratorBase::find_output_by_name(const std::string &name) {
1230     // There usually are very few outputs, so a linear search is fine
1231     GeneratorParamInfo &pi = param_info();
1232     for (GeneratorOutputBase *output : pi.outputs()) {
1233         if (output->name() == name) {
1234             return output;
1235         }
1236     }
1237     internal_error << "Output " << name << " not found.";
1238     return nullptr;  // not reached
1239 }
1240 
set_generator_param_values(const GeneratorParamsMap & params)1241 void GeneratorBase::set_generator_param_values(const GeneratorParamsMap &params) {
1242     GeneratorParamInfo &pi = param_info();
1243 
1244     std::unordered_map<std::string, Internal::GeneratorParamBase *> generator_params_by_name;
1245     for (auto *g : pi.generator_params()) {
1246         generator_params_by_name[g->name] = g;
1247     }
1248 
1249     for (auto &key_value : params) {
1250         auto gp = generator_params_by_name.find(key_value.first);
1251         user_assert(gp != generator_params_by_name.end())
1252             << "Generator " << generator_registered_name << " has no GeneratorParam named: " << key_value.first << "\n";
1253         if (gp->second->is_looplevel_param()) {
1254             if (!key_value.second.string_value.empty()) {
1255                 gp->second->set_from_string(key_value.second.string_value);
1256             } else {
1257                 gp->second->set(key_value.second.loop_level);
1258             }
1259         } else {
1260             gp->second->set_from_string(key_value.second.string_value);
1261         }
1262     }
1263 }
1264 
init_from_context(const Halide::GeneratorContext & context)1265 void GeneratorBase::init_from_context(const Halide::GeneratorContext &context) {
1266     Halide::GeneratorContext::init_from_context(context);
1267     internal_assert(param_info_ptr == nullptr);
1268     // pre-emptively build our param_info now
1269     param_info_ptr.reset(new GeneratorParamInfo(this, size));
1270 }
1271 
set_generator_names(const std::string & registered_name,const std::string & stub_name)1272 void GeneratorBase::set_generator_names(const std::string &registered_name, const std::string &stub_name) {
1273     user_assert(is_valid_name(registered_name)) << "Invalid Generator name: " << registered_name;
1274     internal_assert(!registered_name.empty() && !stub_name.empty());
1275     internal_assert(generator_registered_name.empty() && generator_stub_name.empty());
1276     generator_registered_name = registered_name;
1277     generator_stub_name = stub_name;
1278 }
1279 
set_inputs_vector(const std::vector<std::vector<StubInput>> & inputs)1280 void GeneratorBase::set_inputs_vector(const std::vector<std::vector<StubInput>> &inputs) {
1281     advance_phase(InputsSet);
1282     internal_assert(!inputs_set) << "set_inputs_vector() must be called at most once per Generator instance.\n";
1283     GeneratorParamInfo &pi = param_info();
1284     user_assert(inputs.size() == pi.inputs().size())
1285         << "Expected exactly " << pi.inputs().size()
1286         << " inputs but got " << inputs.size() << "\n";
1287     for (size_t i = 0; i < pi.inputs().size(); ++i) {
1288         pi.inputs()[i]->set_inputs(inputs[i]);
1289     }
1290     inputs_set = true;
1291 }
1292 
track_parameter_values(bool include_outputs)1293 void GeneratorBase::track_parameter_values(bool include_outputs) {
1294     GeneratorParamInfo &pi = param_info();
1295     for (auto input : pi.inputs()) {
1296         if (input->kind() == IOKind::Buffer) {
1297             internal_assert(!input->parameters_.empty());
1298             for (auto &p : input->parameters_) {
1299                 // This must use p.name(), *not* input->name()
1300                 get_value_tracker()->track_values(p.name(), parameter_constraints(p));
1301             }
1302         }
1303     }
1304     if (include_outputs) {
1305         for (auto output : pi.outputs()) {
1306             if (output->kind() == IOKind::Buffer) {
1307                 internal_assert(!output->funcs().empty());
1308                 for (auto &f : output->funcs()) {
1309                     user_assert(f.defined()) << "Output " << output->name() << " is not fully defined.";
1310                     auto output_buffers = f.output_buffers();
1311                     for (auto &o : output_buffers) {
1312                         Parameter p = o.parameter();
1313                         // This must use p.name(), *not* output->name()
1314                         get_value_tracker()->track_values(p.name(), parameter_constraints(p));
1315                     }
1316                 }
1317             }
1318         }
1319     }
1320 }
1321 
check_min_phase(Phase expected_phase) const1322 void GeneratorBase::check_min_phase(Phase expected_phase) const {
1323     user_assert(phase >= expected_phase) << "You may not do this operation at this phase.";
1324 }
1325 
check_exact_phase(Phase expected_phase) const1326 void GeneratorBase::check_exact_phase(Phase expected_phase) const {
1327     user_assert(phase == expected_phase) << "You may not do this operation at this phase.";
1328 }
1329 
advance_phase(Phase new_phase)1330 void GeneratorBase::advance_phase(Phase new_phase) {
1331     switch (new_phase) {
1332     case Created:
1333         internal_error << "Impossible";
1334         break;
1335     case ConfigureCalled:
1336         internal_assert(phase == Created) << "pase is " << phase;
1337         break;
1338     case InputsSet:
1339         internal_assert(phase == Created || phase == ConfigureCalled);
1340         break;
1341     case GenerateCalled:
1342         // It's OK to advance directly to GenerateCalled.
1343         internal_assert(phase == Created || phase == ConfigureCalled || phase == InputsSet);
1344         break;
1345     case ScheduleCalled:
1346         internal_assert(phase == GenerateCalled);
1347         break;
1348     }
1349     phase = new_phase;
1350 }
1351 
pre_configure()1352 void GeneratorBase::pre_configure() {
1353     advance_phase(ConfigureCalled);
1354 }
1355 
post_configure()1356 void GeneratorBase::post_configure() {
1357 }
1358 
pre_generate()1359 void GeneratorBase::pre_generate() {
1360     advance_phase(GenerateCalled);
1361     GeneratorParamInfo &pi = param_info();
1362     user_assert(!pi.outputs().empty()) << "Must use Output<> with generate() method.";
1363     user_assert(get_target() != Target()) << "The Generator target has not been set.";
1364 
1365     if (!inputs_set) {
1366         for (auto *input : pi.inputs()) {
1367             input->init_internals();
1368         }
1369         inputs_set = true;
1370     }
1371     for (auto *output : pi.outputs()) {
1372         output->init_internals();
1373     }
1374     track_parameter_values(false);
1375 }
1376 
post_generate()1377 void GeneratorBase::post_generate() {
1378     track_parameter_values(true);
1379 }
1380 
pre_schedule()1381 void GeneratorBase::pre_schedule() {
1382     advance_phase(ScheduleCalled);
1383     track_parameter_values(true);
1384 }
1385 
post_schedule()1386 void GeneratorBase::post_schedule() {
1387     track_parameter_values(true);
1388 }
1389 
pre_build()1390 void GeneratorBase::pre_build() {
1391     advance_phase(GenerateCalled);
1392     advance_phase(ScheduleCalled);
1393     GeneratorParamInfo &pi = param_info();
1394     user_assert(pi.outputs().empty()) << "May not use build() method with Output<>.";
1395     if (!inputs_set) {
1396         for (auto *input : pi.inputs()) {
1397             input->init_internals();
1398         }
1399         inputs_set = true;
1400     }
1401     track_parameter_values(false);
1402 }
1403 
post_build()1404 void GeneratorBase::post_build() {
1405     track_parameter_values(true);
1406 }
1407 
get_pipeline()1408 Pipeline GeneratorBase::get_pipeline() {
1409     check_min_phase(GenerateCalled);
1410     if (!pipeline.defined()) {
1411         GeneratorParamInfo &pi = param_info();
1412         user_assert(!pi.outputs().empty()) << "Must use get_pipeline<> with Output<>.";
1413         std::vector<Func> funcs;
1414         for (auto *output : pi.outputs()) {
1415             for (const auto &f : output->funcs()) {
1416                 user_assert(f.defined()) << "Output \"" << f.name() << "\" was not defined.\n";
1417                 if (output->dims_defined()) {
1418                     user_assert(f.dimensions() == output->dims()) << "Output \"" << f.name()
1419                                                                   << "\" requires dimensions=" << output->dims()
1420                                                                   << " but was defined as dimensions=" << f.dimensions() << ".\n";
1421                 }
1422                 if (output->types_defined()) {
1423                     user_assert((int)f.outputs() == (int)output->types().size()) << "Output \"" << f.name()
1424                                                                                  << "\" requires a Tuple of size " << output->types().size()
1425                                                                                  << " but was defined as Tuple of size " << f.outputs() << ".\n";
1426                     for (size_t i = 0; i < f.output_types().size(); ++i) {
1427                         Type expected = output->types().at(i);
1428                         Type actual = f.output_types()[i];
1429                         user_assert(expected == actual) << "Output \"" << f.name()
1430                                                         << "\" requires type " << expected
1431                                                         << " but was defined as type " << actual << ".\n";
1432                     }
1433                 }
1434                 funcs.push_back(f);
1435             }
1436         }
1437         pipeline = Pipeline(funcs);
1438     }
1439     return pipeline;
1440 }
1441 
build_module(const std::string & function_name,const LinkageType linkage_type)1442 Module GeneratorBase::build_module(const std::string &function_name,
1443                                    const LinkageType linkage_type) {
1444     AutoSchedulerResults auto_schedule_results;
1445     call_configure();
1446     Pipeline pipeline = build_pipeline();
1447     if (get_auto_schedule()) {
1448         auto_schedule_results = pipeline.auto_schedule(get_target(), get_machine_params());
1449     }
1450 
1451     const GeneratorParamInfo &pi = param_info();
1452     std::vector<Argument> filter_arguments;
1453     for (const auto *input : pi.inputs()) {
1454         for (const auto &p : input->parameters_) {
1455             filter_arguments.push_back(to_argument(p, p.is_buffer() ? Expr() : input->get_def_expr()));
1456         }
1457     }
1458 
1459     Module result = pipeline.compile_to_module(filter_arguments, function_name, get_target(), linkage_type);
1460     std::shared_ptr<ExternsMap> externs_map = get_externs_map();
1461     for (const auto &map_entry : *externs_map) {
1462         result.append(map_entry.second);
1463     }
1464 
1465     for (const auto *output : pi.outputs()) {
1466         for (size_t i = 0; i < output->funcs().size(); ++i) {
1467             auto from = output->funcs()[i].name();
1468             auto to = output->array_name(i);
1469             size_t tuple_size = output->types_defined() ? output->types().size() : 1;
1470             for (size_t t = 0; t < tuple_size; ++t) {
1471                 std::string suffix = (tuple_size > 1) ? ("." + std::to_string(t)) : "";
1472                 result.remap_metadata_name(from + suffix, to + suffix);
1473             }
1474         }
1475     }
1476 
1477     result.set_auto_scheduler_results(auto_schedule_results);
1478 
1479     return result;
1480 }
1481 
build_gradient_module(const std::string & function_name)1482 Module GeneratorBase::build_gradient_module(const std::string &function_name) {
1483     constexpr int DBG = 1;
1484 
1485     // I doubt these ever need customizing; if they do, we can make them arguments to this function.
1486     const std::string grad_input_pattern = "_grad_loss_for_$OUT$";
1487     const std::string grad_output_pattern = "_grad_loss_$OUT$_wrt_$IN$";
1488     const LinkageType linkage_type = LinkageType::ExternalPlusMetadata;
1489 
1490     user_assert(!function_name.empty()) << "build_gradient_module(): function_name cannot be empty\n";
1491 
1492     call_configure();
1493     Pipeline original_pipeline = build_pipeline();
1494     std::vector<Func> original_outputs = original_pipeline.outputs();
1495 
1496     // Construct the adjoint pipeline, which has:
1497     // - All the same inputs as the original, in the same order
1498     // - Followed by one grad-input for each original output
1499     // - Followed by one output for each unique pairing of original-output + original-input.
1500 
1501     const GeneratorParamInfo &pi = param_info();
1502 
1503     // Even though propagate_adjoints() supports Funcs-of-Tuples just fine,
1504     // we aren't going to support them here (yet); AFAICT, neither PyTorch nor
1505     // TF support Tensors with Tuples-as-values, so we'd have to split the
1506     // tuples up into separate Halide inputs and outputs anyway; since Generator
1507     // doesn't support Tuple-valued Inputs at all, and Tuple-valued Outputs
1508     // are quite rare, we're going to just fail up front, with the assumption
1509     // that the coder will explicitly adapt their code as needed. (Note that
1510     // support for Tupled outputs could be added with some effort, so if this
1511     // is somehow deemed critical, go for it)
1512     for (const auto *input : pi.inputs()) {
1513         const size_t tuple_size = input->types_defined() ? input->types().size() : 1;
1514         // Note: this should never happen
1515         internal_assert(tuple_size == 1) << "Tuple Inputs are not yet supported by build_gradient_module()";
1516     }
1517     for (const auto *output : pi.outputs()) {
1518         const size_t tuple_size = output->types_defined() ? output->types().size() : 1;
1519         internal_assert(tuple_size == 1) << "Tuple Outputs are not yet supported by build_gradient_module";
1520     }
1521 
1522     std::vector<Argument> gradient_inputs;
1523 
1524     // First: the original inputs. Note that scalar inputs remain scalar,
1525     // rather being promoted into zero-dimensional buffers.
1526     for (const auto *input : pi.inputs()) {
1527         // There can be multiple Funcs/Parameters per input if the input is an Array
1528         internal_assert(input->parameters_.size() == input->funcs_.size());
1529         for (const auto &p : input->parameters_) {
1530             gradient_inputs.push_back(to_argument(p, p.is_buffer() ? Expr() : input->get_def_expr()));
1531             debug(DBG) << "    gradient copied input is: " << gradient_inputs.back().name << "\n";
1532         }
1533     }
1534 
1535     // Next: add a grad-input for each *original* output; these will
1536     // be the same shape as the output (so we should copy estimates from
1537     // those outputs onto these estimates).
1538     // - If an output is an Array, we'll have a separate input for each array element.
1539 
1540     std::vector<ImageParam> d_output_imageparams;
1541     for (const auto *output : pi.outputs()) {
1542         for (size_t i = 0; i < output->funcs().size(); ++i) {
1543             const Func &f = output->funcs()[i];
1544             const std::string output_name = output->array_name(i);
1545             // output_name is something like "funcname_i"
1546             const std::string grad_in_name = replace_all(grad_input_pattern, "$OUT$", output_name);
1547             // TODO(srj): does it make sense for gradient to be a non-float type?
1548             // For now, assume it's always float32 (unless the output is already some float).
1549             const Type grad_in_type = output->type().is_float() ? output->type() : Float(32);
1550             const int grad_in_dimensions = f.dimensions();
1551             const ArgumentEstimates grad_in_estimates = f.output_buffer().parameter().get_argument_estimates();
1552             internal_assert((int)grad_in_estimates.buffer_estimates.size() == grad_in_dimensions);
1553 
1554             ImageParam d_im(grad_in_type, grad_in_dimensions, grad_in_name);
1555             for (int d = 0; d < grad_in_dimensions; d++) {
1556                 d_im.parameter().set_min_constraint_estimate(d, grad_in_estimates.buffer_estimates[i].min);
1557                 d_im.parameter().set_extent_constraint_estimate(d, grad_in_estimates.buffer_estimates[i].extent);
1558             }
1559             d_output_imageparams.push_back(d_im);
1560             gradient_inputs.push_back(to_argument(d_im.parameter(), Expr()));
1561 
1562             debug(DBG) << "    gradient synthesized input is: " << gradient_inputs.back().name << "\n";
1563         }
1564     }
1565 
1566     // Finally: define the output Func(s), one for each unique output/input pair.
1567     // Note that original_outputs.size() != pi.outputs().size() if any outputs are arrays.
1568     internal_assert(original_outputs.size() == d_output_imageparams.size());
1569     std::vector<Func> gradient_outputs;
1570     for (size_t i = 0; i < original_outputs.size(); ++i) {
1571         const Func &original_output = original_outputs.at(i);
1572         const ImageParam &d_output = d_output_imageparams.at(i);
1573         Region bounds;
1574         for (int i = 0; i < d_output.dimensions(); i++) {
1575             bounds.emplace_back(d_output.dim(i).min(), d_output.dim(i).extent());
1576         }
1577         Func adjoint_func = BoundaryConditions::constant_exterior(d_output, make_zero(d_output.type()));
1578         Derivative d = propagate_adjoints(original_output, adjoint_func, bounds);
1579 
1580         const std::string &output_name = original_output.name();
1581         for (const auto *input : pi.inputs()) {
1582             for (size_t i = 0; i < input->funcs_.size(); ++i) {
1583                 const std::string input_name = input->array_name(i);
1584                 const auto &f = input->funcs_[i];
1585                 const auto &p = input->parameters_[i];
1586 
1587                 Func d_f = d(f);
1588 
1589                 std::string grad_out_name = replace_all(replace_all(grad_output_pattern, "$OUT$", output_name), "$IN$", input_name);
1590                 if (!d_f.defined()) {
1591                     grad_out_name = "_dummy" + grad_out_name;
1592                 }
1593 
1594                 Func d_out_wrt_in(grad_out_name);
1595                 if (d_f.defined()) {
1596                     d_out_wrt_in(Halide::_) = d_f(Halide::_);
1597                 } else {
1598                     debug(DBG) << "    No Derivative found for output " << output_name << " wrt input " << input_name << "\n";
1599                     // If there was no Derivative found, don't skip the output;
1600                     // just replace with a dummy Func that is all zeros. This ensures
1601                     // that the signature of the Pipeline we produce is always predictable.
1602                     std::vector<Var> vars;
1603                     for (int i = 0; i < d_output.dimensions(); i++) {
1604                         vars.push_back(Var::implicit(i));
1605                     }
1606                     d_out_wrt_in(vars) = make_zero(d_output.type());
1607                 }
1608 
1609                 d_out_wrt_in.set_estimates(p.get_argument_estimates().buffer_estimates);
1610 
1611                 // Useful for debugging; ordinarily better to leave out
1612                 // debug(0) << "\n\n"
1613                 //          << "output:\n" << FuncWithDependencies(original_output) << "\n"
1614                 //          << "d_output:\n" << FuncWithDependencies(adjoint_func) << "\n"
1615                 //          << "input:\n" << FuncWithDependencies(f) << "\n"
1616                 //          << "d_out_wrt_in:\n" << FuncWithDependencies(d_out_wrt_in) << "\n";
1617 
1618                 gradient_outputs.push_back(d_out_wrt_in);
1619                 debug(DBG) << "    gradient output is: " << d_out_wrt_in.name() << "\n";
1620             }
1621         }
1622     }
1623 
1624     Pipeline grad_pipeline = Pipeline(gradient_outputs);
1625 
1626     AutoSchedulerResults auto_schedule_results;
1627     if (get_auto_schedule()) {
1628         auto_schedule_results = grad_pipeline.auto_schedule(get_target(), get_machine_params());
1629     } else {
1630         user_warning << "Autoscheduling is not enabled in build_gradient_module(), so the resulting "
1631                         "gradient module will be unscheduled; this is very unlikely to be what you want.\n";
1632     }
1633 
1634     Module result = grad_pipeline.compile_to_module(gradient_inputs, function_name, get_target(), linkage_type);
1635     user_assert(get_externs_map()->empty())
1636         << "Building a gradient-descent module for a Generator with ExternalCode is not supported.\n";
1637 
1638     result.set_auto_scheduler_results(auto_schedule_results);
1639 
1640     return result;
1641 }
1642 
emit_cpp_stub(const std::string & stub_file_path)1643 void GeneratorBase::emit_cpp_stub(const std::string &stub_file_path) {
1644     user_assert(!generator_registered_name.empty() && !generator_stub_name.empty()) << "Generator has no name.\n";
1645     // Make sure we call configure() so that extra inputs/outputs are added as necessary.
1646     call_configure();
1647     // StubEmitter will want to access the GP/SP values, so advance the phase to avoid assert-fails.
1648     advance_phase(GenerateCalled);
1649     advance_phase(ScheduleCalled);
1650     GeneratorParamInfo &pi = param_info();
1651     std::ofstream file(stub_file_path);
1652     StubEmitter emit(file, generator_registered_name, generator_stub_name, pi.generator_params(), pi.inputs(), pi.outputs());
1653     emit.emit();
1654 }
1655 
check_scheduled(const char * m) const1656 void GeneratorBase::check_scheduled(const char *m) const {
1657     check_min_phase(ScheduleCalled);
1658 }
1659 
check_input_is_singular(Internal::GeneratorInputBase * in)1660 void GeneratorBase::check_input_is_singular(Internal::GeneratorInputBase *in) {
1661     user_assert(!in->is_array())
1662         << "Input " << in->name() << " is an array, and must be set with a vector type.";
1663 }
1664 
check_input_is_array(Internal::GeneratorInputBase * in)1665 void GeneratorBase::check_input_is_array(Internal::GeneratorInputBase *in) {
1666     user_assert(in->is_array())
1667         << "Input " << in->name() << " is not an array, and must not be set with a vector type.";
1668 }
1669 
check_input_kind(Internal::GeneratorInputBase * in,Internal::IOKind kind)1670 void GeneratorBase::check_input_kind(Internal::GeneratorInputBase *in, Internal::IOKind kind) {
1671     user_assert(in->kind() == kind)
1672         << "Input " << in->name() << " cannot be set with the type specified.";
1673 }
1674 
GIOBase(size_t array_size,const std::string & name,IOKind kind,const std::vector<Type> & types,int dims)1675 GIOBase::GIOBase(size_t array_size,
1676                  const std::string &name,
1677                  IOKind kind,
1678                  const std::vector<Type> &types,
1679                  int dims)
1680     : array_size_(array_size), name_(name), kind_(kind), types_(types), dims_(dims) {
1681 }
1682 
~GIOBase()1683 GIOBase::~GIOBase() {
1684     // nothing
1685 }
1686 
array_size_defined() const1687 bool GIOBase::array_size_defined() const {
1688     return array_size_ != -1;
1689 }
1690 
array_size() const1691 size_t GIOBase::array_size() const {
1692     user_assert(array_size_defined()) << "ArraySize is unspecified for " << input_or_output() << "'" << name() << "'; you need to explicitly set it via the resize() method or by setting '"
1693                                       << name() << ".size' in your build rules.";
1694     return (size_t)array_size_;
1695 }
1696 
is_array() const1697 bool GIOBase::is_array() const {
1698     internal_error << "Unimplemented";
1699     return false;
1700 }
1701 
name() const1702 const std::string &GIOBase::name() const {
1703     return name_;
1704 }
1705 
kind() const1706 IOKind GIOBase::kind() const {
1707     return kind_;
1708 }
1709 
types_defined() const1710 bool GIOBase::types_defined() const {
1711     return !types_.empty();
1712 }
1713 
types() const1714 const std::vector<Type> &GIOBase::types() const {
1715     // If types aren't defined, but we have one Func that is,
1716     // we probably just set an Output<Func> and should propagate the types.
1717     if (!types_defined()) {
1718         // use funcs_, not funcs(): the latter could give a much-less-helpful error message
1719         // in this case.
1720         const auto &f = funcs_;
1721         if (f.size() == 1 && f.at(0).defined()) {
1722             check_matching_types(f.at(0).output_types());
1723         }
1724     }
1725     user_assert(types_defined()) << "Type is not defined for " << input_or_output() << " '" << name() << "'; you may need to specify '" << name() << ".type' as a GeneratorParam.\n";
1726     return types_;
1727 }
1728 
type() const1729 Type GIOBase::type() const {
1730     const auto &t = types();
1731     internal_assert(t.size() == 1) << "Expected types_.size() == 1, saw " << t.size() << " for " << name() << "\n";
1732     return t.at(0);
1733 }
1734 
dims_defined() const1735 bool GIOBase::dims_defined() const {
1736     return dims_ != -1;
1737 }
1738 
dims() const1739 int GIOBase::dims() const {
1740     // If types aren't defined, but we have one Func that is,
1741     // we probably just set an Output<Func> and should propagate the types.
1742     if (!dims_defined()) {
1743         // use funcs_, not funcs(): the latter could give a much-less-helpful error message
1744         // in this case.
1745         const auto &f = funcs_;
1746         if (f.size() == 1 && f.at(0).defined()) {
1747             check_matching_dims(funcs().at(0).dimensions());
1748         }
1749     }
1750     user_assert(dims_defined()) << "Dimensions are not defined for " << input_or_output() << " '" << name() << "'; you may need to specify '" << name() << ".dim' as a GeneratorParam.\n";
1751     return dims_;
1752 }
1753 
funcs() const1754 const std::vector<Func> &GIOBase::funcs() const {
1755     internal_assert(funcs_.size() == array_size() && exprs_.empty());
1756     return funcs_;
1757 }
1758 
exprs() const1759 const std::vector<Expr> &GIOBase::exprs() const {
1760     internal_assert(exprs_.size() == array_size() && funcs_.empty());
1761     return exprs_;
1762 }
1763 
verify_internals()1764 void GIOBase::verify_internals() {
1765     user_assert(dims_ >= 0) << "Generator Input/Output Dimensions must have positive values";
1766 
1767     if (kind() != IOKind::Scalar) {
1768         for (const Func &f : funcs()) {
1769             user_assert(f.defined()) << "Input/Output " << name() << " is not defined.\n";
1770             user_assert(f.dimensions() == dims())
1771                 << "Expected dimensions " << dims()
1772                 << " but got " << f.dimensions()
1773                 << " for " << name() << "\n";
1774             user_assert(f.outputs() == 1)
1775                 << "Expected outputs() == " << 1
1776                 << " but got " << f.outputs()
1777                 << " for " << name() << "\n";
1778             user_assert(f.output_types().size() == 1)
1779                 << "Expected output_types().size() == " << 1
1780                 << " but got " << f.outputs()
1781                 << " for " << name() << "\n";
1782             user_assert(f.output_types()[0] == type())
1783                 << "Expected type " << type()
1784                 << " but got " << f.output_types()[0]
1785                 << " for " << name() << "\n";
1786         }
1787     } else {
1788         for (const Expr &e : exprs()) {
1789             user_assert(e.defined()) << "Input/Ouput " << name() << " is not defined.\n";
1790             user_assert(e.type() == type())
1791                 << "Expected type " << type()
1792                 << " but got " << e.type()
1793                 << " for " << name() << "\n";
1794         }
1795     }
1796 }
1797 
array_name(size_t i) const1798 std::string GIOBase::array_name(size_t i) const {
1799     std::string n = name();
1800     if (is_array()) {
1801         n += "_" + std::to_string(i);
1802     }
1803     return n;
1804 }
1805 
1806 // If our type(s) are defined, ensure it matches the ones passed in, asserting if not.
1807 // If our type(s) are not defined, just set to the ones passed in.
check_matching_types(const std::vector<Type> & t) const1808 void GIOBase::check_matching_types(const std::vector<Type> &t) const {
1809     if (types_defined()) {
1810         user_assert(types().size() == t.size()) << "Type mismatch for " << name() << ": expected " << types().size() << " types but saw " << t.size();
1811         for (size_t i = 0; i < t.size(); ++i) {
1812             user_assert(types().at(i) == t.at(i)) << "Type mismatch for " << name() << ": expected " << types().at(i) << " saw " << t.at(i);
1813         }
1814     } else {
1815         types_ = t;
1816     }
1817 }
1818 
check_gio_access() const1819 void GIOBase::check_gio_access() const {
1820     // // Allow reading when no Generator is set, to avoid having to special-case ctor initing code
1821     if (!generator) return;
1822     user_assert(generator->phase > GeneratorBase::InputsSet)
1823         << "The " << input_or_output() << " \"" << name() << "\" cannot be examined before build() or generate() is called.\n";
1824 }
1825 
1826 // If our dims are defined, ensure it matches the one passed in, asserting if not.
1827 // If our dims are not defined, just set to the one passed in.
check_matching_dims(int d) const1828 void GIOBase::check_matching_dims(int d) const {
1829     internal_assert(d >= 0);
1830     if (dims_defined()) {
1831         user_assert(dims() == d) << "Dimensions mismatch for " << name() << ": expected " << dims() << " saw " << d;
1832     } else {
1833         dims_ = d;
1834     }
1835 }
1836 
check_matching_array_size(size_t size) const1837 void GIOBase::check_matching_array_size(size_t size) const {
1838     if (array_size_defined()) {
1839         user_assert(array_size() == size) << "ArraySize mismatch for " << name() << ": expected " << array_size() << " saw " << size;
1840     } else {
1841         array_size_ = size;
1842     }
1843 }
1844 
GeneratorInputBase(size_t array_size,const std::string & name,IOKind kind,const std::vector<Type> & t,int d)1845 GeneratorInputBase::GeneratorInputBase(size_t array_size,
1846                                        const std::string &name,
1847                                        IOKind kind,
1848                                        const std::vector<Type> &t,
1849                                        int d)
1850     : GIOBase(array_size, name, kind, t, d) {
1851     ObjectInstanceRegistry::register_instance(this, 0, ObjectInstanceRegistry::GeneratorInput, this, nullptr);
1852 }
1853 
GeneratorInputBase(const std::string & name,IOKind kind,const std::vector<Type> & t,int d)1854 GeneratorInputBase::GeneratorInputBase(const std::string &name, IOKind kind, const std::vector<Type> &t, int d)
1855     : GeneratorInputBase(1, name, kind, t, d) {
1856     // nothing
1857 }
1858 
~GeneratorInputBase()1859 GeneratorInputBase::~GeneratorInputBase() {
1860     ObjectInstanceRegistry::unregister_instance(this);
1861 }
1862 
check_value_writable() const1863 void GeneratorInputBase::check_value_writable() const {
1864     user_assert(generator && generator->phase == GeneratorBase::InputsSet)
1865         << "The Input " << name() << " cannot be set at this point.\n";
1866 }
1867 
set_def_min_max()1868 void GeneratorInputBase::set_def_min_max() {
1869     // nothing
1870 }
1871 
get_def_expr() const1872 Expr GeneratorInputBase::get_def_expr() const {
1873     return Expr();
1874 }
1875 
parameter() const1876 Parameter GeneratorInputBase::parameter() const {
1877     user_assert(!this->is_array()) << "Cannot call the parameter() method on Input<[]> " << name() << "; use an explicit subscript operator instead.";
1878     return parameters_.at(0);
1879 }
1880 
verify_internals()1881 void GeneratorInputBase::verify_internals() {
1882     GIOBase::verify_internals();
1883 
1884     const size_t expected = (kind() != IOKind::Scalar) ? funcs().size() : exprs().size();
1885     user_assert(parameters_.size() == expected) << "Expected parameters_.size() == "
1886                                                 << expected << ", saw " << parameters_.size() << " for " << name() << "\n";
1887 }
1888 
init_internals()1889 void GeneratorInputBase::init_internals() {
1890     // Call these for the side-effect of asserting if the values aren't defined.
1891     (void)array_size();
1892     (void)types();
1893     (void)dims();
1894 
1895     parameters_.clear();
1896     exprs_.clear();
1897     funcs_.clear();
1898     for (size_t i = 0; i < array_size(); ++i) {
1899         auto name = array_name(i);
1900         parameters_.emplace_back(type(), kind() != IOKind::Scalar, dims(), name);
1901         auto &p = parameters_[i];
1902         if (kind() != IOKind::Scalar) {
1903             internal_assert(dims() == p.dimensions());
1904             funcs_.push_back(make_param_func(p, name));
1905         } else {
1906             Expr e = Internal::Variable::make(type(), name, p);
1907             exprs_.push_back(e);
1908         }
1909     }
1910 
1911     set_def_min_max();
1912     verify_internals();
1913 }
1914 
set_inputs(const std::vector<StubInput> & inputs)1915 void GeneratorInputBase::set_inputs(const std::vector<StubInput> &inputs) {
1916     generator->check_exact_phase(GeneratorBase::InputsSet);
1917     parameters_.clear();
1918     exprs_.clear();
1919     funcs_.clear();
1920     check_matching_array_size(inputs.size());
1921     for (size_t i = 0; i < inputs.size(); ++i) {
1922         const StubInput &in = inputs.at(i);
1923         user_assert(in.kind() == kind()) << "An input for " << name() << " is not of the expected kind.\n";
1924         if (kind() == IOKind::Function) {
1925             auto f = in.func();
1926             user_assert(f.defined()) << "The input for " << name() << " is an undefined Func. Please define it.\n";
1927             check_matching_types(f.output_types());
1928             check_matching_dims(f.dimensions());
1929             funcs_.push_back(f);
1930             parameters_.emplace_back(f.output_types().at(0), true, f.dimensions(), array_name(i));
1931         } else if (kind() == IOKind::Buffer) {
1932             auto p = in.parameter();
1933             user_assert(p.defined()) << "The input for " << name() << " is an undefined Buffer. Please define it.\n";
1934             check_matching_types({p.type()});
1935             check_matching_dims(p.dimensions());
1936             funcs_.push_back(make_param_func(p, name()));
1937             parameters_.push_back(p);
1938         } else {
1939             auto e = in.expr();
1940             user_assert(e.defined()) << "The input for " << name() << " is an undefined Expr. Please define it.\n";
1941             check_matching_types({e.type()});
1942             check_matching_dims(0);
1943             exprs_.push_back(e);
1944             parameters_.emplace_back(e.type(), false, 0, array_name(i));
1945         }
1946     }
1947 
1948     set_def_min_max();
1949     verify_internals();
1950 }
1951 
set_estimate_impl(const Var & var,const Expr & min,const Expr & extent)1952 void GeneratorInputBase::set_estimate_impl(const Var &var, const Expr &min, const Expr &extent) {
1953     internal_assert(exprs_.empty() && !funcs_.empty() && parameters_.size() == funcs_.size());
1954     for (size_t i = 0; i < funcs_.size(); ++i) {
1955         Func &f = funcs_[i];
1956         f.set_estimate(var, min, extent);
1957         // Propagate the estimate into the Parameter as well, just in case
1958         // we end up compiling this for toplevel.
1959         std::vector<Var> args = f.args();
1960         int dim = -1;
1961         for (size_t a = 0; a < args.size(); ++a) {
1962             if (args[a].same_as(var)) {
1963                 dim = a;
1964                 break;
1965             }
1966         }
1967         internal_assert(dim >= 0);
1968         Parameter &p = parameters_[i];
1969         p.set_min_constraint_estimate(dim, min);
1970         p.set_extent_constraint_estimate(dim, extent);
1971     }
1972 }
1973 
set_estimates_impl(const Region & estimates)1974 void GeneratorInputBase::set_estimates_impl(const Region &estimates) {
1975     internal_assert(exprs_.empty() && !funcs_.empty() && parameters_.size() == funcs_.size());
1976     for (size_t i = 0; i < funcs_.size(); ++i) {
1977         Func &f = funcs_[i];
1978         f.set_estimates(estimates);
1979         // Propagate the estimate into the Parameter as well, just in case
1980         // we end up compiling this for toplevel.
1981         for (size_t dim = 0; dim < estimates.size(); ++dim) {
1982             Parameter &p = parameters_[i];
1983             const Range &r = estimates[dim];
1984             p.set_min_constraint_estimate(dim, r.min);
1985             p.set_extent_constraint_estimate(dim, r.extent);
1986         }
1987     }
1988 }
1989 
GeneratorOutputBase(size_t array_size,const std::string & name,IOKind kind,const std::vector<Type> & t,int d)1990 GeneratorOutputBase::GeneratorOutputBase(size_t array_size, const std::string &name, IOKind kind, const std::vector<Type> &t, int d)
1991     : GIOBase(array_size, name, kind, t, d) {
1992     internal_assert(kind != IOKind::Scalar);
1993     ObjectInstanceRegistry::register_instance(this, 0, ObjectInstanceRegistry::GeneratorOutput,
1994                                               this, nullptr);
1995 }
1996 
GeneratorOutputBase(const std::string & name,IOKind kind,const std::vector<Type> & t,int d)1997 GeneratorOutputBase::GeneratorOutputBase(const std::string &name, IOKind kind, const std::vector<Type> &t, int d)
1998     : GeneratorOutputBase(1, name, kind, t, d) {
1999     // nothing
2000 }
2001 
~GeneratorOutputBase()2002 GeneratorOutputBase::~GeneratorOutputBase() {
2003     ObjectInstanceRegistry::unregister_instance(this);
2004 }
2005 
check_value_writable() const2006 void GeneratorOutputBase::check_value_writable() const {
2007     user_assert(generator && generator->phase == GeneratorBase::GenerateCalled)
2008         << "The Output " << name() << " can only be set inside generate().\n";
2009 }
2010 
init_internals()2011 void GeneratorOutputBase::init_internals() {
2012     exprs_.clear();
2013     funcs_.clear();
2014     if (array_size_defined()) {
2015         for (size_t i = 0; i < array_size(); ++i) {
2016             funcs_.emplace_back(array_name(i));
2017         }
2018     }
2019 }
2020 
resize(size_t size)2021 void GeneratorOutputBase::resize(size_t size) {
2022     internal_assert(is_array());
2023     internal_assert(!array_size_defined()) << "You may only call " << name()
2024                                            << ".resize() when then size is undefined\n";
2025     array_size_ = (int)size;
2026     init_internals();
2027 }
2028 
check_scheduled(const char * m) const2029 void StubOutputBufferBase::check_scheduled(const char *m) const {
2030     generator->check_scheduled(m);
2031 }
2032 
get_target() const2033 Target StubOutputBufferBase::get_target() const {
2034     return generator->get_target();
2035 }
2036 
generator_test()2037 void generator_test() {
2038     GeneratorContext context(get_host_target());
2039 
2040     // Verify that the Generator's internal phase actually prevents unsupported
2041     // order of operations.
2042     {
2043         class Tester : public Generator<Tester> {
2044         public:
2045             GeneratorParam<int> gp0{"gp0", 0};
2046             GeneratorParam<float> gp1{"gp1", 1.f};
2047             GeneratorParam<uint64_t> gp2{"gp2", 2};
2048 
2049             Input<int> input{"input"};
2050             Output<Func> output{"output", Int(32), 1};
2051 
2052             void generate() {
2053                 internal_assert(gp0 == 1);
2054                 internal_assert(gp1 == 2.f);
2055                 internal_assert(gp2 == (uint64_t)2);  // unchanged
2056                 Var x;
2057                 output(x) = input + gp0;
2058             }
2059             void schedule() {
2060                 // empty
2061             }
2062         };
2063 
2064         Tester tester;
2065         tester.init_from_context(context);
2066         internal_assert(tester.phase == GeneratorBase::Created);
2067 
2068         // Verify that calling GeneratorParam::set() works.
2069         tester.gp0.set(1);
2070 
2071         tester.set_inputs_vector({{StubInput(42)}});
2072         internal_assert(tester.phase == GeneratorBase::InputsSet);
2073 
2074         // tester.set_inputs_vector({{StubInput(43)}});  // This will assert-fail.
2075 
2076         // Also ok to call in this phase.
2077         tester.gp1.set(2.f);
2078 
2079         tester.call_generate();
2080         internal_assert(tester.phase == GeneratorBase::GenerateCalled);
2081 
2082         // tester.set_inputs_vector({{StubInput(44)}});  // This will assert-fail.
2083         // tester.gp2.set(2);  // This will assert-fail.
2084 
2085         tester.call_schedule();
2086         internal_assert(tester.phase == GeneratorBase::ScheduleCalled);
2087 
2088         // tester.set_inputs_vector({{StubInput(45)}});  // This will assert-fail.
2089         // tester.gp2.set(2);  // This will assert-fail.
2090         // tester.sp2.set(202);  // This will assert-fail.
2091     }
2092 
2093     // Verify that the Generator's internal phase actually prevents unsupported
2094     // order of operations (with old-style Generator)
2095     {
2096         class Tester : public Generator<Tester> {
2097         public:
2098             GeneratorParam<int> gp0{"gp0", 0};
2099             GeneratorParam<float> gp1{"gp1", 1.f};
2100             GeneratorParam<uint64_t> gp2{"gp2", 2};
2101             GeneratorParam<uint8_t> gp_uint8{"gp_uint8", 65};
2102             GeneratorParam<int8_t> gp_int8{"gp_int8", 66};
2103             GeneratorParam<char> gp_char{"gp_char", 97};
2104             GeneratorParam<signed char> gp_schar{"gp_schar", 98};
2105             GeneratorParam<unsigned char> gp_uchar{"gp_uchar", 99};
2106             GeneratorParam<bool> gp_bool{"gp_bool", true};
2107 
2108             Input<int> input{"input"};
2109 
2110             Func build() {
2111                 internal_assert(gp0 == 1);
2112                 internal_assert(gp1 == 2.f);
2113                 internal_assert(gp2 == (uint64_t)2);  // unchanged
2114                 internal_assert(gp_uint8 == 67);
2115                 internal_assert(gp_int8 == 68);
2116                 internal_assert(gp_bool == false);
2117                 internal_assert(gp_char == 107);
2118                 internal_assert(gp_schar == 108);
2119                 internal_assert(gp_uchar == 109);
2120                 Var x;
2121                 Func output;
2122                 output(x) = input + gp0;
2123                 return output;
2124             }
2125         };
2126 
2127         Tester tester;
2128         tester.init_from_context(context);
2129         internal_assert(tester.phase == GeneratorBase::Created);
2130 
2131         // Verify that calling GeneratorParam::set() works.
2132         tester.gp0.set(1);
2133 
2134         // set_inputs_vector() can't be called on an old-style Generator;
2135         // that's OK, since we can skip from Created -> GenerateCalled anyway
2136         // tester.set_inputs_vector({{StubInput(42)}});
2137         // internal_assert(tester.phase == GeneratorBase::InputsSet);
2138 
2139         // tester.set_inputs_vector({{StubInput(43)}});  // This will assert-fail.
2140 
2141         // Also ok to call in this phase.
2142         tester.gp1.set(2.f);
2143 
2144         // Verify that 8-bit non-boolean GP values are parsed as integers, not chars.
2145         tester.gp_int8.set_from_string("68");
2146         tester.gp_uint8.set_from_string("67");
2147         tester.gp_char.set_from_string("107");
2148         tester.gp_schar.set_from_string("108");
2149         tester.gp_uchar.set_from_string("109");
2150         tester.gp_bool.set_from_string("false");
2151 
2152         tester.build_pipeline();
2153         internal_assert(tester.phase == GeneratorBase::ScheduleCalled);
2154 
2155         // tester.set_inputs_vector({{StubInput(45)}});  // This will assert-fail.
2156         // tester.gp2.set(2);  // This will assert-fail.
2157         // tester.sp2.set(202);  // This will assert-fail.
2158     }
2159 
2160     // Verify that set_inputs() works properly, even if the specific subtype of Generator is not known.
2161     {
2162         class Tester : public Generator<Tester> {
2163         public:
2164             Input<int> input_int{"input_int"};
2165             Input<float> input_float{"input_float"};
2166             Input<uint8_t> input_byte{"input_byte"};
2167             Input<uint64_t[4]> input_scalar_array{"input_scalar_array"};
2168             Input<Func> input_func_typed{"input_func_typed", Int(16), 1};
2169             Input<Func> input_func_untyped{"input_func_untyped", 1};
2170             Input<Func[]> input_func_array{"input_func_array", 1};
2171             Input<Buffer<uint8_t>> input_buffer_typed{"input_buffer_typed", 3};
2172             Input<Buffer<>> input_buffer_untyped{"input_buffer_untyped"};
2173             Output<Func> output{"output", Float(32), 1};
2174 
2175             void generate() {
2176                 Var x;
2177                 output(x) = input_int +
2178                             input_float +
2179                             input_byte +
2180                             input_scalar_array[3] +
2181                             input_func_untyped(x) +
2182                             input_func_typed(x) +
2183                             input_func_array[0](x) +
2184                             input_buffer_typed(x, 0, 0) +
2185                             input_buffer_untyped(x, Halide::_);
2186             }
2187             void schedule() {
2188                 // nothing
2189             }
2190         };
2191 
2192         Tester tester_instance;
2193         tester_instance.init_from_context(context);
2194         // Use a base-typed reference to verify the code below doesn't know about subtype
2195         GeneratorBase &tester = tester_instance;
2196 
2197         const int i = 1234;
2198         const float f = 2.25f;
2199         const uint8_t b = 0x42;
2200         const std::vector<uint64_t> a = {1, 2, 3, 4};
2201         Var x;
2202         Func fn_typed, fn_untyped;
2203         fn_typed(x) = cast<int16_t>(38);
2204         fn_untyped(x) = 32.f;
2205         const std::vector<Func> fn_array = {fn_untyped, fn_untyped};
2206 
2207         Buffer<uint8_t> buf_typed(1, 1, 1);
2208         Buffer<float> buf_untyped(1);
2209 
2210         buf_typed.fill(33);
2211         buf_untyped.fill(34);
2212 
2213         // set_inputs() requires inputs in Input<>-decl-order,
2214         // and all inputs match type exactly.
2215         tester.set_inputs(i, f, b, a, fn_typed, fn_untyped, fn_array, buf_typed, buf_untyped);
2216         tester.call_generate();
2217         tester.call_schedule();
2218 
2219         Buffer<float> im = tester_instance.realize(1);
2220         internal_assert(im.dim(0).extent() == 1);
2221         internal_assert(im(0) == 1475.25f) << "Expected 1475.25 but saw " << im(0);
2222     }
2223 
2224     // Verify that array inputs and outputs are typed correctly.
2225     {
2226         class Tester : public Generator<Tester> {
2227         public:
2228             Input<int[]> expr_array_input{"expr_array_input"};
2229             Input<Func[]> func_array_input{"input_func_array"};
2230             Input<Buffer<>[]> buffer_array_input { "buffer_array_input" };
2231 
2232             Input<int[]> expr_array_output{"expr_array_output"};
2233             Output<Func[]> func_array_output{"func_array_output"};
2234             Output<Buffer<>[]> buffer_array_output { "buffer_array_output" };
2235 
2236             void generate() {
2237             }
2238         };
2239 
2240         Tester tester_instance;
2241 
2242         static_assert(std::is_same<decltype(tester_instance.expr_array_input[0]), const Expr &>::value, "type mismatch");
2243         static_assert(std::is_same<decltype(tester_instance.expr_array_output[0]), const Expr &>::value, "type mismatch");
2244 
2245         static_assert(std::is_same<decltype(tester_instance.func_array_input[0]), const Func &>::value, "type mismatch");
2246         static_assert(std::is_same<decltype(tester_instance.func_array_output[0]), Func &>::value, "type mismatch");
2247 
2248         static_assert(std::is_same<decltype(tester_instance.buffer_array_input[0]), ImageParam>::value, "type mismatch");
2249         static_assert(std::is_same<decltype(tester_instance.buffer_array_output[0]), const Func &>::value, "type mismatch");
2250     }
2251 
2252     class GPTester : public Generator<GPTester> {
2253     public:
2254         GeneratorParam<int> gp{"gp", 0};
2255         Output<Func> output{"output", Int(32), 0};
2256         void generate() {
2257             output() = 0;
2258         }
2259         void schedule() {
2260         }
2261     };
2262     GPTester gp_tester;
2263     gp_tester.init_from_context(context);
2264     // Accessing the GeneratorParam will assert-fail if we
2265     // don't do some minimal setup here.
2266     gp_tester.set_inputs_vector({});
2267     gp_tester.call_generate();
2268     gp_tester.call_schedule();
2269     auto &gp = gp_tester.gp;
2270 
2271     // Verify that RDom parameter-pack variants can convert GeneratorParam to Expr
2272     RDom rdom(0, gp, 0, gp);
2273 
2274     // Verify that Func parameter-pack variants can convert GeneratorParam to Expr
2275     Var x, y;
2276     Func f, g;
2277     f(x, y) = x + y;
2278     g(x, y) = f(gp, gp);  // check Func::operator() overloads
2279     g(rdom.x, rdom.y) += f(rdom.x, rdom.y);
2280     g.update(0).reorder(rdom.y, rdom.x);  // check Func::reorder() overloads for RDom::operator RVar()
2281 
2282     // Verify that print() parameter-pack variants can convert GeneratorParam to Expr
2283     print(f(0, 0), g(1, 1), gp);
2284     print_when(true, f(0, 0), g(1, 1), gp);
2285 
2286     // Verify that Tuple parameter-pack variants can convert GeneratorParam to Expr
2287     Tuple t(gp, gp, gp);
2288 
2289     std::cout << "Generator test passed" << std::endl;
2290 }
2291 
2292 }  // namespace Internal
2293 }  // namespace Halide
2294