1 #include "AddImageChecks.h"
2 #include "ExternFuncArgument.h"
3 #include "Function.h"
4 #include "IRMutator.h"
5 #include "IROperator.h"
6 #include "IRVisitor.h"
7 #include "Simplify.h"
8 #include "Substitute.h"
9 #include "Target.h"
10
11 namespace Halide {
12 namespace Internal {
13
14 using std::map;
15 using std::pair;
16 using std::string;
17 using std::vector;
18
19 /* Find all the externally referenced buffers in a stmt */
20 class FindBuffers : public IRGraphVisitor {
21 public:
22 struct Result {
23 Buffer<> image;
24 Parameter param;
25 Type type;
26 int dimensions{0};
27 bool used_on_host{false};
28 };
29
30 map<string, Result> buffers;
31 bool in_device_loop = false;
32
33 using IRGraphVisitor::visit;
34
visit(const For * op)35 void visit(const For *op) override {
36 op->min.accept(this);
37 op->extent.accept(this);
38 bool old = in_device_loop;
39 if (op->device_api != DeviceAPI::None &&
40 op->device_api != DeviceAPI::Host) {
41 in_device_loop = true;
42 }
43 op->body.accept(this);
44 in_device_loop = old;
45 }
46
visit(const Call * op)47 void visit(const Call *op) override {
48 IRGraphVisitor::visit(op);
49 if (op->image.defined()) {
50 Result &r = buffers[op->name];
51 r.image = op->image;
52 r.type = op->type.element_of();
53 r.dimensions = (int)op->args.size();
54 r.used_on_host = r.used_on_host || (!in_device_loop);
55 } else if (op->param.defined()) {
56 Result &r = buffers[op->name];
57 r.param = op->param;
58 r.type = op->type.element_of();
59 r.dimensions = (int)op->args.size();
60 r.used_on_host = r.used_on_host || (!in_device_loop);
61 }
62 }
63
visit(const Provide * op)64 void visit(const Provide *op) override {
65 IRGraphVisitor::visit(op);
66 if (op->values.size() == 1) {
67 auto it = buffers.find(op->name);
68 if (it != buffers.end() && !in_device_loop) {
69 it->second.used_on_host = true;
70 }
71 } else {
72 for (size_t i = 0; i < op->values.size(); i++) {
73 string name = op->name + "." + std::to_string(i);
74 auto it = buffers.find(name);
75 if (it != buffers.end() && !in_device_loop) {
76 it->second.used_on_host = true;
77 }
78 }
79 }
80 }
81
visit(const Variable * op)82 void visit(const Variable *op) override {
83 if (op->param.defined() &&
84 op->param.is_buffer() &&
85 buffers.find(op->param.name()) == buffers.end()) {
86 Result r;
87 r.param = op->param;
88 r.type = op->param.type();
89 r.dimensions = op->param.dimensions();
90 r.used_on_host = false;
91 buffers[op->param.name()] = r;
92 } else if (op->reduction_domain.defined()) {
93 // The bounds of reduction domains are not yet defined,
94 // and they may be the only reference to some parameters.
95 op->reduction_domain.accept(this);
96 }
97 }
98 };
99
100 class TrimStmtToPartsThatAccessBuffers : public IRMutator {
101 bool touches_buffer = false;
102 const map<string, FindBuffers::Result> &buffers;
103
104 using IRMutator::visit;
105
visit(const Call * op)106 Expr visit(const Call *op) override {
107 touches_buffer |= (buffers.count(op->name) > 0);
108 return IRMutator::visit(op);
109 }
visit(const Provide * op)110 Stmt visit(const Provide *op) override {
111 touches_buffer |= (buffers.find(op->name) != buffers.end());
112 return IRMutator::visit(op);
113 }
visit(const Variable * op)114 Expr visit(const Variable *op) override {
115 if (op->type.is_handle() && op->param.defined() && op->param.is_buffer()) {
116 touches_buffer |= (buffers.find(op->param.name()) != buffers.end());
117 }
118 return IRMutator::visit(op);
119 }
120
visit(const Block * op)121 Stmt visit(const Block *op) override {
122 bool old_touches_buffer = touches_buffer;
123 touches_buffer = false;
124 Stmt first = mutate(op->first);
125 old_touches_buffer |= touches_buffer;
126 if (!touches_buffer) {
127 first = Evaluate::make(0);
128 }
129 touches_buffer = false;
130 Stmt rest = mutate(op->rest);
131 old_touches_buffer |= touches_buffer;
132 if (!touches_buffer) {
133 rest = Evaluate::make(0);
134 }
135 touches_buffer = old_touches_buffer;
136 return Block::make(first, rest);
137 }
138
139 public:
TrimStmtToPartsThatAccessBuffers(const map<string,FindBuffers::Result> & bufs)140 TrimStmtToPartsThatAccessBuffers(const map<string, FindBuffers::Result> &bufs)
141 : buffers(bufs) {
142 }
143 };
144
add_image_checks_inner(Stmt s,const vector<Function> & outputs,const Target & t,const vector<string> & order,const map<string,Function> & env,const FuncValueBounds & fb,bool will_inject_host_copies)145 Stmt add_image_checks_inner(Stmt s,
146 const vector<Function> &outputs,
147 const Target &t,
148 const vector<string> &order,
149 const map<string, Function> &env,
150 const FuncValueBounds &fb,
151 bool will_inject_host_copies) {
152
153 bool no_asserts = t.has_feature(Target::NoAsserts);
154 bool no_bounds_query = t.has_feature(Target::NoBoundsQuery);
155
156 // First hunt for all the referenced buffers
157 FindBuffers finder;
158 map<string, FindBuffers::Result> &bufs = finder.buffers;
159
160 // Add the output buffer(s).
161 for (Function f : outputs) {
162 for (size_t i = 0; i < f.values().size(); i++) {
163 FindBuffers::Result output_buffer;
164 output_buffer.type = f.values()[i].type();
165 output_buffer.param = f.output_buffers()[i];
166 output_buffer.dimensions = f.dimensions();
167 if (f.values().size() > 1) {
168 bufs[f.name() + '.' + std::to_string(i)] = output_buffer;
169 } else {
170 bufs[f.name()] = output_buffer;
171 }
172 }
173 }
174
175 // Add the input buffer(s) and annotate which output buffers are
176 // used on host.
177 s.accept(&finder);
178
179 Scope<Interval> empty_scope;
180 Stmt sub_stmt = TrimStmtToPartsThatAccessBuffers(bufs).mutate(s);
181 map<string, Box> boxes = boxes_touched(sub_stmt, empty_scope, fb);
182
183 // Now iterate through all the buffers, creating a list of lets
184 // and a list of asserts.
185 vector<pair<string, Expr>> lets_overflow;
186 vector<pair<string, Expr>> lets_required;
187 vector<pair<string, Expr>> lets_constrained;
188 vector<pair<string, Expr>> lets_proposed;
189 vector<Stmt> dims_no_overflow_asserts;
190 vector<Stmt> asserts_required;
191 vector<Stmt> asserts_constrained;
192 vector<Stmt> asserts_proposed;
193 vector<Stmt> asserts_type_checks;
194 vector<Stmt> asserts_host_alignment;
195 vector<Stmt> asserts_host_non_null;
196 vector<Stmt> asserts_device_not_dirty;
197 vector<Stmt> buffer_rewrites;
198 vector<Stmt> msan_checks;
199
200 // Inject the code that conditionally returns if we're in inference mode
201 Expr maybe_return_condition = const_false();
202
203 // We're also going to apply the constraints to the required min
204 // and extent. To do this we have to substitute all references to
205 // the actual sizes of the input images in the constraints with
206 // references to the required sizes.
207 map<string, Expr> replace_with_required;
208
209 for (const pair<const string, FindBuffers::Result> &buf : bufs) {
210 const string &name = buf.first;
211
212 for (int i = 0; i < buf.second.dimensions; i++) {
213 string dim = std::to_string(i);
214
215 Expr min_required = Variable::make(Int(32), name + ".min." + dim + ".required");
216 replace_with_required[name + ".min." + dim] = min_required;
217
218 Expr extent_required = Variable::make(Int(32), name + ".extent." + dim + ".required");
219 replace_with_required[name + ".extent." + dim] = simplify(extent_required);
220
221 Expr stride_required = Variable::make(Int(32), name + ".stride." + dim + ".required");
222 replace_with_required[name + ".stride." + dim] = stride_required;
223 }
224 }
225
226 // We also want to build a map that lets us replace values passed
227 // in with the constrained version. This is applied to the rest of
228 // the lowered pipeline to take advantage of the constraints,
229 // e.g. for constant folding.
230 map<string, Expr> replace_with_constrained;
231
232 for (pair<const string, FindBuffers::Result> &buf : bufs) {
233 const string &name = buf.first;
234 Buffer<> &image = buf.second.image;
235 Parameter ¶m = buf.second.param;
236 Type type = buf.second.type;
237 int dimensions = buf.second.dimensions;
238 bool used_on_host = buf.second.used_on_host;
239
240 // Detect if this is one of the outputs of a multi-output pipeline.
241 bool is_output_buffer = false;
242 bool is_secondary_output_buffer = false;
243 string buffer_name = name;
244 for (Function f : outputs) {
245 for (size_t i = 0; i < f.output_buffers().size(); i++) {
246 if (param.defined() &&
247 param.same_as(f.output_buffers()[i])) {
248 is_output_buffer = true;
249 // If we're one of multiple output buffers, we should use the
250 // region inferred for the func in general.
251 buffer_name = f.name();
252 if (i > 0) {
253 is_secondary_output_buffer = true;
254 }
255 }
256 }
257 }
258
259 Box touched = boxes[buffer_name];
260 internal_assert(touched.empty() || (int)(touched.size()) == dimensions);
261
262 // The buffer may be used in one or more extern stage. If so we need to
263 // expand the box touched to include the results of the
264 // top-level bounds query calls to those extern stages.
265 if (param.defined()) {
266 // Find the extern users.
267 vector<string> extern_users;
268 for (size_t i = 0; i < order.size(); i++) {
269 Function f = env.find(order[i])->second;
270 if (f.has_extern_definition() &&
271 !f.extern_definition_proxy_expr().defined()) {
272 const vector<ExternFuncArgument> &args = f.extern_arguments();
273 for (size_t j = 0; j < args.size(); j++) {
274 if ((args[j].image_param.defined() &&
275 args[j].image_param.name() == param.name()) ||
276 (args[j].buffer.defined() &&
277 args[j].buffer.name() == param.name())) {
278 extern_users.push_back(order[i]);
279 }
280 }
281 }
282 }
283
284 // Expand the box by the result of the bounds query from each.
285 for (size_t i = 0; i < extern_users.size(); i++) {
286 const string &extern_user = extern_users[i];
287 Box query_box;
288 Expr query_buf = Variable::make(type_of<struct halide_buffer_t *>(),
289 param.name() + ".bounds_query." + extern_user);
290 for (int j = 0; j < dimensions; j++) {
291 Expr min = Call::make(Int(32), Call::buffer_get_min,
292 {query_buf, j}, Call::Extern);
293 Expr max = Call::make(Int(32), Call::buffer_get_max,
294 {query_buf, j}, Call::Extern);
295 query_box.push_back(Interval(min, max));
296 }
297 merge_boxes(touched, query_box);
298 }
299 }
300
301 ReductionDomain rdom;
302
303 // An expression returning whether or not we're in inference mode
304 string buf_name = name + ".buffer";
305 Expr handle = Variable::make(type_of<halide_buffer_t *>(), buf_name,
306 image, param, rdom);
307 Expr inference_mode = Call::make(Bool(), Call::buffer_is_bounds_query,
308 {handle}, Call::Extern);
309 maybe_return_condition = maybe_return_condition || inference_mode;
310
311 // Come up with a name to refer to this buffer in the error messages
312 string error_name = (is_output_buffer ? "Output" : "Input");
313 error_name += " buffer " + name;
314
315 if (!is_output_buffer && t.has_feature(Target::MSAN)) {
316 Expr buffer = Variable::make(type_of<struct halide_buffer_t *>(), buf_name);
317 Stmt check_contents = Evaluate::make(
318 Call::make(Int(32), "halide_msan_check_buffer_is_initialized", {buffer, Expr(buf_name)}, Call::Extern));
319 msan_checks.push_back(check_contents);
320 }
321
322 // Check the type matches the internally-understood type
323 {
324 string type_name = name + ".type";
325 Expr type_var = Variable::make(UInt(32), type_name, image, param, rdom);
326 uint32_t correct_type_bits = ((halide_type_t)type).as_u32();
327 Expr correct_type_expr = make_const(UInt(32), correct_type_bits);
328 Expr error = Call::make(Int(32), "halide_error_bad_type",
329 {error_name, type_var, correct_type_expr},
330 Call::Extern);
331 Stmt type_check = AssertStmt::make(type_var == correct_type_expr, error);
332 asserts_type_checks.push_back(type_check);
333 }
334
335 // Check the dimensions matches the internally-understood dimensions
336 {
337 string dimensions_name = name + ".dimensions";
338 Expr dimensions_given = Variable::make(Int(32), dimensions_name, image, param, rdom);
339 Expr error = Call::make(Int(32), "halide_error_bad_dimensions",
340 {error_name,
341 dimensions_given, make_const(Int(32), dimensions)},
342 Call::Extern);
343 asserts_type_checks.push_back(
344 AssertStmt::make(dimensions_given == dimensions, error));
345 }
346
347 if (touched.maybe_unused()) {
348 debug(3) << "Image " << name << " is only used when " << touched.used << "\n";
349 }
350
351 // Check that the region passed in (after applying constraints) is within the region used
352 debug(3) << "In image " << name << " region touched is:\n";
353
354 for (int j = 0; j < dimensions; j++) {
355 string dim = std::to_string(j);
356 string actual_min_name = name + ".min." + dim;
357 string actual_extent_name = name + ".extent." + dim;
358 string actual_stride_name = name + ".stride." + dim;
359 Expr actual_min = Variable::make(Int(32), actual_min_name, image, param, rdom);
360 Expr actual_extent = Variable::make(Int(32), actual_extent_name, image, param, rdom);
361 Expr actual_stride = Variable::make(Int(32), actual_stride_name, image, param, rdom);
362
363 if (!touched.empty() && !touched[j].is_bounded()) {
364 user_error << "Buffer " << name
365 << " may be accessed in an unbounded way in dimension "
366 << j << "\n";
367 }
368
369 Expr min_required = touched.empty() ? actual_min : touched[j].min;
370 Expr extent_required = touched.empty() ? actual_extent : (touched[j].max + 1 - touched[j].min);
371
372 if (touched.maybe_unused()) {
373 min_required = select(touched.used, min_required, actual_min);
374 extent_required = select(touched.used, extent_required, actual_extent);
375 }
376
377 string min_required_name = name + ".min." + dim + ".required";
378 string extent_required_name = name + ".extent." + dim + ".required";
379
380 Expr min_required_var = Variable::make(Int(32), min_required_name);
381 Expr extent_required_var = Variable::make(Int(32), extent_required_name);
382
383 lets_required.emplace_back(extent_required_name, extent_required);
384 lets_required.emplace_back(min_required_name, min_required);
385
386 Expr actual_max = actual_min + actual_extent - 1;
387 Expr max_required = min_required_var + extent_required_var - 1;
388
389 if (touched.maybe_unused()) {
390 max_required = select(touched.used, max_required, actual_max);
391 }
392
393 Expr oob_condition = actual_min <= min_required_var && actual_max >= max_required;
394
395 Expr oob_error = Call::make(Int(32), "halide_error_access_out_of_bounds",
396 {error_name, j, min_required_var, max_required, actual_min, actual_max},
397 Call::Extern);
398
399 asserts_required.push_back(AssertStmt::make(oob_condition, oob_error));
400
401 // Come up with a required stride to use in bounds
402 // inference mode. We don't assert it. It's just used to
403 // apply the constraints to to come up with a proposed
404 // stride. Strides actually passed in may not be in this
405 // order (e.g if storage is swizzled relative to dimension
406 // order).
407 Expr stride_required;
408 if (j == 0) {
409 stride_required = 1;
410 } else {
411 string last_dim = std::to_string(j - 1);
412 stride_required = (Variable::make(Int(32), name + ".stride." + last_dim + ".required") *
413 Variable::make(Int(32), name + ".extent." + last_dim + ".required"));
414 }
415 lets_required.emplace_back(name + ".stride." + dim + ".required", stride_required);
416
417 // On 32-bit systems, insert checks to make sure the total
418 // size of all input and output buffers is <= 2^31 - 1.
419 // And that no product of extents overflows 2^31 - 1. This
420 // second test is likely only needed if a fuse directive
421 // is used in the schedule to combine multiple extents,
422 // but it is here for extra safety. On 64-bit targets with the
423 // LargeBuffers feature, the maximum size is 2^63 - 1.
424 Expr max_size = make_const(UInt(64), t.maximum_buffer_size());
425 Expr max_extent = make_const(UInt(64), 0x7fffffff);
426 Expr actual_size = abs(cast<int64_t>(actual_extent) * actual_stride);
427 Expr allocation_size_error = Call::make(Int(32), "halide_error_buffer_allocation_too_large",
428 {name, actual_size, max_size}, Call::Extern);
429 Stmt check = AssertStmt::make(actual_size <= max_size, allocation_size_error);
430 dims_no_overflow_asserts.push_back(check);
431
432 // Don't repeat extents check for secondary buffers as extents must be the same as for the first one.
433 if (!is_secondary_output_buffer) {
434 if (j == 0) {
435 lets_overflow.emplace_back(name + ".total_extent." + dim, cast<int64_t>(actual_extent));
436 } else {
437 max_size = cast<int64_t>(max_size);
438 Expr last_dim = Variable::make(Int(64), name + ".total_extent." + std::to_string(j - 1));
439 Expr this_dim = actual_extent * last_dim;
440 Expr this_dim_var = Variable::make(Int(64), name + ".total_extent." + dim);
441 lets_overflow.emplace_back(name + ".total_extent." + dim, this_dim);
442 Expr error = Call::make(Int(32), "halide_error_buffer_extents_too_large",
443 {name, this_dim_var, max_size}, Call::Extern);
444 Stmt check = AssertStmt::make(this_dim_var <= max_size, error);
445 dims_no_overflow_asserts.push_back(check);
446 }
447
448 // It is never legal to have a negative buffer extent.
449 Expr negative_extent_condition = actual_extent >= 0;
450 Expr negative_extent_error = Call::make(Int(32), "halide_error_buffer_extents_negative",
451 {error_name, j, actual_extent}, Call::Extern);
452 asserts_required.push_back(AssertStmt::make(negative_extent_condition, negative_extent_error));
453 }
454 }
455
456 // Create code that mutates the input buffers if we're in bounds inference mode.
457 BufferBuilder builder;
458 builder.buffer_memory = Variable::make(type_of<struct halide_buffer_t *>(), buf_name);
459 builder.shape_memory = Call::make(type_of<struct halide_dimension_t *>(),
460 Call::buffer_get_shape, {builder.buffer_memory},
461 Call::Extern);
462 builder.type = type;
463 builder.dimensions = dimensions;
464 for (int i = 0; i < dimensions; i++) {
465 string dim = std::to_string(i);
466 builder.mins.push_back(Variable::make(Int(32), name + ".min." + dim + ".proposed"));
467 builder.extents.push_back(Variable::make(Int(32), name + ".extent." + dim + ".proposed"));
468 builder.strides.push_back(Variable::make(Int(32), name + ".stride." + dim + ".proposed"));
469 }
470 Stmt rewrite = Evaluate::make(builder.build());
471
472 rewrite = IfThenElse::make(inference_mode, rewrite);
473 buffer_rewrites.push_back(rewrite);
474
475 // Build the constraints tests and proposed sizes.
476 vector<pair<Expr, Expr>> constraints;
477 for (int i = 0; i < dimensions; i++) {
478 string dim = std::to_string(i);
479 string min_name = name + ".min." + dim;
480 string stride_name = name + ".stride." + dim;
481 string extent_name = name + ".extent." + dim;
482
483 Expr stride_constrained, extent_constrained, min_constrained;
484
485 Expr stride_orig = Variable::make(Int(32), stride_name, image, param, rdom);
486 Expr extent_orig = Variable::make(Int(32), extent_name, image, param, rdom);
487 Expr min_orig = Variable::make(Int(32), min_name, image, param, rdom);
488
489 Expr stride_required = Variable::make(Int(32), stride_name + ".required");
490 Expr extent_required = Variable::make(Int(32), extent_name + ".required");
491 Expr min_required = Variable::make(Int(32), min_name + ".required");
492
493 Expr stride_proposed = Variable::make(Int(32), stride_name + ".proposed");
494 Expr extent_proposed = Variable::make(Int(32), extent_name + ".proposed");
495 Expr min_proposed = Variable::make(Int(32), min_name + ".proposed");
496
497 debug(2) << "Injecting constraints for " << name << "." << i << "\n";
498 if (is_secondary_output_buffer) {
499 // For multi-output (Tuple) pipelines, output buffers
500 // beyond the first implicitly have their min and extent
501 // constrained to match the first output.
502
503 if (param.defined()) {
504 user_assert(!param.extent_constraint(i).defined() &&
505 !param.min_constraint(i).defined())
506 << "Can't constrain the min or extent of an output buffer beyond the "
507 << "first. They are implicitly constrained to have the same min and extent "
508 << "as the first output buffer.\n";
509
510 stride_constrained = param.stride_constraint(i);
511 } else if (image.defined() && (int)i < image.dimensions()) {
512 stride_constrained = image.dim(i).stride();
513 }
514
515 std::string min0_name = buffer_name + ".0.min." + dim;
516 if (replace_with_constrained.count(min0_name) > 0) {
517 min_constrained = replace_with_constrained[min0_name];
518 } else {
519 min_constrained = Variable::make(Int(32), min0_name);
520 }
521
522 std::string extent0_name = buffer_name + ".0.extent." + dim;
523 if (replace_with_constrained.count(extent0_name) > 0) {
524 extent_constrained = replace_with_constrained[extent0_name];
525 } else {
526 extent_constrained = Variable::make(Int(32), extent0_name);
527 }
528 } else if (image.defined() && (int)i < image.dimensions()) {
529 stride_constrained = image.dim(i).stride();
530 extent_constrained = image.dim(i).extent();
531 min_constrained = image.dim(i).min();
532 } else if (param.defined()) {
533 stride_constrained = param.stride_constraint(i);
534 extent_constrained = param.extent_constraint(i);
535 min_constrained = param.min_constraint(i);
536 }
537
538 if (stride_constrained.defined()) {
539 // Come up with a suggested stride by passing the
540 // required region through this constraint.
541 constraints.emplace_back(stride_orig, stride_constrained);
542 stride_constrained = substitute(replace_with_required, stride_constrained);
543 lets_proposed.emplace_back(stride_name + ".proposed", stride_constrained);
544 } else {
545 lets_proposed.emplace_back(stride_name + ".proposed", stride_required);
546 }
547
548 if (min_constrained.defined()) {
549 constraints.emplace_back(min_orig, min_constrained);
550 min_constrained = substitute(replace_with_required, min_constrained);
551 lets_proposed.emplace_back(min_name + ".proposed", min_constrained);
552 } else {
553 lets_proposed.emplace_back(min_name + ".proposed", min_required);
554 }
555
556 if (extent_constrained.defined()) {
557 constraints.emplace_back(extent_orig, extent_constrained);
558 extent_constrained = substitute(replace_with_required, extent_constrained);
559 lets_proposed.emplace_back(extent_name + ".proposed", extent_constrained);
560 } else {
561 lets_proposed.emplace_back(extent_name + ".proposed", extent_required);
562 }
563
564 // In bounds inference mode, make sure the proposed
565 // versions still satisfy the constraints.
566 Expr max_proposed = min_proposed + extent_proposed - 1;
567 Expr max_required = min_required + extent_required - 1;
568 Expr check = (min_proposed <= min_required) && (max_proposed >= max_required);
569 Expr error = Call::make(Int(32), "halide_error_constraints_make_required_region_smaller",
570 {error_name, i, min_proposed, max_proposed, min_required, max_required},
571 Call::Extern);
572 asserts_proposed.push_back(AssertStmt::make((!inference_mode) || check, error));
573
574 // stride_required is just a suggestion. It's ok if the
575 // constraints shuffle them around in ways that make it
576 // smaller.
577 /*
578 check = (stride_proposed >= stride_required);
579 error = "Applying the constraints to the required stride made it smaller";
580 asserts_proposed.push_back(AssertStmt::make((!inference_mode) || check, error, vector<Expr>()));
581 */
582 }
583
584 // Assert all the conditions, and set the new values
585 for (size_t i = 0; i < constraints.size(); i++) {
586 Expr var = constraints[i].first;
587 const string &name = var.as<Variable>()->name;
588 Expr constrained_var = Variable::make(Int(32), name + ".constrained");
589
590 std::ostringstream ss;
591 ss << constraints[i].second;
592 string constrained_var_str = ss.str();
593
594 replace_with_constrained[name] = constrained_var;
595
596 lets_constrained.emplace_back(name + ".constrained", constraints[i].second);
597
598 Expr error = 0;
599 if (!no_asserts) {
600 error = Call::make(Int(32), "halide_error_constraint_violated",
601 {name, var, constrained_var_str, constrained_var},
602 Call::Extern);
603 }
604
605 // Check the var passed in equals the constrained version (when not in inference mode)
606 asserts_constrained.push_back(AssertStmt::make(var == constrained_var, error));
607 }
608
609 // For the buffers used on host, check the host field is non-null
610 Expr host_ptr = Variable::make(Handle(), name, image, param, ReductionDomain());
611 if (used_on_host) {
612 Expr error = Call::make(Int(32), "halide_error_host_is_null",
613 {error_name}, Call::Extern);
614 Expr check = (host_ptr != make_zero(host_ptr.type()));
615 if (touched.maybe_unused()) {
616 check = !touched.used || check;
617 }
618 asserts_host_non_null.push_back(AssertStmt::make(check, error));
619
620 if (!will_inject_host_copies) {
621 Expr device_dirty = Variable::make(Bool(), name + ".device_dirty",
622 image, param, ReductionDomain());
623
624 Expr error = Call::make(Int(32), "halide_error_device_dirty_with_no_device_support",
625 {error_name}, Call::Extern);
626
627 // If we have no device support, we can't handle
628 // device_dirty, so every buffer touched needs checking.
629 asserts_device_not_dirty.push_back(AssertStmt::make(!device_dirty, error));
630 }
631 }
632
633 // and check alignment of the host field
634 if (param.defined() && param.host_alignment() != param.type().bytes()) {
635 int alignment_required = param.host_alignment();
636 Expr u64t_host_ptr = reinterpret<uint64_t>(host_ptr);
637 Expr align_condition = (u64t_host_ptr % alignment_required) == 0;
638 Expr error = Call::make(Int(32), "halide_error_unaligned_host_ptr",
639 {name, alignment_required}, Call::Extern);
640 asserts_host_alignment.push_back(AssertStmt::make(align_condition, error));
641 }
642 }
643
644 auto prepend_stmts = [&](vector<Stmt> *stmts) {
645 while (!stmts->empty()) {
646 s = Block::make(std::move(stmts->back()), s);
647 stmts->pop_back();
648 }
649 };
650
651 auto prepend_lets = [&](vector<pair<string, Expr>> *lets) {
652 while (!lets->empty()) {
653 auto &p = lets->back();
654 s = LetStmt::make(p.first, std::move(p.second), s);
655 lets->pop_back();
656 }
657 };
658
659 if (!no_asserts) {
660 // Inject the code that checks the host pointers.
661 prepend_stmts(&asserts_host_non_null);
662 prepend_stmts(&asserts_host_alignment);
663 prepend_stmts(&asserts_device_not_dirty);
664 prepend_stmts(&dims_no_overflow_asserts);
665 prepend_lets(&lets_overflow);
666 }
667
668 // Replace uses of the var with the constrained versions in the
669 // rest of the program. We also need to respect the existence of
670 // constrained versions during storage flattening and bounds
671 // inference.
672 s = substitute(replace_with_constrained, s);
673
674 // Now we add a bunch of code to the top of the pipeline. This is
675 // all in reverse order compared to execution, as we incrementally
676 // prepending code.
677
678 // Inject the code that checks the constraints are correct. We
679 // need these regardless of how NoAsserts is set, because they are
680 // what gets Halide to actually exploit the constraint.
681 prepend_stmts(&asserts_constrained);
682
683 if (!no_asserts) {
684 prepend_stmts(&asserts_required);
685 prepend_stmts(&asserts_type_checks);
686 }
687
688 // Inject the code that returns early for inference mode.
689 if (!no_bounds_query) {
690 s = IfThenElse::make(!maybe_return_condition, s);
691 prepend_stmts(&buffer_rewrites);
692 }
693
694 if (!no_asserts) {
695 prepend_stmts(&asserts_proposed);
696 }
697
698 // Inject the code that defines the proposed sizes.
699 prepend_lets(&lets_proposed);
700
701 // Inject the code that defines the constrained sizes.
702 prepend_lets(&lets_constrained);
703
704 // Inject the code that defines the required sizes produced by bounds inference.
705 prepend_lets(&lets_required);
706
707 // Inject the code that checks that does msan checks. (Note that this ignores no_asserts.)
708 prepend_stmts(&msan_checks);
709
710 return s;
711 }
712
713 // The following function repeats the arguments list it just passes
714 // through six times. Surely there is a better way?
add_image_checks(const Stmt & s,const vector<Function> & outputs,const Target & t,const vector<string> & order,const map<string,Function> & env,const FuncValueBounds & fb,bool will_inject_host_copies)715 Stmt add_image_checks(const Stmt &s,
716 const vector<Function> &outputs,
717 const Target &t,
718 const vector<string> &order,
719 const map<string, Function> &env,
720 const FuncValueBounds &fb,
721 bool will_inject_host_copies) {
722
723 // Checks for images go at the marker deposited by computation
724 // bounds inference.
725 class Injector : public IRMutator {
726 using IRMutator::visit;
727
728 Stmt visit(const Block *op) override {
729 const Evaluate *e = op->first.as<Evaluate>();
730 const Call *c = e ? e->value.as<Call>() : nullptr;
731 if (c && c->is_intrinsic(Call::add_image_checks_marker)) {
732 return add_image_checks_inner(op->rest, outputs, t, order, env, fb, will_inject_host_copies);
733 } else {
734 return IRMutator::visit(op);
735 }
736 }
737
738 const vector<Function> &outputs;
739 const Target &t;
740 const vector<string> ℴ
741 const map<string, Function> &env;
742 const FuncValueBounds &fb;
743 bool will_inject_host_copies;
744
745 public:
746 Injector(const vector<Function> &outputs,
747 const Target &t,
748 const vector<string> &order,
749 const map<string, Function> &env,
750 const FuncValueBounds &fb,
751 bool will_inject_host_copies)
752 : outputs(outputs), t(t), order(order), env(env), fb(fb), will_inject_host_copies(will_inject_host_copies) {
753 }
754 } injector(outputs, t, order, env, fb, will_inject_host_copies);
755
756 return injector.mutate(s);
757 }
758
759 } // namespace Internal
760 } // namespace Halide
761