1 #include <algorithm>
2 #include <atomic>
3 #include <cmath>
4 #include <iostream>
5 #include <sstream>
6 #include <utility>
7
8 #include "CSE.h"
9 #include "Debug.h"
10 #include "IREquality.h"
11 #include "IRMutator.h"
12 #include "IROperator.h"
13 #include "IRPrinter.h"
14 #include "Util.h"
15 #include "Var.h"
16
17 namespace Halide {
18
19 // Evaluate a float polynomial efficiently, taking instruction latency
20 // into account. The high order terms come first. n is the number of
21 // terms, which is the degree plus one.
22 namespace {
23
evaluate_polynomial(Expr x,float * coeff,int n)24 Expr evaluate_polynomial(Expr x, float *coeff, int n) {
25 internal_assert(n >= 2);
26
27 Expr x2 = x * x;
28
29 Expr even_terms = coeff[0];
30 Expr odd_terms = coeff[1];
31
32 for (int i = 2; i < n; i++) {
33 if ((i & 1) == 0) {
34 if (coeff[i] == 0.0f) {
35 even_terms *= x2;
36 } else {
37 even_terms = even_terms * x2 + coeff[i];
38 }
39 } else {
40 if (coeff[i] == 0.0f) {
41 odd_terms *= x2;
42 } else {
43 odd_terms = odd_terms * x2 + coeff[i];
44 }
45 }
46 }
47
48 if ((n & 1) == 0) {
49 return even_terms * std::move(x) + odd_terms;
50 } else {
51 return odd_terms * std::move(x) + even_terms;
52 }
53 }
54
stringify(const std::vector<Expr> & args)55 Expr stringify(const std::vector<Expr> &args) {
56 if (args.empty()) {
57 return Expr("");
58 }
59
60 return Internal::Call::make(type_of<const char *>(), Internal::Call::stringify,
61 args, Internal::Call::Intrinsic);
62 }
63
combine_strings(const std::vector<Expr> & args)64 Expr combine_strings(const std::vector<Expr> &args) {
65 // Insert spaces between each expr.
66 std::vector<Expr> strings(args.size() * 2);
67 for (size_t i = 0; i < args.size(); i++) {
68 strings[i * 2] = args[i];
69 if (i < args.size() - 1) {
70 strings[i * 2 + 1] = Expr(" ");
71 } else {
72 strings[i * 2 + 1] = Expr("\n");
73 }
74 }
75
76 return stringify(strings);
77 }
78
79 } // namespace
80
81 namespace Internal {
82
is_const(const Expr & e)83 bool is_const(const Expr &e) {
84 if (e.as<IntImm>() ||
85 e.as<UIntImm>() ||
86 e.as<FloatImm>() ||
87 e.as<StringImm>()) {
88 return true;
89 } else if (const Cast *c = e.as<Cast>()) {
90 return is_const(c->value);
91 } else if (const Ramp *r = e.as<Ramp>()) {
92 return is_const(r->base) && is_const(r->stride);
93 } else if (const Broadcast *b = e.as<Broadcast>()) {
94 return is_const(b->value);
95 } else {
96 return false;
97 }
98 }
99
is_const(const Expr & e,int64_t value)100 bool is_const(const Expr &e, int64_t value) {
101 if (const IntImm *i = e.as<IntImm>()) {
102 return i->value == value;
103 } else if (const UIntImm *i = e.as<UIntImm>()) {
104 return (value >= 0) && (i->value == (uint64_t)value);
105 } else if (const FloatImm *i = e.as<FloatImm>()) {
106 return i->value == value;
107 } else if (const Cast *c = e.as<Cast>()) {
108 return is_const(c->value, value);
109 } else if (const Broadcast *b = e.as<Broadcast>()) {
110 return is_const(b->value, value);
111 } else {
112 return false;
113 }
114 }
115
is_no_op(const Stmt & s)116 bool is_no_op(const Stmt &s) {
117 if (!s.defined()) return true;
118 const Evaluate *e = s.as<Evaluate>();
119 return e && is_const(e->value);
120 }
121
122 namespace {
123
124 class ExprIsPure : public IRGraphVisitor {
125 using IRVisitor::visit;
126
visit(const Call * op)127 void visit(const Call *op) override {
128 if (!op->is_pure()) {
129 result = false;
130 } else {
131 IRGraphVisitor::visit(op);
132 }
133 }
134
visit(const Load * op)135 void visit(const Load *op) override {
136 if (!op->image.defined() && !op->param.defined()) {
137 // It's a load from an internal buffer, which could
138 // mutate.
139 result = false;
140 } else {
141 IRGraphVisitor::visit(op);
142 }
143 }
144
145 public:
146 bool result = true;
147 };
148
149 } // namespace
150
is_pure(const Expr & e)151 bool is_pure(const Expr &e) {
152 ExprIsPure pure;
153 e.accept(&pure);
154 return pure.result;
155 }
156
as_const_int(const Expr & e)157 const int64_t *as_const_int(const Expr &e) {
158 if (!e.defined()) {
159 return nullptr;
160 } else if (const Broadcast *b = e.as<Broadcast>()) {
161 return as_const_int(b->value);
162 } else if (const IntImm *i = e.as<IntImm>()) {
163 return &(i->value);
164 } else {
165 return nullptr;
166 }
167 }
168
as_const_uint(const Expr & e)169 const uint64_t *as_const_uint(const Expr &e) {
170 if (!e.defined()) {
171 return nullptr;
172 } else if (const Broadcast *b = e.as<Broadcast>()) {
173 return as_const_uint(b->value);
174 } else if (const UIntImm *i = e.as<UIntImm>()) {
175 return &(i->value);
176 } else {
177 return nullptr;
178 }
179 }
180
as_const_float(const Expr & e)181 const double *as_const_float(const Expr &e) {
182 if (!e.defined()) {
183 return nullptr;
184 } else if (const Broadcast *b = e.as<Broadcast>()) {
185 return as_const_float(b->value);
186 } else if (const FloatImm *f = e.as<FloatImm>()) {
187 return &(f->value);
188 } else {
189 return nullptr;
190 }
191 }
192
is_const_power_of_two_integer(const Expr & e,int * bits)193 bool is_const_power_of_two_integer(const Expr &e, int *bits) {
194 if (!(e.type().is_int() || e.type().is_uint())) return false;
195
196 const Broadcast *b = e.as<Broadcast>();
197 if (b) return is_const_power_of_two_integer(b->value, bits);
198
199 const Cast *c = e.as<Cast>();
200 if (c) return is_const_power_of_two_integer(c->value, bits);
201
202 uint64_t val = 0;
203
204 if (const int64_t *i = as_const_int(e)) {
205 if (*i < 0) return false;
206 val = (uint64_t)(*i);
207 } else if (const uint64_t *u = as_const_uint(e)) {
208 val = *u;
209 }
210
211 if (val && ((val & (val - 1)) == 0)) {
212 *bits = 0;
213 for (; val; val >>= 1) {
214 if (val == 1) return true;
215 (*bits)++;
216 }
217 }
218
219 return false;
220 }
221
is_positive_const(const Expr & e)222 bool is_positive_const(const Expr &e) {
223 if (const IntImm *i = e.as<IntImm>()) return i->value > 0;
224 if (const UIntImm *u = e.as<UIntImm>()) return u->value > 0;
225 if (const FloatImm *f = e.as<FloatImm>()) return f->value > 0.0f;
226 if (const Cast *c = e.as<Cast>()) {
227 return is_positive_const(c->value);
228 }
229 if (const Ramp *r = e.as<Ramp>()) {
230 // slightly conservative
231 return is_positive_const(r->base) && is_positive_const(r->stride);
232 }
233 if (const Broadcast *b = e.as<Broadcast>()) {
234 return is_positive_const(b->value);
235 }
236 return false;
237 }
238
is_negative_const(const Expr & e)239 bool is_negative_const(const Expr &e) {
240 if (const IntImm *i = e.as<IntImm>()) return i->value < 0;
241 if (const FloatImm *f = e.as<FloatImm>()) return f->value < 0.0f;
242 if (const Cast *c = e.as<Cast>()) {
243 return is_negative_const(c->value);
244 }
245 if (const Ramp *r = e.as<Ramp>()) {
246 // slightly conservative
247 return is_negative_const(r->base) && is_negative_const(r->stride);
248 }
249 if (const Broadcast *b = e.as<Broadcast>()) {
250 return is_negative_const(b->value);
251 }
252 return false;
253 }
254
is_negative_negatable_const(const Expr & e,Type T)255 bool is_negative_negatable_const(const Expr &e, Type T) {
256 if (const IntImm *i = e.as<IntImm>()) {
257 return (i->value < 0 && !T.is_min(i->value));
258 }
259 if (const FloatImm *f = e.as<FloatImm>()) return f->value < 0.0f;
260 if (const Cast *c = e.as<Cast>()) {
261 return is_negative_negatable_const(c->value, c->type);
262 }
263 if (const Ramp *r = e.as<Ramp>()) {
264 // slightly conservative
265 return is_negative_negatable_const(r->base) && is_negative_const(r->stride);
266 }
267 if (const Broadcast *b = e.as<Broadcast>()) {
268 return is_negative_negatable_const(b->value);
269 }
270 return false;
271 }
272
is_negative_negatable_const(const Expr & e)273 bool is_negative_negatable_const(const Expr &e) {
274 return is_negative_negatable_const(e, e.type());
275 }
276
is_undef(const Expr & e)277 bool is_undef(const Expr &e) {
278 if (const Call *c = e.as<Call>()) return c->is_intrinsic(Call::undef);
279 return false;
280 }
281
is_zero(const Expr & e)282 bool is_zero(const Expr &e) {
283 if (const IntImm *int_imm = e.as<IntImm>()) return int_imm->value == 0;
284 if (const UIntImm *uint_imm = e.as<UIntImm>()) return uint_imm->value == 0;
285 if (const FloatImm *float_imm = e.as<FloatImm>()) return float_imm->value == 0.0;
286 if (const Cast *c = e.as<Cast>()) return is_zero(c->value);
287 if (const Broadcast *b = e.as<Broadcast>()) return is_zero(b->value);
288 if (const Call *c = e.as<Call>()) {
289 return (c->is_intrinsic(Call::bool_to_mask) || c->is_intrinsic(Call::cast_mask)) &&
290 is_zero(c->args[0]);
291 }
292 return false;
293 }
294
is_one(const Expr & e)295 bool is_one(const Expr &e) {
296 if (const IntImm *int_imm = e.as<IntImm>()) return int_imm->value == 1;
297 if (const UIntImm *uint_imm = e.as<UIntImm>()) return uint_imm->value == 1;
298 if (const FloatImm *float_imm = e.as<FloatImm>()) return float_imm->value == 1.0;
299 if (const Cast *c = e.as<Cast>()) return is_one(c->value);
300 if (const Broadcast *b = e.as<Broadcast>()) return is_one(b->value);
301 if (const Call *c = e.as<Call>()) {
302 return (c->is_intrinsic(Call::bool_to_mask) || c->is_intrinsic(Call::cast_mask)) &&
303 is_one(c->args[0]);
304 }
305 return false;
306 }
307
is_two(const Expr & e)308 bool is_two(const Expr &e) {
309 if (e.type().bits() < 2) return false;
310 if (const IntImm *int_imm = e.as<IntImm>()) return int_imm->value == 2;
311 if (const UIntImm *uint_imm = e.as<UIntImm>()) return uint_imm->value == 2;
312 if (const FloatImm *float_imm = e.as<FloatImm>()) return float_imm->value == 2.0;
313 if (const Cast *c = e.as<Cast>()) return is_two(c->value);
314 if (const Broadcast *b = e.as<Broadcast>()) return is_two(b->value);
315 return false;
316 }
317
318 namespace {
319
320 template<typename T>
make_const_helper(Type t,T val)321 Expr make_const_helper(Type t, T val) {
322 if (t.is_vector()) {
323 return Broadcast::make(make_const(t.element_of(), val), t.lanes());
324 } else if (t.is_int()) {
325 return IntImm::make(t, (int64_t)val);
326 } else if (t.is_uint()) {
327 return UIntImm::make(t, (uint64_t)val);
328 } else if (t.is_float()) {
329 return FloatImm::make(t, (double)val);
330 } else {
331 internal_error << "Can't make a constant of type " << t << "\n";
332 return Expr();
333 }
334 }
335
336 } // namespace
337
make_const(Type t,int64_t val)338 Expr make_const(Type t, int64_t val) {
339 return make_const_helper(t, val);
340 }
341
make_const(Type t,uint64_t val)342 Expr make_const(Type t, uint64_t val) {
343 return make_const_helper(t, val);
344 }
345
make_const(Type t,double val)346 Expr make_const(Type t, double val) {
347 return make_const_helper(t, val);
348 }
349
make_bool(bool val,int w)350 Expr make_bool(bool val, int w) {
351 return make_const(UInt(1, w), val);
352 }
353
make_zero(Type t)354 Expr make_zero(Type t) {
355 if (t.is_handle()) {
356 return reinterpret(t, make_zero(UInt(64)));
357 } else {
358 return make_const(t, 0);
359 }
360 }
361
make_one(Type t)362 Expr make_one(Type t) {
363 return make_const(t, 1);
364 }
365
make_two(Type t)366 Expr make_two(Type t) {
367 return make_const(t, 2);
368 }
369
make_signed_integer_overflow(Type type)370 Expr make_signed_integer_overflow(Type type) {
371 static std::atomic<int> counter{0};
372 return Call::make(type, Call::signed_integer_overflow, {counter++}, Call::Intrinsic);
373 }
374
const_true(int w)375 Expr const_true(int w) {
376 return make_one(UInt(1, w));
377 }
378
const_false(int w)379 Expr const_false(int w) {
380 return make_zero(UInt(1, w));
381 }
382
lossless_cast(Type t,Expr e)383 Expr lossless_cast(Type t, Expr e) {
384 if (!e.defined() || t == e.type()) {
385 return e;
386 } else if (t.can_represent(e.type())) {
387 return cast(t, std::move(e));
388 }
389
390 if (const Cast *c = e.as<Cast>()) {
391 if (t.can_represent(c->value.type())) {
392 // We can recurse into widening casts.
393 return lossless_cast(t, c->value);
394 } else {
395 return Expr();
396 }
397 }
398
399 if (const Broadcast *b = e.as<Broadcast>()) {
400 Expr v = lossless_cast(t.element_of(), b->value);
401 if (v.defined()) {
402 return Broadcast::make(v, b->lanes);
403 } else {
404 return Expr();
405 }
406 }
407
408 if (const IntImm *i = e.as<IntImm>()) {
409 if (t.can_represent(i->value)) {
410 return make_const(t, i->value);
411 } else {
412 return Expr();
413 }
414 }
415
416 if (const UIntImm *i = e.as<UIntImm>()) {
417 if (t.can_represent(i->value)) {
418 return make_const(t, i->value);
419 } else {
420 return Expr();
421 }
422 }
423
424 if (const FloatImm *f = e.as<FloatImm>()) {
425 if (t.can_represent(f->value)) {
426 return make_const(t, f->value);
427 } else {
428 return Expr();
429 }
430 }
431
432 if ((t.is_int() || t.is_uint()) && t.bits() >= 16) {
433 if (const Add *add = e.as<Add>()) {
434 // If we can losslessly narrow the args even more
435 // aggressively, we're good.
436 // E.g. lossless_cast(uint16, (uint32)(some_u8) + 37)
437 // = (uint16)(some_u8) + 37
438 Expr a = lossless_cast(t.with_bits(t.bits() / 2), add->a);
439 Expr b = lossless_cast(t.with_bits(t.bits() / 2), add->b);
440 if (a.defined() && b.defined()) {
441 return cast(t, a) + cast(t, b);
442 } else {
443 return Expr();
444 }
445 }
446
447 if (const Sub *sub = e.as<Sub>()) {
448 Expr a = lossless_cast(t.with_bits(t.bits() / 2), sub->a);
449 Expr b = lossless_cast(t.with_bits(t.bits() / 2), sub->b);
450 if (a.defined() && b.defined()) {
451 return cast(t, a) + cast(t, b);
452 } else {
453 return Expr();
454 }
455 }
456
457 if (const Mul *mul = e.as<Mul>()) {
458 Expr a = lossless_cast(t.with_bits(t.bits() / 2), mul->a);
459 Expr b = lossless_cast(t.with_bits(t.bits() / 2), mul->b);
460 if (a.defined() && b.defined()) {
461 return cast(t, a) * cast(t, b);
462 } else {
463 return Expr();
464 }
465 }
466
467 if (const VectorReduce *reduce = e.as<VectorReduce>()) {
468 const int factor = reduce->value.type().lanes() / reduce->type.lanes();
469 switch (reduce->op) {
470 case VectorReduce::Add:
471 // A horizontal add requires one extra bit per factor
472 // of two in the reduction factor. E.g. a reduction of
473 // 8 vector lanes down to 2 requires 2 extra bits in
474 // the output. We only deal with power-of-two types
475 // though, so just make sure the reduction factor
476 // isn't so large that it will more than double the
477 // number of bits required.
478 if (factor < (1 << (t.bits() / 2))) {
479 Type narrower = reduce->value.type().with_bits(t.bits() / 2);
480 Expr val = lossless_cast(narrower, reduce->value);
481 if (val.defined()) {
482 return VectorReduce::make(reduce->op, val, reduce->type.lanes());
483 }
484 }
485 break;
486 case VectorReduce::Max:
487 case VectorReduce::Min: {
488 Expr val = lossless_cast(t, reduce->value);
489 if (val.defined()) {
490 return VectorReduce::make(reduce->op, val, reduce->type.lanes());
491 }
492 break;
493 }
494 default:
495 break;
496 }
497 }
498 }
499
500 return Expr();
501 }
502
check_representable(Type dst,int64_t x)503 void check_representable(Type dst, int64_t x) {
504 if (dst.is_handle()) {
505 user_assert(dst.can_represent(x))
506 << "Integer constant " << x
507 << " will be implicitly coerced to type " << dst
508 << ", but Halide does not support pointer arithmetic.\n";
509 } else {
510 user_assert(dst.can_represent(x))
511 << "Integer constant " << x
512 << " will be implicitly coerced to type " << dst
513 << ", which changes its value to " << make_const(dst, x)
514 << ".\n";
515 }
516 }
517
match_types(Expr & a,Expr & b)518 void match_types(Expr &a, Expr &b) {
519 if (a.type() == b.type()) return;
520
521 user_assert(!a.type().is_handle() && !b.type().is_handle())
522 << "Can't do arithmetic on opaque pointer types: "
523 << a << ", " << b << "\n";
524
525 // Broadcast scalar to match vector
526 if (a.type().is_scalar() && b.type().is_vector()) {
527 a = Broadcast::make(std::move(a), b.type().lanes());
528 } else if (a.type().is_vector() && b.type().is_scalar()) {
529 b = Broadcast::make(std::move(b), a.type().lanes());
530 } else {
531 internal_assert(a.type().lanes() == b.type().lanes()) << "Can't match types of differing widths";
532 }
533
534 Type ta = a.type(), tb = b.type();
535
536 // If type broadcasting has made the types match no additional casts are needed
537 if (ta == tb) return;
538
539 if (!ta.is_float() && tb.is_float()) {
540 // int(a) * float(b) -> float(b)
541 // uint(a) * float(b) -> float(b)
542 a = cast(tb, std::move(a));
543 } else if (ta.is_float() && !tb.is_float()) {
544 b = cast(ta, std::move(b));
545 } else if (ta.is_float() && tb.is_float()) {
546 // float(a) * float(b) -> float(max(a, b))
547 if (ta.bits() > tb.bits())
548 b = cast(ta, std::move(b));
549 else
550 a = cast(tb, std::move(a));
551 } else if (ta.is_uint() && tb.is_uint()) {
552 // uint(a) * uint(b) -> uint(max(a, b))
553 if (ta.bits() > tb.bits())
554 b = cast(ta, std::move(b));
555 else
556 a = cast(tb, std::move(a));
557 } else if (!ta.is_float() && !tb.is_float()) {
558 // int(a) * (u)int(b) -> int(max(a, b))
559 int bits = std::max(ta.bits(), tb.bits());
560 int lanes = a.type().lanes();
561 a = cast(Int(bits, lanes), std::move(a));
562 b = cast(Int(bits, lanes), std::move(b));
563 } else {
564 internal_error << "Could not match types: " << ta << ", " << tb << "\n";
565 }
566 }
567
568 // Cast to the wider type of the two. Already guaranteed to leave
569 // signed/unsigned on number of lanes unchanged.
match_bits(Expr & x,Expr & y)570 void match_bits(Expr &x, Expr &y) {
571 // The signedness doesn't match, so just match the bits.
572 if (x.type().bits() < y.type().bits()) {
573 Type t;
574 if (x.type().is_int()) {
575 t = Int(y.type().bits(), y.type().lanes());
576 } else {
577 t = UInt(y.type().bits(), y.type().lanes());
578 }
579 x = cast(t, x);
580 } else if (y.type().bits() < x.type().bits()) {
581 Type t;
582 if (y.type().is_int()) {
583 t = Int(x.type().bits(), x.type().lanes());
584 } else {
585 t = UInt(x.type().bits(), x.type().lanes());
586 }
587 y = cast(t, y);
588 }
589 }
590
match_types_bitwise(Expr & x,Expr & y,const char * op_name)591 void match_types_bitwise(Expr &x, Expr &y, const char *op_name) {
592 user_assert(x.defined() && y.defined()) << op_name << " of undefined Expr\n";
593 user_assert(x.type().is_int() || x.type().is_uint())
594 << "The first argument to " << op_name << " must be an integer or unsigned integer";
595 user_assert(y.type().is_int() || y.type().is_uint())
596 << "The second argument to " << op_name << " must be an integer or unsigned integer";
597 user_assert(y.type().is_int() == x.type().is_int())
598 << "Arguments to " << op_name
599 << " must be both be signed or both be unsigned.\n"
600 << "LHS type: " << x.type() << " RHS type: " << y.type() << "\n"
601 << "LHS value: " << x << " RHS value: " << y << "\n";
602
603 // Broadcast scalar to match vector
604 if (x.type().is_scalar() && y.type().is_vector()) {
605 x = Broadcast::make(std::move(x), y.type().lanes());
606 } else if (x.type().is_vector() && y.type().is_scalar()) {
607 y = Broadcast::make(std::move(y), x.type().lanes());
608 } else {
609 internal_assert(x.type().lanes() == y.type().lanes()) << "Can't match types of differing widths";
610 }
611
612 // Cast to the wider type of the two. Already guaranteed to leave
613 // signed/unsigned on number of lanes unchanged.
614 if (x.type().bits() < y.type().bits()) {
615 x = cast(y.type(), x);
616 } else if (y.type().bits() < x.type().bits()) {
617 y = cast(x.type(), y);
618 }
619 }
620
621 // Fast math ops based on those from Syrah (http://github.com/boulos/syrah). Thanks, Solomon!
622
623 // Factor a float into 2^exponent * reduced, where reduced is between 0.75 and 1.5
range_reduce_log(const Expr & input,Expr * reduced,Expr * exponent)624 void range_reduce_log(const Expr &input, Expr *reduced, Expr *exponent) {
625 Type type = input.type();
626 Type int_type = Int(32, type.lanes());
627 Expr int_version = reinterpret(int_type, input);
628
629 // single precision = SEEE EEEE EMMM MMMM MMMM MMMM MMMM MMMM
630 // exponent mask = 0111 1111 1000 0000 0000 0000 0000 0000
631 // 0x7 0xF 0x8 0x0 0x0 0x0 0x0 0x0
632 // non-exponent = 1000 0000 0111 1111 1111 1111 1111 1111
633 // = 0x8 0x0 0x7 0xF 0xF 0xF 0xF 0xF
634 Expr non_exponent_mask = make_const(int_type, 0x807fffff);
635
636 // Extract a version with no exponent (between 1.0 and 2.0)
637 Expr no_exponent = int_version & non_exponent_mask;
638
639 // If > 1.5, we want to divide by two, to normalize back into the
640 // range (0.75, 1.5). We can detect this by sniffing the high bit
641 // of the mantissa.
642 Expr new_exponent = no_exponent >> 22;
643
644 Expr new_biased_exponent = 127 - new_exponent;
645 Expr old_biased_exponent = int_version >> 23;
646 *exponent = old_biased_exponent - new_biased_exponent;
647
648 Expr blended = (int_version & non_exponent_mask) | (new_biased_exponent << 23);
649
650 *reduced = reinterpret(type, blended);
651 }
652
halide_log(const Expr & x_full)653 Expr halide_log(const Expr &x_full) {
654 Type type = x_full.type();
655 internal_assert(type.element_of() == Float(32));
656
657 Expr nan = Call::make(type, "nan_f32", {}, Call::PureExtern);
658 Expr neg_inf = Call::make(type, "neg_inf_f32", {}, Call::PureExtern);
659
660 Expr use_nan = x_full < 0.0f; // log of a negative returns nan
661 Expr use_neg_inf = x_full == 0.0f; // log of zero is -inf
662 Expr exceptional = use_nan | use_neg_inf;
663
664 // Avoid producing nans or infs by generating ln(1.0f) instead and
665 // then fixing it later.
666 Expr patched = select(exceptional, make_one(type), x_full);
667 Expr reduced, exponent;
668 range_reduce_log(patched, &reduced, &exponent);
669
670 // Very close to the Taylor series for log about 1, but tuned to
671 // have minimum relative error in the reduced domain (0.75 - 1.5).
672
673 float coeff[] = {
674 0.05111976432738144643f,
675 -0.11793923497136414580f,
676 0.14971993724699017569f,
677 -0.16862004708254804686f,
678 0.19980668101718729313f,
679 -0.24991211576292837737f,
680 0.33333435275479328386f,
681 -0.50000106292873236491f,
682 1.0f,
683 0.0f};
684 Expr x1 = reduced - 1.0f;
685 Expr result = evaluate_polynomial(x1, coeff, sizeof(coeff) / sizeof(coeff[0]));
686
687 result += cast(type, exponent) * logf(2.0);
688
689 result = select(exceptional, select(use_nan, nan, neg_inf), result);
690
691 // This introduces lots of common subexpressions
692 result = common_subexpression_elimination(result);
693
694 return result;
695 }
696
halide_exp(const Expr & x_full)697 Expr halide_exp(const Expr &x_full) {
698 Type type = x_full.type();
699 internal_assert(type.element_of() == Float(32));
700
701 float ln2_part1 = 0.6931457519f;
702 float ln2_part2 = 1.4286067653e-6f;
703 float one_over_ln2 = 1.0f / logf(2.0f);
704
705 Expr scaled = x_full * one_over_ln2;
706 Expr k_real = floor(scaled);
707 Expr k = cast(Int(32, type.lanes()), k_real);
708
709 Expr x = x_full - k_real * ln2_part1;
710 x -= k_real * ln2_part2;
711
712 float coeff[] = {
713 0.00031965933071842413f,
714 0.00119156835564003744f,
715 0.00848988645943932717f,
716 0.04160188091348320655f,
717 0.16667983794100929562f,
718 0.49999899033463041098f,
719 1.0f,
720 1.0f};
721 Expr result = evaluate_polynomial(x, coeff, sizeof(coeff) / sizeof(coeff[0]));
722
723 // Compute 2^k.
724 int fpbias = 127;
725 Expr biased = k + fpbias;
726
727 Expr inf = Call::make(type, "inf_f32", {}, Call::PureExtern);
728
729 // Shift the bits up into the exponent field and reinterpret this
730 // thing as float.
731 Expr two_to_the_n = reinterpret(type, biased << 23);
732 result *= two_to_the_n;
733
734 // Catch overflow and underflow
735 result = select(biased < 255, result, inf);
736 result = select(biased > 0, result, make_zero(type));
737
738 // This introduces lots of common subexpressions
739 result = common_subexpression_elimination(result);
740
741 return result;
742 }
743
halide_erf(const Expr & x_full)744 Expr halide_erf(const Expr &x_full) {
745 user_assert(x_full.type() == Float(32)) << "halide_erf only works for Float(32)";
746
747 // Extract the sign and magnitude.
748 Expr sign = select(x_full < 0, -1.0f, 1.0f);
749 Expr x = abs(x_full);
750
751 // An approximation very similar to one from Abramowitz and
752 // Stegun, but tuned for values > 1. Takes the form 1 - P(x)^-16.
753 float c1[] = {0.0000818502f,
754 -0.0000026500f,
755 0.0009353904f,
756 0.0081960206f,
757 0.0430054424f,
758 0.0703310579f,
759 1.0f};
760 Expr approx1 = evaluate_polynomial(x, c1, sizeof(c1) / sizeof(c1[0]));
761
762 approx1 = 1.0f - pow(approx1, -16);
763
764 // An odd polynomial tuned for values < 1. Similar to the Taylor
765 // expansion of erf.
766 float c2[] = {-0.0005553339f,
767 0.0048937243f,
768 -0.0266849239f,
769 0.1127890132f,
770 -0.3761207240f,
771 1.1283789803f};
772
773 Expr approx2 = evaluate_polynomial(x * x, c2, sizeof(c2) / sizeof(c2[0]));
774 approx2 *= x;
775
776 // Switch between the two approximations based on the magnitude.
777 Expr y = select(x > 1.0f, approx1, approx2);
778
779 Expr result = common_subexpression_elimination(sign * y);
780
781 return result;
782 }
783
raise_to_integer_power(Expr e,int64_t p)784 Expr raise_to_integer_power(Expr e, int64_t p) {
785 Expr result;
786 if (p == 0) {
787 result = make_one(e.type());
788 } else if (p == 1) {
789 result = std::move(e);
790 } else if (p < 0) {
791 result = make_one(e.type());
792 result /= raise_to_integer_power(std::move(e), -p);
793 } else {
794 // p is at least 2
795 if (p & 1) {
796 Expr y = raise_to_integer_power(e, p >> 1);
797 result = y * y * std::move(e);
798 } else {
799 e = raise_to_integer_power(std::move(e), p >> 1);
800 result = e * e;
801 }
802 }
803 return result;
804 }
805
split_into_ands(const Expr & cond,std::vector<Expr> & result)806 void split_into_ands(const Expr &cond, std::vector<Expr> &result) {
807 if (!cond.defined()) {
808 return;
809 }
810 internal_assert(cond.type().is_bool()) << "Should be a boolean condition\n";
811 if (const And *a = cond.as<And>()) {
812 split_into_ands(a->a, result);
813 split_into_ands(a->b, result);
814 } else if (!is_one(cond)) {
815 result.push_back(cond);
816 }
817 }
818
build() const819 Expr BufferBuilder::build() const {
820 std::vector<Expr> args(10);
821 if (buffer_memory.defined()) {
822 args[0] = buffer_memory;
823 } else {
824 Expr sz = Call::make(Int(32), Call::size_of_halide_buffer_t, {}, Call::Intrinsic);
825 args[0] = Call::make(type_of<struct halide_buffer_t *>(), Call::alloca, {sz}, Call::Intrinsic);
826 }
827
828 std::string shape_var_name = unique_name('t');
829 Expr shape_var = Variable::make(type_of<halide_dimension_t *>(), shape_var_name);
830 if (shape_memory.defined()) {
831 args[1] = shape_memory;
832 } else if (dimensions == 0) {
833 args[1] = make_zero(type_of<halide_dimension_t *>());
834 } else {
835 args[1] = shape_var;
836 }
837
838 if (host.defined()) {
839 args[2] = host;
840 } else {
841 args[2] = make_zero(type_of<void *>());
842 }
843
844 if (device.defined()) {
845 args[3] = device;
846 } else {
847 args[3] = make_zero(UInt(64));
848 }
849
850 if (device_interface.defined()) {
851 args[4] = device_interface;
852 } else {
853 args[4] = make_zero(type_of<struct halide_device_interface_t *>());
854 }
855
856 args[5] = (int)type.code();
857 args[6] = type.bits();
858 args[7] = dimensions;
859
860 std::vector<Expr> shape;
861 for (size_t i = 0; i < (size_t)dimensions; i++) {
862 if (i < mins.size()) {
863 shape.push_back(mins[i]);
864 } else {
865 shape.emplace_back(0);
866 }
867 if (i < extents.size()) {
868 shape.push_back(extents[i]);
869 } else {
870 shape.emplace_back(0);
871 }
872 if (i < strides.size()) {
873 shape.push_back(strides[i]);
874 } else {
875 shape.emplace_back(0);
876 }
877 // per-dimension flags, currently unused.
878 shape.emplace_back(0);
879 }
880 for (const Expr &e : shape) {
881 internal_assert(e.type() == Int(32))
882 << "Buffer shape fields must be int32_t:" << e << "\n";
883 }
884 Expr shape_arg = Call::make(type_of<halide_dimension_t *>(), Call::make_struct, shape, Call::Intrinsic);
885 if (shape_memory.defined()) {
886 args[8] = shape_arg;
887 } else if (dimensions == 0) {
888 args[8] = make_zero(type_of<halide_dimension_t *>());
889 } else {
890 args[8] = shape_var;
891 }
892
893 Expr flags = make_zero(UInt(64));
894 if (host_dirty.defined()) {
895 flags = select(host_dirty,
896 make_const(UInt(64), halide_buffer_flag_host_dirty),
897 make_zero(UInt(64)));
898 }
899 if (device_dirty.defined()) {
900 flags = flags | select(device_dirty,
901 make_const(UInt(64), halide_buffer_flag_device_dirty),
902 make_zero(UInt(64)));
903 }
904 args[9] = flags;
905
906 Expr e = Call::make(type_of<struct halide_buffer_t *>(), Call::buffer_init, args, Call::Extern);
907
908 if (!shape_memory.defined() && dimensions != 0) {
909 e = Let::make(shape_var_name, shape_arg, e);
910 }
911
912 return e;
913 }
914
strided_ramp_base(const Expr & e,int stride)915 Expr strided_ramp_base(const Expr &e, int stride) {
916 const Ramp *r = e.as<Ramp>();
917 if (r == nullptr) {
918 return Expr();
919 }
920
921 const IntImm *i = r->stride.as<IntImm>();
922 if (i != nullptr && i->value == stride) {
923 return r->base;
924 }
925
926 return Expr();
927 }
928
929 namespace {
930
931 struct RemoveLikelies : public IRMutator {
932 using IRMutator::visit;
visitHalide::Internal::__anon460940500411::RemoveLikelies933 Expr visit(const Call *op) override {
934 if (op->is_intrinsic(Call::likely) ||
935 op->is_intrinsic(Call::likely_if_innermost)) {
936 return mutate(op->args[0]);
937 } else {
938 return IRMutator::visit(op);
939 }
940 }
941 };
942
943 } // namespace
944
remove_likelies(const Expr & e)945 Expr remove_likelies(const Expr &e) {
946 return RemoveLikelies().mutate(e);
947 }
948
remove_likelies(const Stmt & s)949 Stmt remove_likelies(const Stmt &s) {
950 return RemoveLikelies().mutate(s);
951 }
952
requirement_failed_error(Expr condition,const std::vector<Expr> & args)953 Expr requirement_failed_error(Expr condition, const std::vector<Expr> &args) {
954 return Internal::Call::make(Int(32),
955 "halide_error_requirement_failed",
956 {stringify({std::move(condition)}), combine_strings(args)},
957 Internal::Call::Extern);
958 }
959
memoize_tag_helper(Expr result,const std::vector<Expr> & cache_key_values)960 Expr memoize_tag_helper(Expr result, const std::vector<Expr> &cache_key_values) {
961 Type t = result.type();
962 std::vector<Expr> args;
963 args.push_back(std::move(result));
964 args.insert(args.end(), cache_key_values.begin(), cache_key_values.end());
965 return Internal::Call::make(t, Internal::Call::memoize_expr,
966 args, Internal::Call::PureIntrinsic);
967 }
968
969 } // namespace Internal
970
fast_log(const Expr & x)971 Expr fast_log(const Expr &x) {
972 user_assert(x.type() == Float(32)) << "fast_log only works for Float(32)";
973
974 Expr reduced, exponent;
975 range_reduce_log(x, &reduced, &exponent);
976
977 Expr x1 = reduced - 1.0f;
978
979 float coeff[] = {
980 0.07640318789187280912f,
981 -0.16252961013874300811f,
982 0.20625219040645212387f,
983 -0.25110261010892864775f,
984 0.33320464908377461777f,
985 -0.49997513376789826101f,
986 1.0f,
987 0.0f};
988
989 Expr result = evaluate_polynomial(x1, coeff, sizeof(coeff) / sizeof(coeff[0]));
990 result = result + cast<float>(exponent) * logf(2);
991 result = common_subexpression_elimination(result);
992 return result;
993 }
994
995 namespace {
996
997 // A vectorizable sine and cosine implementation. Based on syrah fast vector math
998 // https://github.com/boulos/syrah/blob/master/src/include/syrah/FixedVectorMath.h#L55
fast_sin_cos(const Expr & x_full,bool is_sin)999 Expr fast_sin_cos(const Expr &x_full, bool is_sin) {
1000 const float two_over_pi = 0.636619746685028076171875f;
1001 const float pi_over_two = 1.57079637050628662109375f;
1002 Expr scaled = x_full * two_over_pi;
1003 Expr k_real = floor(scaled);
1004 Expr k = cast<int>(k_real);
1005 Expr k_mod4 = k % 4;
1006 Expr sin_usecos = is_sin ? ((k_mod4 == 1) || (k_mod4 == 3)) : ((k_mod4 == 0) || (k_mod4 == 2));
1007 Expr flip_sign = is_sin ? (k_mod4 > 1) : ((k_mod4 == 1) || (k_mod4 == 2));
1008
1009 // Reduce the angle modulo pi/2.
1010 Expr x = x_full - k_real * pi_over_two;
1011
1012 const float sin_c2 = -0.16666667163372039794921875f;
1013 const float sin_c4 = 8.333347737789154052734375e-3;
1014 const float sin_c6 = -1.9842604524455964565277099609375e-4;
1015 const float sin_c8 = 2.760012648650445044040679931640625e-6;
1016 const float sin_c10 = -2.50293279435709337121807038784027099609375e-8;
1017
1018 const float cos_c2 = -0.5f;
1019 const float cos_c4 = 4.166664183139801025390625e-2;
1020 const float cos_c6 = -1.388833043165504932403564453125e-3;
1021 const float cos_c8 = 2.47562347794882953166961669921875e-5;
1022 const float cos_c10 = -2.59630184018533327616751194000244140625e-7;
1023
1024 Expr outside = select(sin_usecos, 1, x);
1025 Expr c2 = select(sin_usecos, cos_c2, sin_c2);
1026 Expr c4 = select(sin_usecos, cos_c4, sin_c4);
1027 Expr c6 = select(sin_usecos, cos_c6, sin_c6);
1028 Expr c8 = select(sin_usecos, cos_c8, sin_c8);
1029 Expr c10 = select(sin_usecos, cos_c10, sin_c10);
1030
1031 Expr x2 = x * x;
1032 Expr tri_func = outside * (x2 * (x2 * (x2 * (x2 * (x2 * c10 + c8) + c6) + c4) + c2) + 1);
1033 return select(flip_sign, -tri_func, tri_func);
1034 }
1035
1036 } // namespace
1037
fast_sin(const Expr & x_full)1038 Expr fast_sin(const Expr &x_full) {
1039 return fast_sin_cos(x_full, true);
1040 }
1041
fast_cos(const Expr & x_full)1042 Expr fast_cos(const Expr &x_full) {
1043 return fast_sin_cos(x_full, false);
1044 }
1045
fast_exp(const Expr & x_full)1046 Expr fast_exp(const Expr &x_full) {
1047 user_assert(x_full.type() == Float(32)) << "fast_exp only works for Float(32)";
1048
1049 Expr scaled = x_full / logf(2.0);
1050 Expr k_real = floor(scaled);
1051 Expr k = cast<int>(k_real);
1052 Expr x = x_full - k_real * logf(2.0);
1053
1054 float coeff[] = {
1055 0.01314350012789660196f,
1056 0.03668965196652099192f,
1057 0.16873890085469545053f,
1058 0.49970514590562437052f,
1059 1.0f,
1060 1.0f};
1061 Expr result = evaluate_polynomial(x, coeff, sizeof(coeff) / sizeof(coeff[0]));
1062
1063 // Compute 2^k.
1064 int fpbias = 127;
1065 Expr biased = clamp(k + fpbias, 0, 255);
1066
1067 // Shift the bits up into the exponent field and reinterpret this
1068 // thing as float.
1069 Expr two_to_the_n = reinterpret<float>(biased << 23);
1070 result *= two_to_the_n;
1071 result = common_subexpression_elimination(result);
1072 return result;
1073 }
1074
print(const std::vector<Expr> & args)1075 Expr print(const std::vector<Expr> &args) {
1076 Expr combined_string = combine_strings(args);
1077
1078 // Call halide_print.
1079 Expr print_call =
1080 Internal::Call::make(Int(32), "halide_print",
1081 {combined_string}, Internal::Call::Extern);
1082
1083 // Return the first argument.
1084 Expr result =
1085 Internal::Call::make(args[0].type(), Internal::Call::return_second,
1086 {print_call, args[0]}, Internal::Call::PureIntrinsic);
1087 return result;
1088 }
1089
print_when(Expr condition,const std::vector<Expr> & args)1090 Expr print_when(Expr condition, const std::vector<Expr> &args) {
1091 Expr p = print(args);
1092 return Internal::Call::make(p.type(),
1093 Internal::Call::if_then_else,
1094 {std::move(condition), p, args[0]},
1095 Internal::Call::PureIntrinsic);
1096 }
1097
require(Expr condition,const std::vector<Expr> & args)1098 Expr require(Expr condition, const std::vector<Expr> &args) {
1099 user_assert(condition.defined()) << "Require of undefined condition.\n";
1100 user_assert(condition.type().is_bool()) << "Require condition must be a boolean type.\n";
1101 user_assert(args.at(0).defined()) << "Require of undefined value.\n";
1102
1103 Expr err = requirement_failed_error(condition, args);
1104
1105 return Internal::Call::make(args[0].type(),
1106 Internal::Call::require,
1107 {likely(std::move(condition)), args[0], std::move(err)},
1108 Internal::Call::PureIntrinsic);
1109 }
1110
saturating_cast(Type t,Expr e)1111 Expr saturating_cast(Type t, Expr e) {
1112 // For float to float, guarantee infinities are always pinned to range.
1113 if (t.is_float() && e.type().is_float()) {
1114 if (t.bits() < e.type().bits()) {
1115 e = cast(t, clamp(std::move(e), t.min(), t.max()));
1116 } else {
1117 e = clamp(cast(t, std::move(e)), t.min(), t.max());
1118 }
1119 } else if (e.type() != t) {
1120 // Limits for Int(2^n) or UInt(2^n) are not exactly representable in Float(2^n)
1121 if (e.type().is_float() && !t.is_float() && t.bits() >= e.type().bits()) {
1122 e = max(std::move(e), t.min()); // min values turn out to be always representable
1123
1124 // This line depends on t.max() rounding upward, which should always
1125 // be the case as it is one less than a representable value, thus
1126 // the one larger is always the closest.
1127 e = select(e >= cast(e.type(), t.max()), t.max(), cast(t, e));
1128 } else {
1129 Expr min_bound;
1130 if (!e.type().is_uint()) {
1131 min_bound = lossless_cast(e.type(), t.min());
1132 }
1133 Expr max_bound = lossless_cast(e.type(), t.max());
1134
1135 if (min_bound.defined() && max_bound.defined()) {
1136 e = clamp(std::move(e), min_bound, max_bound);
1137 } else if (min_bound.defined()) {
1138 e = max(std::move(e), min_bound);
1139 } else if (max_bound.defined()) {
1140 e = min(std::move(e), max_bound);
1141 }
1142 e = cast(t, std::move(e));
1143 }
1144 }
1145 return e;
1146 }
1147
select(Expr condition,Expr true_value,Expr false_value)1148 Expr select(Expr condition, Expr true_value, Expr false_value) {
1149 if (as_const_int(condition)) {
1150 // Why are you doing this? We'll preserve the select node until constant folding for you.
1151 condition = cast(Bool(), std::move(condition));
1152 }
1153
1154 // Coerce int literals to the type of the other argument
1155 if (as_const_int(true_value)) {
1156 true_value = cast(false_value.type(), std::move(true_value));
1157 }
1158 if (as_const_int(false_value)) {
1159 false_value = cast(true_value.type(), std::move(false_value));
1160 }
1161
1162 user_assert(condition.type().is_bool())
1163 << "The first argument to a select must be a boolean:\n"
1164 << " " << condition << " has type " << condition.type() << "\n";
1165
1166 user_assert(true_value.type() == false_value.type())
1167 << "The second and third arguments to a select do not have a matching type:\n"
1168 << " " << true_value << " has type " << true_value.type() << "\n"
1169 << " " << false_value << " has type " << false_value.type() << "\n";
1170
1171 return Internal::Select::make(std::move(condition), std::move(true_value), std::move(false_value));
1172 }
1173
tuple_select(const Tuple & condition,const Tuple & true_value,const Tuple & false_value)1174 Tuple tuple_select(const Tuple &condition, const Tuple &true_value, const Tuple &false_value) {
1175 user_assert(condition.size() == true_value.size() && true_value.size() == false_value.size())
1176 << "tuple_select() requires all Tuples to have identical sizes.";
1177 Tuple result(std::vector<Expr>(condition.size()));
1178 for (size_t i = 0; i < result.size(); i++) {
1179 result[i] = select(condition[i], true_value[i], false_value[i]);
1180 }
1181 return result;
1182 }
1183
tuple_select(const Expr & condition,const Tuple & true_value,const Tuple & false_value)1184 Tuple tuple_select(const Expr &condition, const Tuple &true_value, const Tuple &false_value) {
1185 user_assert(true_value.size() == false_value.size())
1186 << "tuple_select() requires all Tuples to have identical sizes.";
1187 Tuple result(std::vector<Expr>(true_value.size()));
1188 for (size_t i = 0; i < result.size(); i++) {
1189 result[i] = select(condition, true_value[i], false_value[i]);
1190 }
1191 return result;
1192 }
1193
mux(const Expr & id,const std::vector<Expr> & values)1194 Expr mux(const Expr &id, const std::vector<Expr> &values) {
1195 user_assert(values.size() >= 2) << "mux() only accepts values with size >= 2.\n";
1196 // Check if all the values have the same type.
1197 Type t = values[0].type();
1198 for (int i = 1; i < (int)values.size(); i++) {
1199 user_assert(values[i].type() == t) << "mux() requires all the values to have the same type.";
1200 }
1201 Expr result = values.back();
1202 for (int i = (int)values.size() - 2; i >= 0; i--) {
1203 result = select(id == i, values[i], result);
1204 }
1205 return result;
1206 }
1207
mux(const Expr & id,const Tuple & tup)1208 Expr mux(const Expr &id, const Tuple &tup) {
1209 return mux(id, tup.as_vector());
1210 }
1211
mux(const Expr & id,const std::initializer_list<Expr> & values)1212 Expr mux(const Expr &id, const std::initializer_list<Expr> &values) {
1213 return mux(id, std::vector<Expr>(values));
1214 }
1215
unsafe_promise_clamped(const Expr & value,const Expr & min,const Expr & max)1216 Expr unsafe_promise_clamped(const Expr &value, const Expr &min, const Expr &max) {
1217 user_assert(value.defined()) << "unsafe_promise_clamped with undefined value.\n";
1218 Expr n_min_val = min.defined() ? lossless_cast(value.type(), min) : value.type().min();
1219 Expr n_max_val = max.defined() ? lossless_cast(value.type(), max) : value.type().max();
1220
1221 // Min and max are allowed to be undefined with the meaning of no bound on that side.
1222
1223 return Internal::Call::make(value.type(),
1224 Internal::Call::unsafe_promise_clamped,
1225 {value, n_min_val, n_max_val},
1226 Internal::Call::Intrinsic);
1227 }
1228
1229 namespace Internal {
promise_clamped(const Expr & value,const Expr & min,const Expr & max)1230 Expr promise_clamped(const Expr &value, const Expr &min, const Expr &max) {
1231 internal_assert(value.defined()) << "promise_clamped with undefined value.\n";
1232 Expr n_min_val = min.defined() ? lossless_cast(value.type(), min) : value.type().min();
1233 Expr n_max_val = max.defined() ? lossless_cast(value.type(), max) : value.type().max();
1234
1235 // Min and max are allowed to be undefined with the meaning of no bound on that side.
1236 return Internal::Call::make(value.type(),
1237 Internal::Call::promise_clamped,
1238 {value, n_min_val, n_max_val},
1239 Internal::Call::Intrinsic);
1240 }
1241 } // namespace Internal
1242
operator +(Expr a,Expr b)1243 Expr operator+(Expr a, Expr b) {
1244 user_assert(a.defined() && b.defined()) << "operator+ of undefined Expr\n";
1245 Internal::match_types(a, b);
1246 return Internal::Add::make(std::move(a), std::move(b));
1247 }
1248
operator +(Expr a,int b)1249 Expr operator+(Expr a, int b) {
1250 user_assert(a.defined()) << "operator+ of undefined Expr\n";
1251 Type t = a.type();
1252 Internal::check_representable(t, b);
1253 return Internal::Add::make(std::move(a), Internal::make_const(t, b));
1254 }
1255
operator +(int a,Expr b)1256 Expr operator+(int a, Expr b) {
1257 user_assert(b.defined()) << "operator+ of undefined Expr\n";
1258 Type t = b.type();
1259 Internal::check_representable(t, a);
1260 return Internal::Add::make(Internal::make_const(t, a), std::move(b));
1261 }
1262
operator +=(Expr & a,Expr b)1263 Expr &operator+=(Expr &a, Expr b) {
1264 user_assert(a.defined() && b.defined()) << "operator+= of undefined Expr\n";
1265 Type t = a.type();
1266 a = Internal::Add::make(std::move(a), cast(t, std::move(b)));
1267 return a;
1268 }
1269
operator -(Expr a,Expr b)1270 Expr operator-(Expr a, Expr b) {
1271 user_assert(a.defined() && b.defined()) << "operator- of undefined Expr\n";
1272 Internal::match_types(a, b);
1273 return Internal::Sub::make(std::move(a), std::move(b));
1274 }
1275
operator -(Expr a,int b)1276 Expr operator-(Expr a, int b) {
1277 user_assert(a.defined()) << "operator- of undefined Expr\n";
1278 Type t = a.type();
1279 Internal::check_representable(t, b);
1280 return Internal::Sub::make(std::move(a), Internal::make_const(t, b));
1281 }
1282
operator -(int a,Expr b)1283 Expr operator-(int a, Expr b) {
1284 user_assert(b.defined()) << "operator- of undefined Expr\n";
1285 Type t = b.type();
1286 Internal::check_representable(t, a);
1287 return Internal::Sub::make(Internal::make_const(t, a), std::move(b));
1288 }
1289
operator -(Expr a)1290 Expr operator-(Expr a) {
1291 user_assert(a.defined()) << "operator- of undefined Expr\n";
1292 Type t = a.type();
1293 return Internal::Sub::make(Internal::make_zero(t), std::move(a));
1294 }
1295
operator -=(Expr & a,Expr b)1296 Expr &operator-=(Expr &a, Expr b) {
1297 user_assert(a.defined() && b.defined()) << "operator-= of undefined Expr\n";
1298 Type t = a.type();
1299 a = Internal::Sub::make(std::move(a), cast(t, std::move(b)));
1300 return a;
1301 }
1302
operator *(Expr a,Expr b)1303 Expr operator*(Expr a, Expr b) {
1304 user_assert(a.defined() && b.defined()) << "operator* of undefined Expr\n";
1305 Internal::match_types(a, b);
1306 return Internal::Mul::make(std::move(a), std::move(b));
1307 }
1308
operator *(Expr a,int b)1309 Expr operator*(Expr a, int b) {
1310 user_assert(a.defined()) << "operator* of undefined Expr\n";
1311 Type t = a.type();
1312 Internal::check_representable(t, b);
1313 return Internal::Mul::make(std::move(a), Internal::make_const(t, b));
1314 }
1315
operator *(int a,Expr b)1316 Expr operator*(int a, Expr b) {
1317 user_assert(b.defined()) << "operator* of undefined Expr\n";
1318 Type t = b.type();
1319 Internal::check_representable(t, a);
1320 return Internal::Mul::make(Internal::make_const(t, a), std::move(b));
1321 }
1322
operator *=(Expr & a,Expr b)1323 Expr &operator*=(Expr &a, Expr b) {
1324 user_assert(a.defined() && b.defined()) << "operator*= of undefined Expr\n";
1325 Type t = a.type();
1326 a = Internal::Mul::make(std::move(a), cast(t, std::move(b)));
1327 return a;
1328 }
1329
operator /(Expr a,Expr b)1330 Expr operator/(Expr a, Expr b) {
1331 user_assert(a.defined() && b.defined()) << "operator/ of undefined Expr\n";
1332 Internal::match_types(a, b);
1333 return Internal::Div::make(std::move(a), std::move(b));
1334 }
1335
operator /=(Expr & a,Expr b)1336 Expr &operator/=(Expr &a, Expr b) {
1337 user_assert(a.defined() && b.defined()) << "operator/= of undefined Expr\n";
1338 Type t = a.type();
1339 a = Internal::Div::make(std::move(a), cast(t, std::move(b)));
1340 return a;
1341 }
1342
operator /(Expr a,int b)1343 Expr operator/(Expr a, int b) {
1344 user_assert(a.defined()) << "operator/ of undefined Expr\n";
1345 Type t = a.type();
1346 Internal::check_representable(t, b);
1347 return Internal::Div::make(std::move(a), Internal::make_const(t, b));
1348 }
1349
operator /(int a,Expr b)1350 Expr operator/(int a, Expr b) {
1351 user_assert(b.defined()) << "operator- of undefined Expr\n";
1352 Type t = b.type();
1353 Internal::check_representable(t, a);
1354 return Internal::Div::make(Internal::make_const(t, a), std::move(b));
1355 }
1356
operator %(Expr a,Expr b)1357 Expr operator%(Expr a, Expr b) {
1358 user_assert(a.defined() && b.defined()) << "operator% of undefined Expr\n";
1359 Internal::match_types(a, b);
1360 return Internal::Mod::make(std::move(a), std::move(b));
1361 }
1362
operator %(Expr a,int b)1363 Expr operator%(Expr a, int b) {
1364 user_assert(a.defined()) << "operator% of undefined Expr\n";
1365 Type t = a.type();
1366 Internal::check_representable(t, b);
1367 return Internal::Mod::make(std::move(a), Internal::make_const(t, b));
1368 }
1369
operator %(int a,Expr b)1370 Expr operator%(int a, Expr b) {
1371 user_assert(b.defined()) << "operator% of undefined Expr\n";
1372 Type t = b.type();
1373 Internal::check_representable(t, a);
1374 return Internal::Mod::make(Internal::make_const(t, a), std::move(b));
1375 }
1376
operator >(Expr a,Expr b)1377 Expr operator>(Expr a, Expr b) {
1378 user_assert(a.defined() && b.defined()) << "operator> of undefined Expr\n";
1379 Internal::match_types(a, b);
1380 return Internal::GT::make(std::move(a), std::move(b));
1381 }
1382
operator >(Expr a,int b)1383 Expr operator>(Expr a, int b) {
1384 user_assert(a.defined()) << "operator> of undefined Expr\n";
1385 Type t = a.type();
1386 Internal::check_representable(t, b);
1387 return Internal::GT::make(std::move(a), Internal::make_const(t, b));
1388 }
1389
operator >(int a,Expr b)1390 Expr operator>(int a, Expr b) {
1391 user_assert(b.defined()) << "operator> of undefined Expr\n";
1392 Type t = b.type();
1393 Internal::check_representable(t, a);
1394 return Internal::GT::make(Internal::make_const(t, a), std::move(b));
1395 }
1396
operator <(Expr a,Expr b)1397 Expr operator<(Expr a, Expr b) {
1398 user_assert(a.defined() && b.defined()) << "operator< of undefined Expr\n";
1399 Internal::match_types(a, b);
1400 return Internal::LT::make(std::move(a), std::move(b));
1401 }
1402
operator <(Expr a,int b)1403 Expr operator<(Expr a, int b) {
1404 user_assert(a.defined()) << "operator< of undefined Expr\n";
1405 Type t = a.type();
1406 Internal::check_representable(t, b);
1407 return Internal::LT::make(std::move(a), Internal::make_const(t, b));
1408 }
1409
operator <(int a,Expr b)1410 Expr operator<(int a, Expr b) {
1411 user_assert(b.defined()) << "operator< of undefined Expr\n";
1412 Type t = b.type();
1413 Internal::check_representable(t, a);
1414 return Internal::LT::make(Internal::make_const(t, a), std::move(b));
1415 }
1416
operator <=(Expr a,Expr b)1417 Expr operator<=(Expr a, Expr b) {
1418 user_assert(a.defined() && b.defined()) << "operator<= of undefined Expr\n";
1419 Internal::match_types(a, b);
1420 return Internal::LE::make(std::move(a), std::move(b));
1421 }
1422
operator <=(Expr a,int b)1423 Expr operator<=(Expr a, int b) {
1424 user_assert(a.defined()) << "operator<= of undefined Expr\n";
1425 Type t = a.type();
1426 Internal::check_representable(t, b);
1427 return Internal::LE::make(std::move(a), Internal::make_const(t, b));
1428 }
1429
operator <=(int a,Expr b)1430 Expr operator<=(int a, Expr b) {
1431 user_assert(b.defined()) << "operator<= of undefined Expr\n";
1432 Type t = b.type();
1433 Internal::check_representable(t, a);
1434 return Internal::LE::make(Internal::make_const(t, a), std::move(b));
1435 }
1436
operator >=(Expr a,Expr b)1437 Expr operator>=(Expr a, Expr b) {
1438 user_assert(a.defined() && b.defined()) << "operator>= of undefined Expr\n";
1439 Internal::match_types(a, b);
1440 return Internal::GE::make(std::move(a), std::move(b));
1441 }
1442
operator >=(const Expr & a,int b)1443 Expr operator>=(const Expr &a, int b) {
1444 user_assert(a.defined()) << "operator>= of undefined Expr\n";
1445 Type t = a.type();
1446 Internal::check_representable(t, b);
1447 return Internal::GE::make(a, Internal::make_const(t, b));
1448 }
1449
operator >=(int a,const Expr & b)1450 Expr operator>=(int a, const Expr &b) {
1451 user_assert(b.defined()) << "operator>= of undefined Expr\n";
1452 Type t = b.type();
1453 Internal::check_representable(t, a);
1454 return Internal::GE::make(Internal::make_const(t, a), b);
1455 }
1456
operator ==(Expr a,Expr b)1457 Expr operator==(Expr a, Expr b) {
1458 user_assert(a.defined() && b.defined()) << "operator== of undefined Expr\n";
1459 Internal::match_types(a, b);
1460 return Internal::EQ::make(std::move(a), std::move(b));
1461 }
1462
operator ==(Expr a,int b)1463 Expr operator==(Expr a, int b) {
1464 user_assert(a.defined()) << "operator== of undefined Expr\n";
1465 Type t = a.type();
1466 Internal::check_representable(t, b);
1467 return Internal::EQ::make(std::move(a), Internal::make_const(t, b));
1468 }
1469
operator ==(int a,Expr b)1470 Expr operator==(int a, Expr b) {
1471 user_assert(b.defined()) << "operator== of undefined Expr\n";
1472 Type t = b.type();
1473 Internal::check_representable(t, a);
1474 return Internal::EQ::make(Internal::make_const(t, a), std::move(b));
1475 }
1476
operator !=(Expr a,Expr b)1477 Expr operator!=(Expr a, Expr b) {
1478 user_assert(a.defined() && b.defined()) << "operator!= of undefined Expr\n";
1479 Internal::match_types(a, b);
1480 return Internal::NE::make(std::move(a), std::move(b));
1481 }
1482
operator !=(Expr a,int b)1483 Expr operator!=(Expr a, int b) {
1484 user_assert(a.defined()) << "operator!= of undefined Expr\n";
1485 Type t = a.type();
1486 Internal::check_representable(t, b);
1487 return Internal::NE::make(std::move(a), Internal::make_const(t, b));
1488 }
1489
operator !=(int a,Expr b)1490 Expr operator!=(int a, Expr b) {
1491 user_assert(b.defined()) << "operator!= of undefined Expr\n";
1492 Type t = b.type();
1493 Internal::check_representable(t, a);
1494 return Internal::NE::make(Internal::make_const(t, a), std::move(b));
1495 }
1496
operator &&(Expr a,Expr b)1497 Expr operator&&(Expr a, Expr b) {
1498 Internal::match_types(a, b);
1499 return Internal::And::make(std::move(a), std::move(b));
1500 }
1501
operator &&(Expr a,bool b)1502 Expr operator&&(Expr a, bool b) {
1503 internal_assert(a.defined()) << "operator&& of undefined Expr\n";
1504 internal_assert(a.type().is_bool()) << "operator&& of Expr of type " << a.type() << "\n";
1505 if (b) {
1506 return a;
1507 } else {
1508 return Internal::make_zero(a.type());
1509 }
1510 }
1511
operator &&(bool a,Expr b)1512 Expr operator&&(bool a, Expr b) {
1513 return std::move(b) && a;
1514 }
1515
operator ||(Expr a,Expr b)1516 Expr operator||(Expr a, Expr b) {
1517 Internal::match_types(a, b);
1518 return Internal::Or::make(std::move(a), std::move(b));
1519 }
1520
operator ||(Expr a,bool b)1521 Expr operator||(Expr a, bool b) {
1522 internal_assert(a.defined()) << "operator|| of undefined Expr\n";
1523 internal_assert(a.type().is_bool()) << "operator|| of Expr of type " << a.type() << "\n";
1524 if (b) {
1525 return Internal::make_one(a.type());
1526 } else {
1527 return a;
1528 }
1529 }
1530
operator ||(bool a,Expr b)1531 Expr operator||(bool a, Expr b) {
1532 return std::move(b) || a;
1533 }
1534
operator !(Expr a)1535 Expr operator!(Expr a) {
1536 return Internal::Not::make(std::move(a));
1537 }
1538
max(Expr a,Expr b)1539 Expr max(Expr a, Expr b) {
1540 user_assert(a.defined() && b.defined())
1541 << "max of undefined Expr\n";
1542 Internal::match_types(a, b);
1543 return Internal::Max::make(std::move(a), std::move(b));
1544 }
1545
max(Expr a,int b)1546 Expr max(Expr a, int b) {
1547 user_assert(a.defined()) << "max of undefined Expr\n";
1548 Type t = a.type();
1549 Internal::check_representable(t, b);
1550 return Internal::Max::make(std::move(a), Internal::make_const(t, b));
1551 }
1552
max(int a,Expr b)1553 Expr max(int a, Expr b) {
1554 user_assert(b.defined()) << "max of undefined Expr\n";
1555 Type t = b.type();
1556 Internal::check_representable(t, a);
1557 return Internal::Max::make(Internal::make_const(t, a), std::move(b));
1558 }
1559
min(Expr a,Expr b)1560 Expr min(Expr a, Expr b) {
1561 user_assert(a.defined() && b.defined())
1562 << "min of undefined Expr\n";
1563 Internal::match_types(a, b);
1564 return Internal::Min::make(std::move(a), std::move(b));
1565 }
1566
min(Expr a,int b)1567 Expr min(Expr a, int b) {
1568 user_assert(a.defined()) << "max of undefined Expr\n";
1569 Type t = a.type();
1570 Internal::check_representable(t, b);
1571 return Internal::Min::make(std::move(a), Internal::make_const(t, b));
1572 }
1573
min(int a,Expr b)1574 Expr min(int a, Expr b) {
1575 user_assert(b.defined()) << "max of undefined Expr\n";
1576 Type t = b.type();
1577 Internal::check_representable(t, a);
1578 return Internal::Min::make(Internal::make_const(t, a), std::move(b));
1579 }
1580
cast(Type t,Expr a)1581 Expr cast(Type t, Expr a) {
1582 user_assert(a.defined()) << "cast of undefined Expr\n";
1583 if (a.type() == t) {
1584 return a;
1585 }
1586
1587 if (t.is_handle() && !a.type().is_handle()) {
1588 user_error << "Can't cast \"" << a << "\" to a handle. "
1589 << "The only legal cast from scalar types to a handle is: "
1590 << "reinterpret(Handle(), cast<uint64_t>(" << a << "));\n";
1591 } else if (a.type().is_handle() && !t.is_handle()) {
1592 user_error << "Can't cast handle \"" << a << "\" to type " << t << ". "
1593 << "The only legal cast from handles to scalar types is: "
1594 << "reinterpret(UInt(64), " << a << ");\n";
1595 }
1596
1597 // Fold constants early
1598 if (const int64_t *i = as_const_int(a)) {
1599 return Internal::make_const(t, *i);
1600 }
1601 if (const uint64_t *u = as_const_uint(a)) {
1602 return Internal::make_const(t, *u);
1603 }
1604 if (const double *f = as_const_float(a)) {
1605 return Internal::make_const(t, *f);
1606 }
1607
1608 if (t.is_vector()) {
1609 if (a.type().is_scalar()) {
1610 return Internal::Broadcast::make(cast(t.element_of(), std::move(a)), t.lanes());
1611 } else if (const Internal::Broadcast *b = a.as<Internal::Broadcast>()) {
1612 internal_assert(b->lanes == t.lanes());
1613 return Internal::Broadcast::make(cast(t.element_of(), b->value), t.lanes());
1614 }
1615 }
1616 return Internal::Cast::make(t, std::move(a));
1617 }
1618
clamp(Expr a,const Expr & min_val,const Expr & max_val)1619 Expr clamp(Expr a, const Expr &min_val, const Expr &max_val) {
1620 user_assert(a.defined() && min_val.defined() && max_val.defined())
1621 << "clamp of undefined Expr\n";
1622 Expr n_min_val = lossless_cast(a.type(), min_val);
1623 user_assert(n_min_val.defined())
1624 << "Type mismatch in call to clamp. First argument ("
1625 << a << ") has type " << a.type() << ", but second argument ("
1626 << min_val << ") has type " << min_val.type() << ". Use an explicit cast.\n";
1627 Expr n_max_val = lossless_cast(a.type(), max_val);
1628 user_assert(n_max_val.defined())
1629 << "Type mismatch in call to clamp. First argument ("
1630 << a << ") has type " << a.type() << ", but third argument ("
1631 << max_val << ") has type " << max_val.type() << ". Use an explicit cast.\n";
1632 return Internal::Max::make(Internal::Min::make(std::move(a), std::move(n_max_val)), std::move(n_min_val));
1633 }
1634
abs(Expr a)1635 Expr abs(Expr a) {
1636 user_assert(a.defined())
1637 << "abs of undefined Expr\n";
1638 Type t = a.type();
1639 if (t.is_uint()) {
1640 user_warning << "Warning: abs of an unsigned type is a no-op\n";
1641 return a;
1642 }
1643 return Internal::Call::make(t.with_code(t.is_int() ? Type::UInt : t.code()),
1644 Internal::Call::abs, {std::move(a)}, Internal::Call::PureIntrinsic);
1645 }
1646
absd(Expr a,Expr b)1647 Expr absd(Expr a, Expr b) {
1648 user_assert(a.defined() && b.defined()) << "absd of undefined Expr\n";
1649 Internal::match_types(a, b);
1650 Type t = a.type();
1651
1652 if (t.is_float()) {
1653 // Floats can just use abs.
1654 return abs(std::move(a) - std::move(b));
1655 }
1656
1657 // The argument may be signed, but the return type is unsigned.
1658 return Internal::Call::make(t.with_code(t.is_int() ? Type::UInt : t.code()),
1659 Internal::Call::absd, {std::move(a), std::move(b)},
1660 Internal::Call::PureIntrinsic);
1661 }
1662
sin(Expr x)1663 Expr sin(Expr x) {
1664 user_assert(x.defined()) << "sin of undefined Expr\n";
1665 if (x.type() == Float(64)) {
1666 return Internal::Call::make(Float(64), "sin_f64", {std::move(x)}, Internal::Call::PureExtern);
1667 } else if (x.type() == Float(16)) {
1668 return Internal::Call::make(Float(16), "sin_f16", {std::move(x)}, Internal::Call::PureExtern);
1669 } else {
1670 return Internal::Call::make(Float(32), "sin_f32", {cast<float>(std::move(x))}, Internal::Call::PureExtern);
1671 }
1672 }
1673
asin(Expr x)1674 Expr asin(Expr x) {
1675 user_assert(x.defined()) << "asin of undefined Expr\n";
1676 if (x.type() == Float(64)) {
1677 return Internal::Call::make(Float(64), "asin_f64", {std::move(x)}, Internal::Call::PureExtern);
1678 } else if (x.type() == Float(16)) {
1679 return Internal::Call::make(Float(16), "asin_f16", {std::move(x)}, Internal::Call::PureExtern);
1680 } else {
1681 return Internal::Call::make(Float(32), "asin_f32", {cast<float>(std::move(x))}, Internal::Call::PureExtern);
1682 }
1683 }
1684
cos(Expr x)1685 Expr cos(Expr x) {
1686 user_assert(x.defined()) << "cos of undefined Expr\n";
1687 if (x.type() == Float(64)) {
1688 return Internal::Call::make(Float(64), "cos_f64", {std::move(x)}, Internal::Call::PureExtern);
1689 } else if (x.type() == Float(16)) {
1690 return Internal::Call::make(Float(16), "cos_f16", {std::move(x)}, Internal::Call::PureExtern);
1691 } else {
1692 return Internal::Call::make(Float(32), "cos_f32", {cast<float>(std::move(x))}, Internal::Call::PureExtern);
1693 }
1694 }
1695
acos(Expr x)1696 Expr acos(Expr x) {
1697 user_assert(x.defined()) << "acos of undefined Expr\n";
1698 if (x.type() == Float(64)) {
1699 return Internal::Call::make(Float(64), "acos_f64", {std::move(x)}, Internal::Call::PureExtern);
1700 } else if (x.type() == Float(16)) {
1701 return Internal::Call::make(Float(16), "acos_f16", {std::move(x)}, Internal::Call::PureExtern);
1702 } else {
1703 return Internal::Call::make(Float(32), "acos_f32", {cast<float>(std::move(x))}, Internal::Call::PureExtern);
1704 }
1705 }
1706
tan(Expr x)1707 Expr tan(Expr x) {
1708 user_assert(x.defined()) << "tan of undefined Expr\n";
1709 if (x.type() == Float(64)) {
1710 return Internal::Call::make(Float(64), "tan_f64", {std::move(x)}, Internal::Call::PureExtern);
1711 } else if (x.type() == Float(16)) {
1712 return Internal::Call::make(Float(16), "tan_f16", {std::move(x)}, Internal::Call::PureExtern);
1713 } else {
1714 return Internal::Call::make(Float(32), "tan_f32", {cast<float>(std::move(x))}, Internal::Call::PureExtern);
1715 }
1716 }
1717
atan(Expr x)1718 Expr atan(Expr x) {
1719 user_assert(x.defined()) << "atan of undefined Expr\n";
1720 if (x.type() == Float(64)) {
1721 return Internal::Call::make(Float(64), "atan_f64", {std::move(x)}, Internal::Call::PureExtern);
1722 } else if (x.type() == Float(16)) {
1723 return Internal::Call::make(Float(16), "atan_f16", {std::move(x)}, Internal::Call::PureExtern);
1724 } else {
1725 return Internal::Call::make(Float(32), "atan_f32", {cast<float>(std::move(x))}, Internal::Call::PureExtern);
1726 }
1727 }
1728
atan2(Expr y,Expr x)1729 Expr atan2(Expr y, Expr x) {
1730 user_assert(x.defined() && y.defined()) << "atan2 of undefined Expr\n";
1731
1732 if (y.type() == Float(64)) {
1733 x = cast<double>(x);
1734 return Internal::Call::make(Float(64), "atan2_f64", {std::move(y), std::move(x)}, Internal::Call::PureExtern);
1735 } else if (y.type() == Float(16)) {
1736 x = cast<float16_t>(x);
1737 return Internal::Call::make(Float(16), "atan2_f16", {std::move(y), std::move(x)}, Internal::Call::PureExtern);
1738 } else {
1739 y = cast<float>(y);
1740 x = cast<float>(x);
1741 return Internal::Call::make(Float(32), "atan2_f32", {std::move(y), std::move(x)}, Internal::Call::PureExtern);
1742 }
1743 }
1744
sinh(Expr x)1745 Expr sinh(Expr x) {
1746 user_assert(x.defined()) << "sinh of undefined Expr\n";
1747 if (x.type() == Float(64)) {
1748 return Internal::Call::make(Float(64), "sinh_f64", {std::move(x)}, Internal::Call::PureExtern);
1749 } else if (x.type() == Float(16)) {
1750 return Internal::Call::make(Float(16), "sinh_f16", {std::move(x)}, Internal::Call::PureExtern);
1751 } else {
1752 return Internal::Call::make(Float(32), "sinh_f32", {cast<float>(std::move(x))}, Internal::Call::PureExtern);
1753 }
1754 }
1755
asinh(Expr x)1756 Expr asinh(Expr x) {
1757 user_assert(x.defined()) << "asinh of undefined Expr\n";
1758 if (x.type() == Float(64)) {
1759 return Internal::Call::make(Float(64), "asinh_f64", {std::move(x)}, Internal::Call::PureExtern);
1760 } else if (x.type() == Float(16)) {
1761 return Internal::Call::make(Float(16), "asinh_f16", {std::move(x)}, Internal::Call::PureExtern);
1762 } else {
1763 return Internal::Call::make(Float(32), "asinh_f32", {cast<float>(std::move(x))}, Internal::Call::PureExtern);
1764 }
1765 }
1766
cosh(Expr x)1767 Expr cosh(Expr x) {
1768 user_assert(x.defined()) << "cosh of undefined Expr\n";
1769 if (x.type() == Float(64)) {
1770 return Internal::Call::make(Float(64), "cosh_f64", {std::move(x)}, Internal::Call::PureExtern);
1771 } else if (x.type() == Float(16)) {
1772 return Internal::Call::make(Float(16), "cosh_f16", {std::move(x)}, Internal::Call::PureExtern);
1773 } else {
1774 return Internal::Call::make(Float(32), "cosh_f32", {cast<float>(std::move(x))}, Internal::Call::PureExtern);
1775 }
1776 }
1777
acosh(Expr x)1778 Expr acosh(Expr x) {
1779 user_assert(x.defined()) << "acosh of undefined Expr\n";
1780 if (x.type() == Float(64)) {
1781 return Internal::Call::make(Float(64), "acosh_f64", {std::move(x)}, Internal::Call::PureExtern);
1782 } else if (x.type() == Float(16)) {
1783 return Internal::Call::make(Float(16), "acosh_f16", {std::move(x)}, Internal::Call::PureExtern);
1784 } else {
1785 return Internal::Call::make(Float(32), "acosh_f32", {cast<float>(std::move(x))}, Internal::Call::PureExtern);
1786 }
1787 }
1788
tanh(Expr x)1789 Expr tanh(Expr x) {
1790 user_assert(x.defined()) << "tanh of undefined Expr\n";
1791 if (x.type() == Float(64)) {
1792 return Internal::Call::make(Float(64), "tanh_f64", {std::move(x)}, Internal::Call::PureExtern);
1793 } else if (x.type() == Float(16)) {
1794 return Internal::Call::make(Float(16), "tanh_f16", {std::move(x)}, Internal::Call::PureExtern);
1795 } else {
1796 return Internal::Call::make(Float(32), "tanh_f32", {cast<float>(std::move(x))}, Internal::Call::PureExtern);
1797 }
1798 }
1799
atanh(Expr x)1800 Expr atanh(Expr x) {
1801 user_assert(x.defined()) << "atanh of undefined Expr\n";
1802 if (x.type() == Float(64)) {
1803 return Internal::Call::make(Float(64), "atanh_f64", {std::move(x)}, Internal::Call::PureExtern);
1804 } else if (x.type() == Float(16)) {
1805 return Internal::Call::make(Float(16), "atanh_f16", {std::move(x)}, Internal::Call::PureExtern);
1806 } else {
1807 return Internal::Call::make(Float(32), "atanh_f32", {cast<float>(std::move(x))}, Internal::Call::PureExtern);
1808 }
1809 }
1810
sqrt(Expr x)1811 Expr sqrt(Expr x) {
1812 user_assert(x.defined()) << "sqrt of undefined Expr\n";
1813 if (x.type() == Float(64)) {
1814 return Internal::Call::make(Float(64), "sqrt_f64", {std::move(x)}, Internal::Call::PureExtern);
1815 } else if (x.type() == Float(16)) {
1816 return Internal::Call::make(Float(16), "sqrt_f16", {std::move(x)}, Internal::Call::PureExtern);
1817 } else {
1818 return Internal::Call::make(Float(32), "sqrt_f32", {cast<float>(std::move(x))}, Internal::Call::PureExtern);
1819 }
1820 }
1821
hypot(const Expr & x,const Expr & y)1822 Expr hypot(const Expr &x, const Expr &y) {
1823 return sqrt(x * x + y * y);
1824 }
1825
exp(Expr x)1826 Expr exp(Expr x) {
1827 user_assert(x.defined()) << "exp of undefined Expr\n";
1828 if (x.type() == Float(64)) {
1829 return Internal::Call::make(Float(64), "exp_f64", {std::move(x)}, Internal::Call::PureExtern);
1830 } else if (x.type() == Float(16)) {
1831 return Internal::Call::make(Float(16), "exp_f16", {std::move(x)}, Internal::Call::PureExtern);
1832 } else {
1833 return Internal::Call::make(Float(32), "exp_f32", {cast<float>(std::move(x))}, Internal::Call::PureExtern);
1834 }
1835 }
1836
log(Expr x)1837 Expr log(Expr x) {
1838 user_assert(x.defined()) << "log of undefined Expr\n";
1839 if (x.type() == Float(64)) {
1840 return Internal::Call::make(Float(64), "log_f64", {std::move(x)}, Internal::Call::PureExtern);
1841 } else if (x.type() == Float(16)) {
1842 return Internal::Call::make(Float(16), "log_f16", {std::move(x)}, Internal::Call::PureExtern);
1843 } else {
1844 return Internal::Call::make(Float(32), "log_f32", {cast<float>(std::move(x))}, Internal::Call::PureExtern);
1845 }
1846 }
1847
pow(Expr x,Expr y)1848 Expr pow(Expr x, Expr y) {
1849 user_assert(x.defined() && y.defined()) << "pow of undefined Expr\n";
1850
1851 if (const int64_t *i = as_const_int(y)) {
1852 return raise_to_integer_power(std::move(x), *i);
1853 }
1854
1855 if (x.type() == Float(64)) {
1856 y = cast<double>(std::move(y));
1857 return Internal::Call::make(Float(64), "pow_f64", {std::move(x), std::move(y)}, Internal::Call::PureExtern);
1858 } else if (x.type() == Float(16)) {
1859 y = cast<float16_t>(std::move(y));
1860 return Internal::Call::make(Float(16), "pow_f16", {std::move(x), std::move(y)}, Internal::Call::PureExtern);
1861 } else {
1862 x = cast<float>(std::move(x));
1863 y = cast<float>(std::move(y));
1864 return Internal::Call::make(Float(32), "pow_f32", {std::move(x), std::move(y)}, Internal::Call::PureExtern);
1865 }
1866 }
1867
erf(const Expr & x)1868 Expr erf(const Expr &x) {
1869 user_assert(x.defined()) << "erf of undefined Expr\n";
1870 user_assert(x.type() == Float(32)) << "erf only takes float arguments\n";
1871 return Internal::halide_erf(x);
1872 }
1873
fast_pow(Expr x,Expr y)1874 Expr fast_pow(Expr x, Expr y) {
1875 if (const int64_t *i = as_const_int(y)) {
1876 return raise_to_integer_power(std::move(x), *i);
1877 }
1878
1879 x = cast<float>(std::move(x));
1880 y = cast<float>(std::move(y));
1881 return select(x == 0.0f, 0.0f, fast_exp(fast_log(x) * std::move(y)));
1882 }
1883
fast_inverse(Expr x)1884 Expr fast_inverse(Expr x) {
1885 user_assert(x.type() == Float(32)) << "fast_inverse only takes float arguments\n";
1886 Type t = x.type();
1887 return Internal::Call::make(t, "fast_inverse_f32", {std::move(x)}, Internal::Call::PureExtern);
1888 }
1889
fast_inverse_sqrt(Expr x)1890 Expr fast_inverse_sqrt(Expr x) {
1891 user_assert(x.type() == Float(32)) << "fast_inverse_sqrt only takes float arguments\n";
1892 Type t = x.type();
1893 return Internal::Call::make(t, "fast_inverse_sqrt_f32", {std::move(x)}, Internal::Call::PureExtern);
1894 }
1895
floor(Expr x)1896 Expr floor(Expr x) {
1897 user_assert(x.defined()) << "floor of undefined Expr\n";
1898 Type t = x.type();
1899 if (t.element_of() == Float(64)) {
1900 return Internal::Call::make(t, "floor_f64", {std::move(x)}, Internal::Call::PureExtern);
1901 } else if (t.element_of() == Float(16)) {
1902 return Internal::Call::make(t, "floor_f16", {std::move(x)}, Internal::Call::PureExtern);
1903 } else {
1904 t = Float(32, t.lanes());
1905 if (t.is_int() || t.is_uint()) {
1906 // Already an integer
1907 return cast(t, std::move(x));
1908 } else {
1909 return Internal::Call::make(t, "floor_f32", {cast(t, std::move(x))}, Internal::Call::PureExtern);
1910 }
1911 }
1912 }
1913
ceil(Expr x)1914 Expr ceil(Expr x) {
1915 user_assert(x.defined()) << "ceil of undefined Expr\n";
1916 Type t = x.type();
1917 if (t.element_of() == Float(64)) {
1918 return Internal::Call::make(t, "ceil_f64", {std::move(x)}, Internal::Call::PureExtern);
1919 } else if (x.type().element_of() == Float(16)) {
1920 return Internal::Call::make(t, "ceil_f16", {std::move(x)}, Internal::Call::PureExtern);
1921 } else {
1922 t = Float(32, t.lanes());
1923 if (t.is_int() || t.is_uint()) {
1924 // Already an integer
1925 return cast(t, std::move(x));
1926 } else {
1927 return Internal::Call::make(t, "ceil_f32", {cast(t, std::move(x))}, Internal::Call::PureExtern);
1928 }
1929 }
1930 }
1931
round(Expr x)1932 Expr round(Expr x) {
1933 user_assert(x.defined()) << "round of undefined Expr\n";
1934 Type t = x.type();
1935 if (t.element_of() == Float(64)) {
1936 return Internal::Call::make(t, "round_f64", {std::move(x)}, Internal::Call::PureExtern);
1937 } else if (t.element_of() == Float(16)) {
1938 return Internal::Call::make(t, "round_f16", {std::move(x)}, Internal::Call::PureExtern);
1939 } else {
1940 t = Float(32, t.lanes());
1941 if (t.is_int() || t.is_uint()) {
1942 // Already an integer
1943 return cast(t, std::move(x));
1944 } else {
1945 return Internal::Call::make(t, "round_f32", {cast(t, std::move(x))}, Internal::Call::PureExtern);
1946 }
1947 }
1948 }
1949
trunc(Expr x)1950 Expr trunc(Expr x) {
1951 user_assert(x.defined()) << "trunc of undefined Expr\n";
1952 Type t = x.type();
1953 if (t.element_of() == Float(64)) {
1954 return Internal::Call::make(t, "trunc_f64", {std::move(x)}, Internal::Call::PureExtern);
1955 } else if (t.element_of() == Float(16)) {
1956 return Internal::Call::make(t, "trunc_f16", {std::move(x)}, Internal::Call::PureExtern);
1957 } else {
1958 t = Float(32, t.lanes());
1959 if (t.is_int() || t.is_uint()) {
1960 // Already an integer
1961 return cast(t, std::move(x));
1962 } else {
1963 return Internal::Call::make(t, "trunc_f32", {cast(t, std::move(x))}, Internal::Call::PureExtern);
1964 }
1965 }
1966 }
1967
is_nan(Expr x)1968 Expr is_nan(Expr x) {
1969 user_assert(x.defined()) << "is_nan of undefined Expr\n";
1970 user_assert(x.type().is_float()) << "is_nan only works for float";
1971 Type t = Bool(x.type().lanes());
1972 if (!is_const(x)) {
1973 x = strict_float(x);
1974 }
1975 if (x.type().element_of() == Float(64)) {
1976 return Internal::Call::make(t, "is_nan_f64", {std::move(x)}, Internal::Call::PureExtern);
1977 } else if (x.type().element_of() == Float(16)) {
1978 return Internal::Call::make(t, "is_nan_f16", {std::move(x)}, Internal::Call::PureExtern);
1979 } else {
1980 Type ft = Float(32, t.lanes());
1981 return Internal::Call::make(t, "is_nan_f32", {cast(ft, std::move(x))}, Internal::Call::PureExtern);
1982 }
1983 }
1984
is_inf(Expr x)1985 Expr is_inf(Expr x) {
1986 user_assert(x.defined()) << "is_inf of undefined Expr\n";
1987 user_assert(x.type().is_float()) << "is_inf only works for float";
1988 Type t = Bool(x.type().lanes());
1989 if (!is_const(x)) {
1990 x = strict_float(x);
1991 }
1992 if (x.type().element_of() == Float(64)) {
1993 return Internal::Call::make(t, "is_inf_f64", {std::move(x)}, Internal::Call::PureExtern);
1994 } else if (x.type().element_of() == Float(16)) {
1995 return Internal::Call::make(t, "is_inf_f16", {std::move(x)}, Internal::Call::PureExtern);
1996 } else {
1997 Type ft = Float(32, t.lanes());
1998 return Internal::Call::make(t, "is_inf_f32", {cast(ft, std::move(x))}, Internal::Call::PureExtern);
1999 }
2000 }
2001
is_finite(Expr x)2002 Expr is_finite(Expr x) {
2003 user_assert(x.defined()) << "is_finite of undefined Expr\n";
2004 user_assert(x.type().is_float()) << "is_finite only works for float";
2005 Type t = Bool(x.type().lanes());
2006 if (!is_const(x)) {
2007 x = strict_float(x);
2008 }
2009 if (x.type().element_of() == Float(64)) {
2010 return Internal::Call::make(t, "is_finite_f64", {std::move(x)}, Internal::Call::PureExtern);
2011 } else if (x.type().element_of() == Float(16)) {
2012 return Internal::Call::make(t, "is_finite_f16", {std::move(x)}, Internal::Call::PureExtern);
2013 } else {
2014 Type ft = Float(32, t.lanes());
2015 return Internal::Call::make(t, "is_finite_f32", {cast(ft, std::move(x))}, Internal::Call::PureExtern);
2016 }
2017 }
2018
fract(const Expr & x)2019 Expr fract(const Expr &x) {
2020 user_assert(x.defined()) << "fract of undefined Expr\n";
2021 return x - trunc(x);
2022 }
2023
reinterpret(Type t,Expr e)2024 Expr reinterpret(Type t, Expr e) {
2025 user_assert(e.defined()) << "reinterpret of undefined Expr\n";
2026 int from_bits = e.type().bits() * e.type().lanes();
2027 int to_bits = t.bits() * t.lanes();
2028 user_assert(from_bits == to_bits)
2029 << "Reinterpret cast from type " << e.type()
2030 << " which has " << from_bits
2031 << " bits, to type " << t
2032 << " which has " << to_bits << " bits\n";
2033 return Internal::Call::make(t, Internal::Call::reinterpret, {std::move(e)}, Internal::Call::PureIntrinsic);
2034 }
2035
operator &(Expr x,Expr y)2036 Expr operator&(Expr x, Expr y) {
2037 match_types_bitwise(x, y, "bitwise and");
2038 Type t = x.type();
2039 return Internal::Call::make(t, Internal::Call::bitwise_and, {std::move(x), std::move(y)}, Internal::Call::PureIntrinsic);
2040 }
2041
operator &(Expr x,int y)2042 Expr operator&(Expr x, int y) {
2043 Type t = x.type();
2044 Internal::check_representable(t, y);
2045 return Internal::Call::make(t, Internal::Call::bitwise_and, {std::move(x), Internal::make_const(t, y)}, Internal::Call::PureIntrinsic);
2046 }
2047
operator &(int x,Expr y)2048 Expr operator&(int x, Expr y) {
2049 Type t = y.type();
2050 Internal::check_representable(t, x);
2051 return Internal::Call::make(t, Internal::Call::bitwise_and, {Internal::make_const(t, x), std::move(y)}, Internal::Call::PureIntrinsic);
2052 }
2053
operator |(Expr x,Expr y)2054 Expr operator|(Expr x, Expr y) {
2055 match_types_bitwise(x, y, "bitwise or");
2056 Type t = x.type();
2057 return Internal::Call::make(t, Internal::Call::bitwise_or, {std::move(x), std::move(y)}, Internal::Call::PureIntrinsic);
2058 }
2059
operator |(Expr x,int y)2060 Expr operator|(Expr x, int y) {
2061 Type t = x.type();
2062 Internal::check_representable(t, y);
2063 return Internal::Call::make(t, Internal::Call::bitwise_or, {std::move(x), Internal::make_const(t, y)}, Internal::Call::PureIntrinsic);
2064 }
2065
operator |(int x,Expr y)2066 Expr operator|(int x, Expr y) {
2067 Type t = y.type();
2068 Internal::check_representable(t, x);
2069 return Internal::Call::make(t, Internal::Call::bitwise_or, {Internal::make_const(t, x), std::move(y)}, Internal::Call::PureIntrinsic);
2070 }
2071
operator ^(Expr x,Expr y)2072 Expr operator^(Expr x, Expr y) {
2073 match_types_bitwise(x, y, "bitwise xor");
2074 Type t = x.type();
2075 return Internal::Call::make(t, Internal::Call::bitwise_xor, {std::move(x), std::move(y)}, Internal::Call::PureIntrinsic);
2076 }
2077
operator ^(Expr x,int y)2078 Expr operator^(Expr x, int y) {
2079 Type t = x.type();
2080 Internal::check_representable(t, y);
2081 return Internal::Call::make(t, Internal::Call::bitwise_xor, {std::move(x), Internal::make_const(t, y)}, Internal::Call::PureIntrinsic);
2082 }
2083
operator ^(int x,Expr y)2084 Expr operator^(int x, Expr y) {
2085 Type t = y.type();
2086 Internal::check_representable(t, x);
2087 return Internal::Call::make(t, Internal::Call::bitwise_xor, {Internal::make_const(t, x), std::move(y)}, Internal::Call::PureIntrinsic);
2088 }
2089
operator ~(Expr x)2090 Expr operator~(Expr x) {
2091 user_assert(x.defined()) << "bitwise not of undefined Expr\n";
2092 user_assert(x.type().is_int() || x.type().is_uint())
2093 << "Argument to bitwise not must be an integer or unsigned integer";
2094 Type t = x.type();
2095 return Internal::Call::make(t, Internal::Call::bitwise_not, {std::move(x)}, Internal::Call::PureIntrinsic);
2096 }
2097
operator <<(Expr x,Expr y)2098 Expr operator<<(Expr x, Expr y) {
2099 if (y.type().is_vector() && !x.type().is_vector()) {
2100 x = Internal::Broadcast::make(x, y.type().lanes());
2101 }
2102 match_bits(x, y);
2103 Type t = x.type();
2104 return Internal::Call::make(t, Internal::Call::shift_left, {std::move(x), std::move(y)}, Internal::Call::PureIntrinsic);
2105 }
2106
operator <<(Expr x,int y)2107 Expr operator<<(Expr x, int y) {
2108 Type t = Int(x.type().bits(), x.type().lanes());
2109 Internal::check_representable(t, y);
2110 return std::move(x) << Internal::make_const(t, y);
2111 }
2112
operator >>(Expr x,Expr y)2113 Expr operator>>(Expr x, Expr y) {
2114 if (y.type().is_vector() && !x.type().is_vector()) {
2115 x = Internal::Broadcast::make(x, y.type().lanes());
2116 }
2117 match_bits(x, y);
2118 Type t = x.type();
2119 return Internal::Call::make(t, Internal::Call::shift_right, {std::move(x), std::move(y)}, Internal::Call::PureIntrinsic);
2120 }
2121
operator >>(Expr x,int y)2122 Expr operator>>(Expr x, int y) {
2123 Type t = Int(x.type().bits(), x.type().lanes());
2124 Internal::check_representable(t, y);
2125 return std::move(x) >> Internal::make_const(t, y);
2126 }
2127
lerp(Expr zero_val,Expr one_val,Expr weight)2128 Expr lerp(Expr zero_val, Expr one_val, Expr weight) {
2129 user_assert(zero_val.defined()) << "lerp with undefined zero value";
2130 user_assert(one_val.defined()) << "lerp with undefined one value";
2131 user_assert(weight.defined()) << "lerp with undefined weight";
2132
2133 // We allow integer constants through, so that you can say things
2134 // like lerp(0, cast<uint8_t>(x), alpha) and produce an 8-bit
2135 // result. Note that lerp(0.0f, cast<uint8_t>(x), alpha) will
2136 // produce an error, as will lerp(0.0f, cast<double>(x),
2137 // alpha). lerp(0, cast<float>(x), alpha) is also allowed and will
2138 // produce a float result.
2139 if (as_const_int(zero_val)) {
2140 zero_val = cast(one_val.type(), std::move(zero_val));
2141 }
2142 if (as_const_int(one_val)) {
2143 one_val = cast(zero_val.type(), std::move(one_val));
2144 }
2145
2146 user_assert(zero_val.type() == one_val.type())
2147 << "Can't lerp between " << zero_val << " of type " << zero_val.type()
2148 << " and " << one_val << " of different type " << one_val.type() << "\n";
2149 user_assert((weight.type().is_uint() || weight.type().is_float()))
2150 << "A lerp weight must be an unsigned integer or a float, but "
2151 << "lerp weight " << weight << " has type " << weight.type() << ".\n";
2152 user_assert((zero_val.type().is_float() || zero_val.type().lanes() <= 32))
2153 << "Lerping between 64-bit integers is not supported\n";
2154 // Compilation error for constant weight that is out of range for integer use
2155 // as this seems like an easy to catch gotcha.
2156 if (!zero_val.type().is_float()) {
2157 const double *const_weight = as_const_float(weight);
2158 if (const_weight) {
2159 user_assert(*const_weight >= 0.0 && *const_weight <= 1.0)
2160 << "Floating-point weight for lerp with integer arguments is "
2161 << *const_weight << ", which is not in the range [0.0, 1.0].\n";
2162 }
2163 }
2164 Type t = zero_val.type();
2165 return Internal::Call::make(t, Internal::Call::lerp,
2166 {std::move(zero_val), std::move(one_val), std::move(weight)},
2167 Internal::Call::PureIntrinsic);
2168 }
2169
popcount(Expr x)2170 Expr popcount(Expr x) {
2171 user_assert(x.defined()) << "popcount of undefined Expr\n";
2172 Type t = x.type();
2173 user_assert(t.is_uint() || t.is_int())
2174 << "Argument to popcount must be an integer\n";
2175 return Internal::Call::make(t, Internal::Call::popcount,
2176 {std::move(x)}, Internal::Call::PureIntrinsic);
2177 }
2178
count_leading_zeros(Expr x)2179 Expr count_leading_zeros(Expr x) {
2180 user_assert(x.defined()) << "count leading zeros of undefined Expr\n";
2181 Type t = x.type();
2182 user_assert(t.is_uint() || t.is_int())
2183 << "Argument to count_leading_zeros must be an integer\n";
2184 return Internal::Call::make(t, Internal::Call::count_leading_zeros,
2185 {std::move(x)}, Internal::Call::PureIntrinsic);
2186 }
2187
count_trailing_zeros(Expr x)2188 Expr count_trailing_zeros(Expr x) {
2189 user_assert(x.defined()) << "count trailing zeros of undefined Expr\n";
2190 Type t = x.type();
2191 user_assert(t.is_uint() || t.is_int())
2192 << "Argument to count_trailing_zeros must be an integer\n";
2193 return Internal::Call::make(t, Internal::Call::count_trailing_zeros,
2194 {std::move(x)}, Internal::Call::PureIntrinsic);
2195 }
2196
div_round_to_zero(Expr x,Expr y)2197 Expr div_round_to_zero(Expr x, Expr y) {
2198 user_assert(x.defined()) << "div_round_to_zero of undefined dividend\n";
2199 user_assert(y.defined()) << "div_round_to_zero of undefined divisor\n";
2200 Internal::match_types(x, y);
2201 if (x.type().is_uint()) {
2202 return std::move(x) / std::move(y);
2203 }
2204 user_assert(x.type().is_int()) << "First argument to div_round_to_zero is not an integer: " << x << "\n";
2205 user_assert(y.type().is_int()) << "Second argument to div_round_to_zero is not an integer: " << y << "\n";
2206 Type t = x.type();
2207 return Internal::Call::make(t, Internal::Call::div_round_to_zero,
2208 {std::move(x), std::move(y)},
2209 Internal::Call::Intrinsic);
2210 }
2211
mod_round_to_zero(Expr x,Expr y)2212 Expr mod_round_to_zero(Expr x, Expr y) {
2213 user_assert(x.defined()) << "mod_round_to_zero of undefined dividend\n";
2214 user_assert(y.defined()) << "mod_round_to_zero of undefined divisor\n";
2215 Internal::match_types(x, y);
2216 if (x.type().is_uint()) {
2217 return std::move(x) % std::move(y);
2218 }
2219 user_assert(x.type().is_int()) << "First argument to mod_round_to_zero is not an integer: " << x << "\n";
2220 user_assert(y.type().is_int()) << "Second argument to mod_round_to_zero is not an integer: " << y << "\n";
2221 Type t = x.type();
2222 return Internal::Call::make(t, Internal::Call::mod_round_to_zero,
2223 {std::move(x), std::move(y)},
2224 Internal::Call::Intrinsic);
2225 }
2226
random_float(Expr seed)2227 Expr random_float(Expr seed) {
2228 // Random floats get even IDs
2229 static std::atomic<int> counter{0};
2230 int id = (counter++) * 2;
2231
2232 std::vector<Expr> args;
2233 if (seed.defined()) {
2234 user_assert(seed.type() == Int(32))
2235 << "The seed passed to random_float must have type Int(32), but instead is "
2236 << seed << " of type " << seed.type() << "\n";
2237 args.push_back(std::move(seed));
2238 }
2239 args.emplace_back(id);
2240
2241 // This is (surprisingly) pure - it's a fixed psuedo-random
2242 // function of its inputs.
2243 return Internal::Call::make(Float(32), Internal::Call::random,
2244 args, Internal::Call::PureIntrinsic);
2245 }
2246
random_uint(Expr seed)2247 Expr random_uint(Expr seed) {
2248 // Random ints get odd IDs
2249 static std::atomic<int> counter{0};
2250 int id = (counter++) * 2 + 1;
2251
2252 std::vector<Expr> args;
2253 if (seed.defined()) {
2254 user_assert(seed.type() == Int(32) || seed.type() == UInt(32))
2255 << "The seed passed to random_int must have type Int(32) or UInt(32), but instead is "
2256 << seed << " of type " << seed.type() << "\n";
2257 args.push_back(std::move(seed));
2258 }
2259 args.emplace_back(id);
2260
2261 return Internal::Call::make(UInt(32), Internal::Call::random,
2262 args, Internal::Call::PureIntrinsic);
2263 }
2264
random_int(Expr seed)2265 Expr random_int(Expr seed) {
2266 return cast<int32_t>(random_uint(std::move(seed)));
2267 }
2268
likely(Expr e)2269 Expr likely(Expr e) {
2270 Type t = e.type();
2271 return Internal::Call::make(t, Internal::Call::likely,
2272 {std::move(e)}, Internal::Call::PureIntrinsic);
2273 }
2274
likely_if_innermost(Expr e)2275 Expr likely_if_innermost(Expr e) {
2276 Type t = e.type();
2277 return Internal::Call::make(t, Internal::Call::likely_if_innermost,
2278 {std::move(e)}, Internal::Call::PureIntrinsic);
2279 }
2280
strict_float(Expr e)2281 Expr strict_float(Expr e) {
2282 Type t = e.type();
2283 return Internal::Call::make(t, Internal::Call::strict_float,
2284 {std::move(e)}, Internal::Call::PureIntrinsic);
2285 }
2286
undef(Type t)2287 Expr undef(Type t) {
2288 return Internal::Call::make(t, Internal::Call::undef,
2289 std::vector<Expr>(),
2290 Internal::Call::PureIntrinsic);
2291 }
2292
2293 } // namespace Halide
2294