1 #include "AddImageChecks.h"
2 #include "ExternFuncArgument.h"
3 #include "Function.h"
4 #include "IRMutator.h"
5 #include "IROperator.h"
6 #include "IRVisitor.h"
7 #include "Simplify.h"
8 #include "Substitute.h"
9 #include "Target.h"
10 
11 namespace Halide {
12 namespace Internal {
13 
14 using std::map;
15 using std::pair;
16 using std::string;
17 using std::vector;
18 
19 /* Find all the externally referenced buffers in a stmt */
20 class FindBuffers : public IRGraphVisitor {
21 public:
22     struct Result {
23         Buffer<> image;
24         Parameter param;
25         Type type;
26         int dimensions{0};
27         bool used_on_host{false};
28     };
29 
30     map<string, Result> buffers;
31     bool in_device_loop = false;
32 
33     using IRGraphVisitor::visit;
34 
visit(const For * op)35     void visit(const For *op) override {
36         op->min.accept(this);
37         op->extent.accept(this);
38         bool old = in_device_loop;
39         if (op->device_api != DeviceAPI::None &&
40             op->device_api != DeviceAPI::Host) {
41             in_device_loop = true;
42         }
43         op->body.accept(this);
44         in_device_loop = old;
45     }
46 
visit(const Call * op)47     void visit(const Call *op) override {
48         IRGraphVisitor::visit(op);
49         if (op->image.defined()) {
50             Result &r = buffers[op->name];
51             r.image = op->image;
52             r.type = op->type.element_of();
53             r.dimensions = (int)op->args.size();
54             r.used_on_host = r.used_on_host || (!in_device_loop);
55         } else if (op->param.defined()) {
56             Result &r = buffers[op->name];
57             r.param = op->param;
58             r.type = op->type.element_of();
59             r.dimensions = (int)op->args.size();
60             r.used_on_host = r.used_on_host || (!in_device_loop);
61         }
62     }
63 
visit(const Provide * op)64     void visit(const Provide *op) override {
65         IRGraphVisitor::visit(op);
66         if (op->values.size() == 1) {
67             auto it = buffers.find(op->name);
68             if (it != buffers.end() && !in_device_loop) {
69                 it->second.used_on_host = true;
70             }
71         } else {
72             for (size_t i = 0; i < op->values.size(); i++) {
73                 string name = op->name + "." + std::to_string(i);
74                 auto it = buffers.find(name);
75                 if (it != buffers.end() && !in_device_loop) {
76                     it->second.used_on_host = true;
77                 }
78             }
79         }
80     }
81 
visit(const Variable * op)82     void visit(const Variable *op) override {
83         if (op->param.defined() &&
84             op->param.is_buffer() &&
85             buffers.find(op->param.name()) == buffers.end()) {
86             Result r;
87             r.param = op->param;
88             r.type = op->param.type();
89             r.dimensions = op->param.dimensions();
90             r.used_on_host = false;
91             buffers[op->param.name()] = r;
92         } else if (op->reduction_domain.defined()) {
93             // The bounds of reduction domains are not yet defined,
94             // and they may be the only reference to some parameters.
95             op->reduction_domain.accept(this);
96         }
97     }
98 };
99 
100 class TrimStmtToPartsThatAccessBuffers : public IRMutator {
101     bool touches_buffer = false;
102     const map<string, FindBuffers::Result> &buffers;
103 
104     using IRMutator::visit;
105 
visit(const Call * op)106     Expr visit(const Call *op) override {
107         touches_buffer |= (buffers.count(op->name) > 0);
108         return IRMutator::visit(op);
109     }
visit(const Provide * op)110     Stmt visit(const Provide *op) override {
111         touches_buffer |= (buffers.find(op->name) != buffers.end());
112         return IRMutator::visit(op);
113     }
visit(const Variable * op)114     Expr visit(const Variable *op) override {
115         if (op->type.is_handle() && op->param.defined() && op->param.is_buffer()) {
116             touches_buffer |= (buffers.find(op->param.name()) != buffers.end());
117         }
118         return IRMutator::visit(op);
119     }
120 
visit(const Block * op)121     Stmt visit(const Block *op) override {
122         bool old_touches_buffer = touches_buffer;
123         touches_buffer = false;
124         Stmt first = mutate(op->first);
125         old_touches_buffer |= touches_buffer;
126         if (!touches_buffer) {
127             first = Evaluate::make(0);
128         }
129         touches_buffer = false;
130         Stmt rest = mutate(op->rest);
131         old_touches_buffer |= touches_buffer;
132         if (!touches_buffer) {
133             rest = Evaluate::make(0);
134         }
135         touches_buffer = old_touches_buffer;
136         return Block::make(first, rest);
137     }
138 
139 public:
TrimStmtToPartsThatAccessBuffers(const map<string,FindBuffers::Result> & bufs)140     TrimStmtToPartsThatAccessBuffers(const map<string, FindBuffers::Result> &bufs)
141         : buffers(bufs) {
142     }
143 };
144 
add_image_checks_inner(Stmt s,const vector<Function> & outputs,const Target & t,const vector<string> & order,const map<string,Function> & env,const FuncValueBounds & fb,bool will_inject_host_copies)145 Stmt add_image_checks_inner(Stmt s,
146                             const vector<Function> &outputs,
147                             const Target &t,
148                             const vector<string> &order,
149                             const map<string, Function> &env,
150                             const FuncValueBounds &fb,
151                             bool will_inject_host_copies) {
152 
153     bool no_asserts = t.has_feature(Target::NoAsserts);
154     bool no_bounds_query = t.has_feature(Target::NoBoundsQuery);
155 
156     // First hunt for all the referenced buffers
157     FindBuffers finder;
158     map<string, FindBuffers::Result> &bufs = finder.buffers;
159 
160     // Add the output buffer(s).
161     for (Function f : outputs) {
162         for (size_t i = 0; i < f.values().size(); i++) {
163             FindBuffers::Result output_buffer;
164             output_buffer.type = f.values()[i].type();
165             output_buffer.param = f.output_buffers()[i];
166             output_buffer.dimensions = f.dimensions();
167             if (f.values().size() > 1) {
168                 bufs[f.name() + '.' + std::to_string(i)] = output_buffer;
169             } else {
170                 bufs[f.name()] = output_buffer;
171             }
172         }
173     }
174 
175     // Add the input buffer(s) and annotate which output buffers are
176     // used on host.
177     s.accept(&finder);
178 
179     Scope<Interval> empty_scope;
180     Stmt sub_stmt = TrimStmtToPartsThatAccessBuffers(bufs).mutate(s);
181     map<string, Box> boxes = boxes_touched(sub_stmt, empty_scope, fb);
182 
183     // Now iterate through all the buffers, creating a list of lets
184     // and a list of asserts.
185     vector<pair<string, Expr>> lets_overflow;
186     vector<pair<string, Expr>> lets_required;
187     vector<pair<string, Expr>> lets_constrained;
188     vector<pair<string, Expr>> lets_proposed;
189     vector<Stmt> dims_no_overflow_asserts;
190     vector<Stmt> asserts_required;
191     vector<Stmt> asserts_constrained;
192     vector<Stmt> asserts_proposed;
193     vector<Stmt> asserts_type_checks;
194     vector<Stmt> asserts_host_alignment;
195     vector<Stmt> asserts_host_non_null;
196     vector<Stmt> asserts_device_not_dirty;
197     vector<Stmt> buffer_rewrites;
198     vector<Stmt> msan_checks;
199 
200     // Inject the code that conditionally returns if we're in inference mode
201     Expr maybe_return_condition = const_false();
202 
203     // We're also going to apply the constraints to the required min
204     // and extent. To do this we have to substitute all references to
205     // the actual sizes of the input images in the constraints with
206     // references to the required sizes.
207     map<string, Expr> replace_with_required;
208 
209     for (const pair<const string, FindBuffers::Result> &buf : bufs) {
210         const string &name = buf.first;
211 
212         for (int i = 0; i < buf.second.dimensions; i++) {
213             string dim = std::to_string(i);
214 
215             Expr min_required = Variable::make(Int(32), name + ".min." + dim + ".required");
216             replace_with_required[name + ".min." + dim] = min_required;
217 
218             Expr extent_required = Variable::make(Int(32), name + ".extent." + dim + ".required");
219             replace_with_required[name + ".extent." + dim] = simplify(extent_required);
220 
221             Expr stride_required = Variable::make(Int(32), name + ".stride." + dim + ".required");
222             replace_with_required[name + ".stride." + dim] = stride_required;
223         }
224     }
225 
226     // We also want to build a map that lets us replace values passed
227     // in with the constrained version. This is applied to the rest of
228     // the lowered pipeline to take advantage of the constraints,
229     // e.g. for constant folding.
230     map<string, Expr> replace_with_constrained;
231 
232     for (pair<const string, FindBuffers::Result> &buf : bufs) {
233         const string &name = buf.first;
234         Buffer<> &image = buf.second.image;
235         Parameter &param = buf.second.param;
236         Type type = buf.second.type;
237         int dimensions = buf.second.dimensions;
238         bool used_on_host = buf.second.used_on_host;
239 
240         // Detect if this is one of the outputs of a multi-output pipeline.
241         bool is_output_buffer = false;
242         bool is_secondary_output_buffer = false;
243         string buffer_name = name;
244         for (Function f : outputs) {
245             for (size_t i = 0; i < f.output_buffers().size(); i++) {
246                 if (param.defined() &&
247                     param.same_as(f.output_buffers()[i])) {
248                     is_output_buffer = true;
249                     // If we're one of multiple output buffers, we should use the
250                     // region inferred for the func in general.
251                     buffer_name = f.name();
252                     if (i > 0) {
253                         is_secondary_output_buffer = true;
254                     }
255                 }
256             }
257         }
258 
259         Box touched = boxes[buffer_name];
260         internal_assert(touched.empty() || (int)(touched.size()) == dimensions);
261 
262         // The buffer may be used in one or more extern stage. If so we need to
263         // expand the box touched to include the results of the
264         // top-level bounds query calls to those extern stages.
265         if (param.defined()) {
266             // Find the extern users.
267             vector<string> extern_users;
268             for (size_t i = 0; i < order.size(); i++) {
269                 Function f = env.find(order[i])->second;
270                 if (f.has_extern_definition() &&
271                     !f.extern_definition_proxy_expr().defined()) {
272                     const vector<ExternFuncArgument> &args = f.extern_arguments();
273                     for (size_t j = 0; j < args.size(); j++) {
274                         if ((args[j].image_param.defined() &&
275                              args[j].image_param.name() == param.name()) ||
276                             (args[j].buffer.defined() &&
277                              args[j].buffer.name() == param.name())) {
278                             extern_users.push_back(order[i]);
279                         }
280                     }
281                 }
282             }
283 
284             // Expand the box by the result of the bounds query from each.
285             for (size_t i = 0; i < extern_users.size(); i++) {
286                 const string &extern_user = extern_users[i];
287                 Box query_box;
288                 Expr query_buf = Variable::make(type_of<struct halide_buffer_t *>(),
289                                                 param.name() + ".bounds_query." + extern_user);
290                 for (int j = 0; j < dimensions; j++) {
291                     Expr min = Call::make(Int(32), Call::buffer_get_min,
292                                           {query_buf, j}, Call::Extern);
293                     Expr max = Call::make(Int(32), Call::buffer_get_max,
294                                           {query_buf, j}, Call::Extern);
295                     query_box.push_back(Interval(min, max));
296                 }
297                 merge_boxes(touched, query_box);
298             }
299         }
300 
301         ReductionDomain rdom;
302 
303         // An expression returning whether or not we're in inference mode
304         string buf_name = name + ".buffer";
305         Expr handle = Variable::make(type_of<halide_buffer_t *>(), buf_name,
306                                      image, param, rdom);
307         Expr inference_mode = Call::make(Bool(), Call::buffer_is_bounds_query,
308                                          {handle}, Call::Extern);
309         maybe_return_condition = maybe_return_condition || inference_mode;
310 
311         // Come up with a name to refer to this buffer in the error messages
312         string error_name = (is_output_buffer ? "Output" : "Input");
313         error_name += " buffer " + name;
314 
315         if (!is_output_buffer && t.has_feature(Target::MSAN)) {
316             Expr buffer = Variable::make(type_of<struct halide_buffer_t *>(), buf_name);
317             Stmt check_contents = Evaluate::make(
318                 Call::make(Int(32), "halide_msan_check_buffer_is_initialized", {buffer, Expr(buf_name)}, Call::Extern));
319             msan_checks.push_back(check_contents);
320         }
321 
322         // Check the type matches the internally-understood type
323         {
324             string type_name = name + ".type";
325             Expr type_var = Variable::make(UInt(32), type_name, image, param, rdom);
326             uint32_t correct_type_bits = ((halide_type_t)type).as_u32();
327             Expr correct_type_expr = make_const(UInt(32), correct_type_bits);
328             Expr error = Call::make(Int(32), "halide_error_bad_type",
329                                     {error_name, type_var, correct_type_expr},
330                                     Call::Extern);
331             Stmt type_check = AssertStmt::make(type_var == correct_type_expr, error);
332             asserts_type_checks.push_back(type_check);
333         }
334 
335         // Check the dimensions matches the internally-understood dimensions
336         {
337             string dimensions_name = name + ".dimensions";
338             Expr dimensions_given = Variable::make(Int(32), dimensions_name, image, param, rdom);
339             Expr error = Call::make(Int(32), "halide_error_bad_dimensions",
340                                     {error_name,
341                                      dimensions_given, make_const(Int(32), dimensions)},
342                                     Call::Extern);
343             asserts_type_checks.push_back(
344                 AssertStmt::make(dimensions_given == dimensions, error));
345         }
346 
347         if (touched.maybe_unused()) {
348             debug(3) << "Image " << name << " is only used when " << touched.used << "\n";
349         }
350 
351         // Check that the region passed in (after applying constraints) is within the region used
352         debug(3) << "In image " << name << " region touched is:\n";
353 
354         for (int j = 0; j < dimensions; j++) {
355             string dim = std::to_string(j);
356             string actual_min_name = name + ".min." + dim;
357             string actual_extent_name = name + ".extent." + dim;
358             string actual_stride_name = name + ".stride." + dim;
359             Expr actual_min = Variable::make(Int(32), actual_min_name, image, param, rdom);
360             Expr actual_extent = Variable::make(Int(32), actual_extent_name, image, param, rdom);
361             Expr actual_stride = Variable::make(Int(32), actual_stride_name, image, param, rdom);
362 
363             if (!touched.empty() && !touched[j].is_bounded()) {
364                 user_error << "Buffer " << name
365                            << " may be accessed in an unbounded way in dimension "
366                            << j << "\n";
367             }
368 
369             Expr min_required = touched.empty() ? actual_min : touched[j].min;
370             Expr extent_required = touched.empty() ? actual_extent : (touched[j].max + 1 - touched[j].min);
371 
372             if (touched.maybe_unused()) {
373                 min_required = select(touched.used, min_required, actual_min);
374                 extent_required = select(touched.used, extent_required, actual_extent);
375             }
376 
377             string min_required_name = name + ".min." + dim + ".required";
378             string extent_required_name = name + ".extent." + dim + ".required";
379 
380             Expr min_required_var = Variable::make(Int(32), min_required_name);
381             Expr extent_required_var = Variable::make(Int(32), extent_required_name);
382 
383             lets_required.emplace_back(extent_required_name, extent_required);
384             lets_required.emplace_back(min_required_name, min_required);
385 
386             Expr actual_max = actual_min + actual_extent - 1;
387             Expr max_required = min_required_var + extent_required_var - 1;
388 
389             if (touched.maybe_unused()) {
390                 max_required = select(touched.used, max_required, actual_max);
391             }
392 
393             Expr oob_condition = actual_min <= min_required_var && actual_max >= max_required;
394 
395             Expr oob_error = Call::make(Int(32), "halide_error_access_out_of_bounds",
396                                         {error_name, j, min_required_var, max_required, actual_min, actual_max},
397                                         Call::Extern);
398 
399             asserts_required.push_back(AssertStmt::make(oob_condition, oob_error));
400 
401             // Come up with a required stride to use in bounds
402             // inference mode. We don't assert it. It's just used to
403             // apply the constraints to to come up with a proposed
404             // stride. Strides actually passed in may not be in this
405             // order (e.g if storage is swizzled relative to dimension
406             // order).
407             Expr stride_required;
408             if (j == 0) {
409                 stride_required = 1;
410             } else {
411                 string last_dim = std::to_string(j - 1);
412                 stride_required = (Variable::make(Int(32), name + ".stride." + last_dim + ".required") *
413                                    Variable::make(Int(32), name + ".extent." + last_dim + ".required"));
414             }
415             lets_required.emplace_back(name + ".stride." + dim + ".required", stride_required);
416 
417             // On 32-bit systems, insert checks to make sure the total
418             // size of all input and output buffers is <= 2^31 - 1.
419             // And that no product of extents overflows 2^31 - 1. This
420             // second test is likely only needed if a fuse directive
421             // is used in the schedule to combine multiple extents,
422             // but it is here for extra safety. On 64-bit targets with the
423             // LargeBuffers feature, the maximum size is 2^63 - 1.
424             Expr max_size = make_const(UInt(64), t.maximum_buffer_size());
425             Expr max_extent = make_const(UInt(64), 0x7fffffff);
426             Expr actual_size = abs(cast<int64_t>(actual_extent) * actual_stride);
427             Expr allocation_size_error = Call::make(Int(32), "halide_error_buffer_allocation_too_large",
428                                                     {name, actual_size, max_size}, Call::Extern);
429             Stmt check = AssertStmt::make(actual_size <= max_size, allocation_size_error);
430             dims_no_overflow_asserts.push_back(check);
431 
432             // Don't repeat extents check for secondary buffers as extents must be the same as for the first one.
433             if (!is_secondary_output_buffer) {
434                 if (j == 0) {
435                     lets_overflow.emplace_back(name + ".total_extent." + dim, cast<int64_t>(actual_extent));
436                 } else {
437                     max_size = cast<int64_t>(max_size);
438                     Expr last_dim = Variable::make(Int(64), name + ".total_extent." + std::to_string(j - 1));
439                     Expr this_dim = actual_extent * last_dim;
440                     Expr this_dim_var = Variable::make(Int(64), name + ".total_extent." + dim);
441                     lets_overflow.emplace_back(name + ".total_extent." + dim, this_dim);
442                     Expr error = Call::make(Int(32), "halide_error_buffer_extents_too_large",
443                                             {name, this_dim_var, max_size}, Call::Extern);
444                     Stmt check = AssertStmt::make(this_dim_var <= max_size, error);
445                     dims_no_overflow_asserts.push_back(check);
446                 }
447 
448                 // It is never legal to have a negative buffer extent.
449                 Expr negative_extent_condition = actual_extent >= 0;
450                 Expr negative_extent_error = Call::make(Int(32), "halide_error_buffer_extents_negative",
451                                                         {error_name, j, actual_extent}, Call::Extern);
452                 asserts_required.push_back(AssertStmt::make(negative_extent_condition, negative_extent_error));
453             }
454         }
455 
456         // Create code that mutates the input buffers if we're in bounds inference mode.
457         BufferBuilder builder;
458         builder.buffer_memory = Variable::make(type_of<struct halide_buffer_t *>(), buf_name);
459         builder.shape_memory = Call::make(type_of<struct halide_dimension_t *>(),
460                                           Call::buffer_get_shape, {builder.buffer_memory},
461                                           Call::Extern);
462         builder.type = type;
463         builder.dimensions = dimensions;
464         for (int i = 0; i < dimensions; i++) {
465             string dim = std::to_string(i);
466             builder.mins.push_back(Variable::make(Int(32), name + ".min." + dim + ".proposed"));
467             builder.extents.push_back(Variable::make(Int(32), name + ".extent." + dim + ".proposed"));
468             builder.strides.push_back(Variable::make(Int(32), name + ".stride." + dim + ".proposed"));
469         }
470         Stmt rewrite = Evaluate::make(builder.build());
471 
472         rewrite = IfThenElse::make(inference_mode, rewrite);
473         buffer_rewrites.push_back(rewrite);
474 
475         // Build the constraints tests and proposed sizes.
476         vector<pair<Expr, Expr>> constraints;
477         for (int i = 0; i < dimensions; i++) {
478             string dim = std::to_string(i);
479             string min_name = name + ".min." + dim;
480             string stride_name = name + ".stride." + dim;
481             string extent_name = name + ".extent." + dim;
482 
483             Expr stride_constrained, extent_constrained, min_constrained;
484 
485             Expr stride_orig = Variable::make(Int(32), stride_name, image, param, rdom);
486             Expr extent_orig = Variable::make(Int(32), extent_name, image, param, rdom);
487             Expr min_orig = Variable::make(Int(32), min_name, image, param, rdom);
488 
489             Expr stride_required = Variable::make(Int(32), stride_name + ".required");
490             Expr extent_required = Variable::make(Int(32), extent_name + ".required");
491             Expr min_required = Variable::make(Int(32), min_name + ".required");
492 
493             Expr stride_proposed = Variable::make(Int(32), stride_name + ".proposed");
494             Expr extent_proposed = Variable::make(Int(32), extent_name + ".proposed");
495             Expr min_proposed = Variable::make(Int(32), min_name + ".proposed");
496 
497             debug(2) << "Injecting constraints for " << name << "." << i << "\n";
498             if (is_secondary_output_buffer) {
499                 // For multi-output (Tuple) pipelines, output buffers
500                 // beyond the first implicitly have their min and extent
501                 // constrained to match the first output.
502 
503                 if (param.defined()) {
504                     user_assert(!param.extent_constraint(i).defined() &&
505                                 !param.min_constraint(i).defined())
506                         << "Can't constrain the min or extent of an output buffer beyond the "
507                         << "first. They are implicitly constrained to have the same min and extent "
508                         << "as the first output buffer.\n";
509 
510                     stride_constrained = param.stride_constraint(i);
511                 } else if (image.defined() && (int)i < image.dimensions()) {
512                     stride_constrained = image.dim(i).stride();
513                 }
514 
515                 std::string min0_name = buffer_name + ".0.min." + dim;
516                 if (replace_with_constrained.count(min0_name) > 0) {
517                     min_constrained = replace_with_constrained[min0_name];
518                 } else {
519                     min_constrained = Variable::make(Int(32), min0_name);
520                 }
521 
522                 std::string extent0_name = buffer_name + ".0.extent." + dim;
523                 if (replace_with_constrained.count(extent0_name) > 0) {
524                     extent_constrained = replace_with_constrained[extent0_name];
525                 } else {
526                     extent_constrained = Variable::make(Int(32), extent0_name);
527                 }
528             } else if (image.defined() && (int)i < image.dimensions()) {
529                 stride_constrained = image.dim(i).stride();
530                 extent_constrained = image.dim(i).extent();
531                 min_constrained = image.dim(i).min();
532             } else if (param.defined()) {
533                 stride_constrained = param.stride_constraint(i);
534                 extent_constrained = param.extent_constraint(i);
535                 min_constrained = param.min_constraint(i);
536             }
537 
538             if (stride_constrained.defined()) {
539                 // Come up with a suggested stride by passing the
540                 // required region through this constraint.
541                 constraints.emplace_back(stride_orig, stride_constrained);
542                 stride_constrained = substitute(replace_with_required, stride_constrained);
543                 lets_proposed.emplace_back(stride_name + ".proposed", stride_constrained);
544             } else {
545                 lets_proposed.emplace_back(stride_name + ".proposed", stride_required);
546             }
547 
548             if (min_constrained.defined()) {
549                 constraints.emplace_back(min_orig, min_constrained);
550                 min_constrained = substitute(replace_with_required, min_constrained);
551                 lets_proposed.emplace_back(min_name + ".proposed", min_constrained);
552             } else {
553                 lets_proposed.emplace_back(min_name + ".proposed", min_required);
554             }
555 
556             if (extent_constrained.defined()) {
557                 constraints.emplace_back(extent_orig, extent_constrained);
558                 extent_constrained = substitute(replace_with_required, extent_constrained);
559                 lets_proposed.emplace_back(extent_name + ".proposed", extent_constrained);
560             } else {
561                 lets_proposed.emplace_back(extent_name + ".proposed", extent_required);
562             }
563 
564             // In bounds inference mode, make sure the proposed
565             // versions still satisfy the constraints.
566             Expr max_proposed = min_proposed + extent_proposed - 1;
567             Expr max_required = min_required + extent_required - 1;
568             Expr check = (min_proposed <= min_required) && (max_proposed >= max_required);
569             Expr error = Call::make(Int(32), "halide_error_constraints_make_required_region_smaller",
570                                     {error_name, i, min_proposed, max_proposed, min_required, max_required},
571                                     Call::Extern);
572             asserts_proposed.push_back(AssertStmt::make((!inference_mode) || check, error));
573 
574             // stride_required is just a suggestion. It's ok if the
575             // constraints shuffle them around in ways that make it
576             // smaller.
577             /*
578             check = (stride_proposed >= stride_required);
579             error = "Applying the constraints to the required stride made it smaller";
580             asserts_proposed.push_back(AssertStmt::make((!inference_mode) || check, error, vector<Expr>()));
581             */
582         }
583 
584         // Assert all the conditions, and set the new values
585         for (size_t i = 0; i < constraints.size(); i++) {
586             Expr var = constraints[i].first;
587             const string &name = var.as<Variable>()->name;
588             Expr constrained_var = Variable::make(Int(32), name + ".constrained");
589 
590             std::ostringstream ss;
591             ss << constraints[i].second;
592             string constrained_var_str = ss.str();
593 
594             replace_with_constrained[name] = constrained_var;
595 
596             lets_constrained.emplace_back(name + ".constrained", constraints[i].second);
597 
598             Expr error = 0;
599             if (!no_asserts) {
600                 error = Call::make(Int(32), "halide_error_constraint_violated",
601                                    {name, var, constrained_var_str, constrained_var},
602                                    Call::Extern);
603             }
604 
605             // Check the var passed in equals the constrained version (when not in inference mode)
606             asserts_constrained.push_back(AssertStmt::make(var == constrained_var, error));
607         }
608 
609         // For the buffers used on host, check the host field is non-null
610         Expr host_ptr = Variable::make(Handle(), name, image, param, ReductionDomain());
611         if (used_on_host) {
612             Expr error = Call::make(Int(32), "halide_error_host_is_null",
613                                     {error_name}, Call::Extern);
614             Expr check = (host_ptr != make_zero(host_ptr.type()));
615             if (touched.maybe_unused()) {
616                 check = !touched.used || check;
617             }
618             asserts_host_non_null.push_back(AssertStmt::make(check, error));
619 
620             if (!will_inject_host_copies) {
621                 Expr device_dirty = Variable::make(Bool(), name + ".device_dirty",
622                                                    image, param, ReductionDomain());
623 
624                 Expr error = Call::make(Int(32), "halide_error_device_dirty_with_no_device_support",
625                                         {error_name}, Call::Extern);
626 
627                 // If we have no device support, we can't handle
628                 // device_dirty, so every buffer touched needs checking.
629                 asserts_device_not_dirty.push_back(AssertStmt::make(!device_dirty, error));
630             }
631         }
632 
633         // and check alignment of the host field
634         if (param.defined() && param.host_alignment() != param.type().bytes()) {
635             int alignment_required = param.host_alignment();
636             Expr u64t_host_ptr = reinterpret<uint64_t>(host_ptr);
637             Expr align_condition = (u64t_host_ptr % alignment_required) == 0;
638             Expr error = Call::make(Int(32), "halide_error_unaligned_host_ptr",
639                                     {name, alignment_required}, Call::Extern);
640             asserts_host_alignment.push_back(AssertStmt::make(align_condition, error));
641         }
642     }
643 
644     auto prepend_stmts = [&](vector<Stmt> *stmts) {
645         while (!stmts->empty()) {
646             s = Block::make(std::move(stmts->back()), s);
647             stmts->pop_back();
648         }
649     };
650 
651     auto prepend_lets = [&](vector<pair<string, Expr>> *lets) {
652         while (!lets->empty()) {
653             auto &p = lets->back();
654             s = LetStmt::make(p.first, std::move(p.second), s);
655             lets->pop_back();
656         }
657     };
658 
659     if (!no_asserts) {
660         // Inject the code that checks the host pointers.
661         prepend_stmts(&asserts_host_non_null);
662         prepend_stmts(&asserts_host_alignment);
663         prepend_stmts(&asserts_device_not_dirty);
664         prepend_stmts(&dims_no_overflow_asserts);
665         prepend_lets(&lets_overflow);
666     }
667 
668     // Replace uses of the var with the constrained versions in the
669     // rest of the program. We also need to respect the existence of
670     // constrained versions during storage flattening and bounds
671     // inference.
672     s = substitute(replace_with_constrained, s);
673 
674     // Now we add a bunch of code to the top of the pipeline. This is
675     // all in reverse order compared to execution, as we incrementally
676     // prepending code.
677 
678     // Inject the code that checks the constraints are correct. We
679     // need these regardless of how NoAsserts is set, because they are
680     // what gets Halide to actually exploit the constraint.
681     prepend_stmts(&asserts_constrained);
682 
683     if (!no_asserts) {
684         prepend_stmts(&asserts_required);
685         prepend_stmts(&asserts_type_checks);
686     }
687 
688     // Inject the code that returns early for inference mode.
689     if (!no_bounds_query) {
690         s = IfThenElse::make(!maybe_return_condition, s);
691         prepend_stmts(&buffer_rewrites);
692     }
693 
694     if (!no_asserts) {
695         prepend_stmts(&asserts_proposed);
696     }
697 
698     // Inject the code that defines the proposed sizes.
699     prepend_lets(&lets_proposed);
700 
701     // Inject the code that defines the constrained sizes.
702     prepend_lets(&lets_constrained);
703 
704     // Inject the code that defines the required sizes produced by bounds inference.
705     prepend_lets(&lets_required);
706 
707     // Inject the code that checks that does msan checks. (Note that this ignores no_asserts.)
708     prepend_stmts(&msan_checks);
709 
710     return s;
711 }
712 
713 // The following function repeats the arguments list it just passes
714 // through six times. Surely there is a better way?
add_image_checks(const Stmt & s,const vector<Function> & outputs,const Target & t,const vector<string> & order,const map<string,Function> & env,const FuncValueBounds & fb,bool will_inject_host_copies)715 Stmt add_image_checks(const Stmt &s,
716                       const vector<Function> &outputs,
717                       const Target &t,
718                       const vector<string> &order,
719                       const map<string, Function> &env,
720                       const FuncValueBounds &fb,
721                       bool will_inject_host_copies) {
722 
723     // Checks for images go at the marker deposited by computation
724     // bounds inference.
725     class Injector : public IRMutator {
726         using IRMutator::visit;
727 
728         Stmt visit(const Block *op) override {
729             const Evaluate *e = op->first.as<Evaluate>();
730             const Call *c = e ? e->value.as<Call>() : nullptr;
731             if (c && c->is_intrinsic(Call::add_image_checks_marker)) {
732                 return add_image_checks_inner(op->rest, outputs, t, order, env, fb, will_inject_host_copies);
733             } else {
734                 return IRMutator::visit(op);
735             }
736         }
737 
738         const vector<Function> &outputs;
739         const Target &t;
740         const vector<string> &order;
741         const map<string, Function> &env;
742         const FuncValueBounds &fb;
743         bool will_inject_host_copies;
744 
745     public:
746         Injector(const vector<Function> &outputs,
747                  const Target &t,
748                  const vector<string> &order,
749                  const map<string, Function> &env,
750                  const FuncValueBounds &fb,
751                  bool will_inject_host_copies)
752             : outputs(outputs), t(t), order(order), env(env), fb(fb), will_inject_host_copies(will_inject_host_copies) {
753         }
754     } injector(outputs, t, order, env, fb, will_inject_host_copies);
755 
756     return injector.mutate(s);
757 }
758 
759 }  // namespace Internal
760 }  // namespace Halide
761